diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f9d8b345a..28af5127f 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -22,6 +22,7 @@ from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.transforms.graph_transformer import ( + ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME, PPR_WEIGHT_FEATURE_NAME, SequenceAuxiliaryData, TokenInputData, @@ -432,6 +433,9 @@ class GraphTransformerEncoder(nn.Module): pairwise_attention_bias_attr_names: List of pairwise feature names used as additive attention bias. These must correspond to sparse graph-level attributes on ``data``. + pairwise_hard_attention_bias_attr_names: List of hard pairwise attention + constraint names. These are structural controls synthesized from the + sampled graph. Supported values: ``['one_hop_masked']``. feature_embedding_layer_dict: Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None) @@ -495,6 +499,7 @@ def __init__( anchor_based_input_attr_names: Optional[list[str]] = None, anchor_based_input_embedding_dict: Optional[nn.ModuleDict] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, + pairwise_hard_attention_bias_attr_names: Optional[list[str]] = None, feature_embedding_layer_dict: Optional[nn.ModuleDict] = None, pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", @@ -543,6 +548,19 @@ def __init__( anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] + pairwise_hard_bias_attr_names = ( + pairwise_hard_attention_bias_attr_names or [] + ) + unsupported_pairwise_hard_bias_attr_names = sorted( + set(pairwise_hard_bias_attr_names) + - {ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME} + ) + if unsupported_pairwise_hard_bias_attr_names: + raise ValueError( + "Unsupported pairwise hard attention bias attr names " + f"{unsupported_pairwise_hard_bias_attr_names}. Supported values: " + f"['{ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME}']." + ) if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names: raise ValueError( f"'{PPR_WEIGHT_FEATURE_NAME}' is an anchor-relative feature and " @@ -568,6 +586,9 @@ def __init__( self._anchor_based_input_attr_names = anchor_based_input_attr_names self._anchor_based_input_embedding_dict = anchor_based_input_embedding_dict self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names + self._pairwise_hard_attention_bias_attr_names = ( + pairwise_hard_attention_bias_attr_names + ) self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads @@ -801,6 +822,7 @@ def forward( anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, + pairwise_hard_attention_bias_attr_names=self._pairwise_hard_attention_bias_attr_names, ) # Free memory after sequences are built @@ -942,18 +964,23 @@ def _build_attention_bias( """Build additive attention bias from padding mask and learned relative PE projections. This function constructs a combined attention bias tensor that is added to - attention scores before softmax. The bias has three components: + attention scores before softmax. The bias has four components: 1. **Padding mask bias**: Sets padded positions to -inf so they receive zero attention weight after softmax. Shape: (batch, 1, 1, seq) broadcasts to (batch, num_heads, seq, seq) for key masking. - 2. **Anchor-relative bias** (optional): For each sequence position, looks up + 2. **Pairwise hard mask** (optional): For each valid query-key pair, blocks + structurally illegal edges with a large negative value. Input shape: + ``(batch, seq, seq)``. Only applied on valid query rows so padded rows + do not become all ``-inf``. + + 3. **Anchor-relative bias** (optional): For each sequence position, looks up the PE value relative to the anchor (e.g., hop distance from anchor). Input shape: (batch, seq, num_anchor_attrs) After projection: (batch, num_heads, 1, seq) - same bias for all query positions. - 3. **Pairwise bias** (optional): For each (query, key) pair, looks up the PE + 4. **Pairwise bias** (optional): For each (query, key) pair, looks up the PE value between those two nodes (e.g., random walk structural encoding). Input shape: (batch, seq, seq, num_pairwise_attrs) After projection: (batch, num_heads, seq, seq) - unique bias per query-key pair. @@ -965,6 +992,7 @@ def _build_attention_bias( used only to infer dtype and device. attention_bias_data: Dictionary containing optional PE tensors: - "anchor_bias": (batch, seq, num_anchor_attrs) or None + - "pairwise_hard_mask": (batch, seq, seq) or None - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None Returns: @@ -977,6 +1005,7 @@ def _build_attention_bias( # # Output attn_bias shape: (2, 8, 4, 4) # - Positions where valid_mask is False get -inf + # - Hard mask blocks structurally illegal pairs # - Anchor bias adds per-key bias (same for all queries) # - Pairwise bias adds unique bias for each (query, key) pair """ @@ -997,7 +1026,29 @@ def _build_attention_bias( negative_inf, ) - # Step 2: Add anchor-relative bias (optional) + # Step 2: Add pairwise hard mask (optional) + pairwise_hard_mask = attention_bias_data.get("pairwise_hard_mask") + if pairwise_hard_mask is not None: + if pairwise_hard_mask.shape != (batch_size, seq_len, seq_len): + raise ValueError( + "Pairwise hard mask must have shape " + f"({batch_size}, {seq_len}, {seq_len}), " + f"got {tuple(pairwise_hard_mask.shape)}." + ) + valid_query_rows = valid_mask.unsqueeze(1).unsqueeze(3) + pairwise_hard_mask_bias = torch.zeros( + (batch_size, 1, seq_len, seq_len), + dtype=dtype, + device=device, + ) + illegal_pair_mask = valid_query_rows & ~pairwise_hard_mask.unsqueeze(1) + pairwise_hard_mask_bias = pairwise_hard_mask_bias.masked_fill( + illegal_pair_mask, + negative_inf, + ) + attn_bias = attn_bias + pairwise_hard_mask_bias + + # Step 3: Add anchor-relative bias (optional) # Projects (batch, seq, num_attrs) → (batch, seq, num_heads) # Then reshapes to (batch, num_heads, 1, seq) for key-side bias anchor_bias_features = attention_bias_data.get("anchor_bias") @@ -1012,7 +1063,7 @@ def _build_attention_bias( ) # (batch, num_heads, 1, seq) attn_bias = attn_bias + anchor_bias - # Step 3: Add pairwise bias (optional) + # Step 4: Add pairwise bias (optional) # Projects (batch, seq, seq, num_attrs) → (batch, seq, seq, num_heads) # Then reshapes to (batch, num_heads, seq, seq) pairwise_bias_features = attention_bias_data.get("pairwise_bias") diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 602f95bde..33e91204e 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -71,10 +71,12 @@ class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] + pairwise_hard_mask: Optional[Tensor] token_input: Optional[TokenInputData] PPR_WEIGHT_FEATURE_NAME = "ppr_weight" +ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME = "one_hop_masked" def heterodata_to_graph_transformer_input( @@ -90,6 +92,7 @@ def heterodata_to_graph_transformer_input( anchor_based_attention_bias_attr_names: Optional[list[str]] = None, anchor_based_input_attr_names: Optional[list[str]] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, + pairwise_hard_attention_bias_attr_names: Optional[list[str]] = None, ) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]: """ Transform a HeteroData object to Graph Transformer sequence input. @@ -131,6 +134,10 @@ def heterodata_to_graph_transformer_input( pairwise_attention_bias_attr_names: List of pairwise feature names used as attention bias. These must correspond to sparse graph-level attributes on ``data``. Example: ['pairwise_distance']. + pairwise_hard_attention_bias_attr_names: List of hard pairwise attention + constraint names. These are structural controls synthesized from the + sampled graph rather than graph-level sparse attributes on ``data``. + Supported values: ['one_hop_masked']. Returns: (sequences, valid_mask, attention_bias_data), where: @@ -143,6 +150,7 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"pairwise_hard_mask"`` shaped ``(batch, seq, seq)`` or None ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -183,6 +191,17 @@ def heterodata_to_graph_transformer_input( anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] + pairwise_hard_bias_attr_names = pairwise_hard_attention_bias_attr_names or [] + unsupported_pairwise_hard_bias_attr_names = sorted( + set(pairwise_hard_bias_attr_names) + - {ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME} + ) + if unsupported_pairwise_hard_bias_attr_names: + raise ValueError( + "Unsupported pairwise hard attention bias attr names " + f"{unsupported_pairwise_hard_bias_attr_names}. Supported values: " + f"['{ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME}']." + ) if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names: raise ValueError( @@ -312,6 +331,18 @@ def heterodata_to_graph_transformer_input( csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, device=device, ) + pairwise_hard_mask = None + if ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME in pairwise_hard_bias_attr_names: + pairwise_hard_mask = _build_one_hop_pairwise_hard_mask( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + adjacency_csr=_build_message_passing_adjacency_csr( + edge_index=homo_data.edge_index, + num_nodes=num_nodes, + device=device, + ), + device=device, + ) anchor_bias_features = _compose_anchor_feature_tensor( anchor_relative_feature_sequences=anchor_relative_feature_sequences, @@ -332,6 +363,7 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "pairwise_hard_mask": pairwise_hard_mask, "token_input": token_input_features, }, ) @@ -875,6 +907,69 @@ def _lookup_pairwise_relative_features( return features +def _build_message_passing_adjacency_csr( + edge_index: Tensor, + num_nodes: int, + device: torch.device, +) -> Tensor: + """Build a CSR adjacency where ``A[i, j] = 1`` means ``j`` may send to ``i``.""" + self_loops = torch.arange(num_nodes, device=device, dtype=torch.long) + reversed_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0) + adjacency_indices = torch.cat( + [ + reversed_edge_index, + torch.stack([self_loops, self_loops], dim=0), + ], + dim=1, + ) + adjacency_values = torch.ones( + adjacency_indices.size(1), + device=device, + dtype=torch.float, + ) + return ( + torch.sparse_coo_tensor( + adjacency_indices, + adjacency_values, + size=(num_nodes, num_nodes), + ) + .coalesce() + .to_sparse_csr() + ) + + +def _build_one_hop_pairwise_hard_mask( + node_index_sequences: Tensor, + valid_mask: Tensor, + adjacency_csr: Tensor, + device: torch.device, +) -> Tensor: + """Return a token-token mask for strict one-hop attention plus self-loops.""" + batch_size, max_seq_len = node_index_sequences.shape + pairwise_hard_mask = torch.zeros( + (batch_size, max_seq_len, max_seq_len), + dtype=torch.bool, + device=device, + ) + + pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) + if not pair_valid_mask.any(): + return pairwise_hard_mask + + row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) + col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) + valid_row_indices = row_indices[pair_valid_mask] + valid_col_indices = col_indices[pair_valid_mask] + + adjacency_values = _lookup_csr_values( + csr_matrix=adjacency_csr, + row_indices=valid_row_indices, + col_indices=valid_col_indices, + ) + pairwise_hard_mask[pair_valid_mask] = adjacency_values > 0 + return pairwise_hard_mask + + def _get_k_hop_neighbors_sparse( anchor_indices: Tensor, edge_index: Tensor, diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index d0fce10c3..293f20999 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -361,6 +361,24 @@ def test_forward_accepts_pairwise_attention_bias(self) -> None: self.assertEqual(embeddings.shape, (3, 6)) self.assertFalse(torch.isnan(embeddings).any()) + def test_forward_accepts_pairwise_hard_attention_bias(self) -> None: + data = _create_user_graph_with_pe() + + encoder = self._create_encoder( + pairwise_hard_attention_bias_attr_names=["one_hop_masked"], + ) + encoder.eval() + + with torch.no_grad(): + embeddings = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertEqual(embeddings.shape, (3, 6)) + self.assertFalse(torch.isnan(embeddings).any()) + def test_concat_mode_infers_sequence_width_without_explicit_pe_dim(self) -> None: data = _create_user_graph_with_pe() @@ -421,6 +439,53 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: self.assertEqual(attn_bias[0, 0, 2, 2].item(), 27.0) self.assertEqual(attn_bias[0, 1, 2, 2].item(), 38.0) + def test_attention_bias_supports_pairwise_hard_mask(self) -> None: + encoder = self._create_encoder( + pairwise_attention_bias_attr_names=["pairwise_distance"], + pairwise_hard_attention_bias_attr_names=["one_hop_masked"], + ) + + assert encoder._pairwise_pe_attention_bias_projection is not None + + with torch.no_grad(): + encoder._pairwise_pe_attention_bias_projection.weight.copy_( + torch.tensor([[1.0], [2.0]]) + ) + attn_bias = encoder._build_attention_bias( + valid_mask=torch.tensor([[True, True, True, False]]), + sequences=torch.zeros((1, 4, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": None, + "pairwise_bias": torch.tensor( + [ + [ + [[0.0], [1.0], [2.0], [0.0]], + [[3.0], [4.0], [5.0], [0.0]], + [[6.0], [7.0], [8.0], [0.0]], + [[0.0], [0.0], [0.0], [0.0]], + ] + ] + ), + "pairwise_hard_mask": torch.tensor( + [ + [ + [True, True, False, False], + [True, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ] + ), + "token_input": None, + }, + ) + + self.assertEqual(attn_bias.shape, (1, 2, 4, 4)) + self.assertEqual(attn_bias[0, 0, 0, 1].item(), 1.0) + self.assertEqual(attn_bias[0, 1, 1, 2].item(), 10.0) + self.assertLess(attn_bias[0, 0, 0, 2].item(), -1e30) + self.assertEqual(attn_bias[0, 0, 3, 0].item(), 0.0) + def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( self, ) -> None: diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 25cda0821..b5d61c35f 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -122,6 +122,25 @@ def create_ppr_sequence_hetero_data() -> HeteroData: return data +def create_bidirectional_chain_hetero_data() -> HeteroData: + """Create a 3-node chain 0 <-> 1 <-> 2 for one-hop mask tests.""" + data = HeteroData() + data["user"].x = torch.tensor( + [ + [10.0, 0.0], + [11.0, 0.0], + [12.0, 0.0], + ] + ) + data["user", "connects", "user"].edge_index = torch.tensor( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ] + ) + return data + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -491,6 +510,74 @@ def test_ppr_sequence_can_return_token_input_and_attention_bias_features(self): torch.equal(valid_mask[1], torch.tensor([True, True, True, False])) ) + def test_one_hop_hard_attention_bias_returns_mask_only(self) -> None: + data = create_bidirectional_chain_hetero_data() + + _, valid_mask, sequence_auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + pairwise_hard_attention_bias_attr_names=["one_hop_masked"], + ) + + pairwise_hard_mask = sequence_auxiliary_data["pairwise_hard_mask"] + assert pairwise_hard_mask is not None + + self.assertIsNone(sequence_auxiliary_data["pairwise_bias"]) + self.assertEqual(pairwise_hard_mask.shape, (1, 4, 4)) + expected_mask = torch.tensor( + [ + [True, True, False, False], + [True, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ) + self.assertTrue(torch.equal(pairwise_hard_mask[0], expected_mask)) + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + + def test_one_hop_hard_attention_bias_can_be_combined_with_soft_pairwise_bias( + self, + ) -> None: + data = _create_hetero_data_with_relative_pe() + + _, _, sequence_auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + pairwise_attention_bias_attr_names=["pairwise_distance"], + pairwise_hard_attention_bias_attr_names=["one_hop_masked"], + ) + + pairwise_bias = sequence_auxiliary_data["pairwise_bias"] + pairwise_hard_mask = sequence_auxiliary_data["pairwise_hard_mask"] + assert pairwise_bias is not None + assert pairwise_hard_mask is not None + + self.assertEqual(pairwise_bias.shape, (1, 4, 4, 1)) + self.assertEqual(pairwise_hard_mask.shape, (1, 4, 4)) + + def test_invalid_pairwise_hard_attention_bias_attr_name_raises(self) -> None: + data = create_bidirectional_chain_hetero_data() + + with self.assertRaisesRegex( + ValueError, + "Unsupported pairwise hard attention bias attr names", + ): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + pairwise_hard_attention_bias_attr_names=["not_supported"], + ) + class TestPyTorchTransformerIntegration(TestCase): """Tests for integration with PyTorch TransformerEncoderLayer."""