Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions gigl/nn/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.transforms.graph_transformer import (
ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME,
PPR_WEIGHT_FEATURE_NAME,
SequenceAuxiliaryData,
TokenInputData,
Expand Down Expand Up @@ -432,6 +433,9 @@ class GraphTransformerEncoder(nn.Module):
pairwise_attention_bias_attr_names: List of pairwise feature names used
as additive attention bias. These must correspond to sparse
graph-level attributes on ``data``.
pairwise_hard_attention_bias_attr_names: List of hard pairwise attention
constraint names. These are structural controls synthesized from the
sampled graph. Supported values: ``['one_hop_masked']``.
feature_embedding_layer_dict: Optional ModuleDict mapping node types to
feature embedding layers. If provided, these are applied to node
features before node projection. (default: None)
Expand Down Expand Up @@ -495,6 +499,7 @@ def __init__(
anchor_based_input_attr_names: Optional[list[str]] = None,
anchor_based_input_embedding_dict: Optional[nn.ModuleDict] = None,
pairwise_attention_bias_attr_names: Optional[list[str]] = None,
pairwise_hard_attention_bias_attr_names: Optional[list[str]] = None,
feature_embedding_layer_dict: Optional[nn.ModuleDict] = None,
pe_integration_mode: Literal["concat", "add"] = "concat",
activation: str = "gelu",
Expand Down Expand Up @@ -543,6 +548,19 @@ def __init__(
anchor_bias_attr_names = anchor_based_attention_bias_attr_names or []
anchor_input_attr_names = anchor_based_input_attr_names or []
pairwise_bias_attr_names = pairwise_attention_bias_attr_names or []
pairwise_hard_bias_attr_names = (
pairwise_hard_attention_bias_attr_names or []
)
unsupported_pairwise_hard_bias_attr_names = sorted(
set(pairwise_hard_bias_attr_names)
- {ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME}
)
if unsupported_pairwise_hard_bias_attr_names:
raise ValueError(
"Unsupported pairwise hard attention bias attr names "
f"{unsupported_pairwise_hard_bias_attr_names}. Supported values: "
f"['{ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME}']."
)
if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names:
raise ValueError(
f"'{PPR_WEIGHT_FEATURE_NAME}' is an anchor-relative feature and "
Expand All @@ -568,6 +586,9 @@ def __init__(
self._anchor_based_input_attr_names = anchor_based_input_attr_names
self._anchor_based_input_embedding_dict = anchor_based_input_embedding_dict
self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names
self._pairwise_hard_attention_bias_attr_names = (
pairwise_hard_attention_bias_attr_names
)
self._feature_embedding_layer_dict = feature_embedding_layer_dict
self._pe_integration_mode = pe_integration_mode
self._num_heads = num_heads
Expand Down Expand Up @@ -801,6 +822,7 @@ def forward(
anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names,
anchor_based_input_attr_names=self._anchor_based_input_attr_names,
pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names,
pairwise_hard_attention_bias_attr_names=self._pairwise_hard_attention_bias_attr_names,
)

# Free memory after sequences are built
Expand Down Expand Up @@ -942,18 +964,23 @@ def _build_attention_bias(
"""Build additive attention bias from padding mask and learned relative PE projections.

This function constructs a combined attention bias tensor that is added to
attention scores before softmax. The bias has three components:
attention scores before softmax. The bias has four components:

1. **Padding mask bias**: Sets padded positions to -inf so they receive zero
attention weight after softmax. Shape: (batch, 1, 1, seq) broadcasts to
(batch, num_heads, seq, seq) for key masking.

2. **Anchor-relative bias** (optional): For each sequence position, looks up
2. **Pairwise hard mask** (optional): For each valid query-key pair, blocks
structurally illegal edges with a large negative value. Input shape:
``(batch, seq, seq)``. Only applied on valid query rows so padded rows
do not become all ``-inf``.

3. **Anchor-relative bias** (optional): For each sequence position, looks up
the PE value relative to the anchor (e.g., hop distance from anchor).
Input shape: (batch, seq, num_anchor_attrs)
After projection: (batch, num_heads, 1, seq) - same bias for all query positions.

3. **Pairwise bias** (optional): For each (query, key) pair, looks up the PE
4. **Pairwise bias** (optional): For each (query, key) pair, looks up the PE
value between those two nodes (e.g., random walk structural encoding).
Input shape: (batch, seq, seq, num_pairwise_attrs)
After projection: (batch, num_heads, seq, seq) - unique bias per query-key pair.
Expand All @@ -965,6 +992,7 @@ def _build_attention_bias(
used only to infer dtype and device.
attention_bias_data: Dictionary containing optional PE tensors:
- "anchor_bias": (batch, seq, num_anchor_attrs) or None
- "pairwise_hard_mask": (batch, seq, seq) or None
- "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None

Returns:
Expand All @@ -977,6 +1005,7 @@ def _build_attention_bias(
#
# Output attn_bias shape: (2, 8, 4, 4)
# - Positions where valid_mask is False get -inf
# - Hard mask blocks structurally illegal pairs
# - Anchor bias adds per-key bias (same for all queries)
# - Pairwise bias adds unique bias for each (query, key) pair
"""
Expand All @@ -997,7 +1026,29 @@ def _build_attention_bias(
negative_inf,
)

# Step 2: Add anchor-relative bias (optional)
# Step 2: Add pairwise hard mask (optional)
pairwise_hard_mask = attention_bias_data.get("pairwise_hard_mask")
if pairwise_hard_mask is not None:
if pairwise_hard_mask.shape != (batch_size, seq_len, seq_len):
raise ValueError(
"Pairwise hard mask must have shape "
f"({batch_size}, {seq_len}, {seq_len}), "
f"got {tuple(pairwise_hard_mask.shape)}."
)
valid_query_rows = valid_mask.unsqueeze(1).unsqueeze(3)
pairwise_hard_mask_bias = torch.zeros(
(batch_size, 1, seq_len, seq_len),
dtype=dtype,
device=device,
)
illegal_pair_mask = valid_query_rows & ~pairwise_hard_mask.unsqueeze(1)
pairwise_hard_mask_bias = pairwise_hard_mask_bias.masked_fill(
illegal_pair_mask,
negative_inf,
)
attn_bias = attn_bias + pairwise_hard_mask_bias

# Step 3: Add anchor-relative bias (optional)
# Projects (batch, seq, num_attrs) → (batch, seq, num_heads)
# Then reshapes to (batch, num_heads, 1, seq) for key-side bias
anchor_bias_features = attention_bias_data.get("anchor_bias")
Expand All @@ -1012,7 +1063,7 @@ def _build_attention_bias(
) # (batch, num_heads, 1, seq)
attn_bias = attn_bias + anchor_bias

# Step 3: Add pairwise bias (optional)
# Step 4: Add pairwise bias (optional)
# Projects (batch, seq, seq, num_attrs) → (batch, seq, seq, num_heads)
# Then reshapes to (batch, num_heads, seq, seq)
pairwise_bias_features = attention_bias_data.get("pairwise_bias")
Expand Down
95 changes: 95 additions & 0 deletions gigl/transforms/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@
class SequenceAuxiliaryData(TypedDict):
anchor_bias: Optional[Tensor]
pairwise_bias: Optional[Tensor]
pairwise_hard_mask: Optional[Tensor]
token_input: Optional[TokenInputData]


PPR_WEIGHT_FEATURE_NAME = "ppr_weight"
ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME = "one_hop_masked"


def heterodata_to_graph_transformer_input(
Expand All @@ -90,6 +92,7 @@ def heterodata_to_graph_transformer_input(
anchor_based_attention_bias_attr_names: Optional[list[str]] = None,
anchor_based_input_attr_names: Optional[list[str]] = None,
pairwise_attention_bias_attr_names: Optional[list[str]] = None,
pairwise_hard_attention_bias_attr_names: Optional[list[str]] = None,
) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]:
"""
Transform a HeteroData object to Graph Transformer sequence input.
Expand Down Expand Up @@ -131,6 +134,10 @@ def heterodata_to_graph_transformer_input(
pairwise_attention_bias_attr_names: List of pairwise feature names used
as attention bias. These must correspond to sparse graph-level
attributes on ``data``. Example: ['pairwise_distance'].
pairwise_hard_attention_bias_attr_names: List of hard pairwise attention
constraint names. These are structural controls synthesized from the
sampled graph rather than graph-level sparse attributes on ``data``.
Supported values: ['one_hop_masked'].

Returns:
(sequences, valid_mask, attention_bias_data), where:
Expand All @@ -143,6 +150,7 @@ def heterodata_to_graph_transformer_input(
``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None
``"pairwise_bias"`` shaped
``(batch, seq, seq, num_pairwise_attrs)`` or None
``"pairwise_hard_mask"`` shaped ``(batch, seq, seq)`` or None
``"token_input"`` as a dict mapping attribute name to a
``(batch, seq, 1)`` tensor, or None

Expand Down Expand Up @@ -183,6 +191,17 @@ def heterodata_to_graph_transformer_input(
anchor_bias_attr_names = anchor_based_attention_bias_attr_names or []
anchor_input_attr_names = anchor_based_input_attr_names or []
pairwise_bias_attr_names = pairwise_attention_bias_attr_names or []
pairwise_hard_bias_attr_names = pairwise_hard_attention_bias_attr_names or []
unsupported_pairwise_hard_bias_attr_names = sorted(
set(pairwise_hard_bias_attr_names)
- {ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME}
)
if unsupported_pairwise_hard_bias_attr_names:
raise ValueError(
"Unsupported pairwise hard attention bias attr names "
f"{unsupported_pairwise_hard_bias_attr_names}. Supported values: "
f"['{ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME}']."
)

if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names:
raise ValueError(
Expand Down Expand Up @@ -312,6 +331,18 @@ def heterodata_to_graph_transformer_input(
csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None,
device=device,
)
pairwise_hard_mask = None
if ONE_HOP_MASKED_ATTENTION_BIAS_ATTR_NAME in pairwise_hard_bias_attr_names:
pairwise_hard_mask = _build_one_hop_pairwise_hard_mask(
node_index_sequences=node_index_sequences,
valid_mask=valid_mask,
adjacency_csr=_build_message_passing_adjacency_csr(
edge_index=homo_data.edge_index,
num_nodes=num_nodes,
device=device,
),
device=device,
)

anchor_bias_features = _compose_anchor_feature_tensor(
anchor_relative_feature_sequences=anchor_relative_feature_sequences,
Expand All @@ -332,6 +363,7 @@ def heterodata_to_graph_transformer_input(
{
"anchor_bias": anchor_bias_features,
"pairwise_bias": pairwise_feature_sequences,
"pairwise_hard_mask": pairwise_hard_mask,
"token_input": token_input_features,
},
)
Expand Down Expand Up @@ -875,6 +907,69 @@ def _lookup_pairwise_relative_features(
return features


def _build_message_passing_adjacency_csr(
edge_index: Tensor,
num_nodes: int,
device: torch.device,
) -> Tensor:
"""Build a CSR adjacency where ``A[i, j] = 1`` means ``j`` may send to ``i``."""
self_loops = torch.arange(num_nodes, device=device, dtype=torch.long)
reversed_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
adjacency_indices = torch.cat(
[
reversed_edge_index,
torch.stack([self_loops, self_loops], dim=0),
],
dim=1,
)
adjacency_values = torch.ones(
adjacency_indices.size(1),
device=device,
dtype=torch.float,
)
return (
torch.sparse_coo_tensor(
adjacency_indices,
adjacency_values,
size=(num_nodes, num_nodes),
)
.coalesce()
.to_sparse_csr()
)


def _build_one_hop_pairwise_hard_mask(
node_index_sequences: Tensor,
valid_mask: Tensor,
adjacency_csr: Tensor,
device: torch.device,
) -> Tensor:
"""Return a token-token mask for strict one-hop attention plus self-loops."""
batch_size, max_seq_len = node_index_sequences.shape
pairwise_hard_mask = torch.zeros(
(batch_size, max_seq_len, max_seq_len),
dtype=torch.bool,
device=device,
)

pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
if not pair_valid_mask.any():
return pairwise_hard_mask

row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len)
col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1)
valid_row_indices = row_indices[pair_valid_mask]
valid_col_indices = col_indices[pair_valid_mask]

adjacency_values = _lookup_csr_values(
csr_matrix=adjacency_csr,
row_indices=valid_row_indices,
col_indices=valid_col_indices,
)
pairwise_hard_mask[pair_valid_mask] = adjacency_values > 0
return pairwise_hard_mask


def _get_k_hop_neighbors_sparse(
anchor_indices: Tensor,
edge_index: Tensor,
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/nn/graph_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,24 @@ def test_forward_accepts_pairwise_attention_bias(self) -> None:
self.assertEqual(embeddings.shape, (3, 6))
self.assertFalse(torch.isnan(embeddings).any())

def test_forward_accepts_pairwise_hard_attention_bias(self) -> None:
data = _create_user_graph_with_pe()

encoder = self._create_encoder(
pairwise_hard_attention_bias_attr_names=["one_hop_masked"],
)
encoder.eval()

with torch.no_grad():
embeddings = encoder(
data=data,
anchor_node_type=self._node_type,
device=self._device,
)

self.assertEqual(embeddings.shape, (3, 6))
self.assertFalse(torch.isnan(embeddings).any())

def test_concat_mode_infers_sequence_width_without_explicit_pe_dim(self) -> None:
data = _create_user_graph_with_pe()

Expand Down Expand Up @@ -421,6 +439,53 @@ def test_attention_bias_features_are_projected_per_head(self) -> None:
self.assertEqual(attn_bias[0, 0, 2, 2].item(), 27.0)
self.assertEqual(attn_bias[0, 1, 2, 2].item(), 38.0)

def test_attention_bias_supports_pairwise_hard_mask(self) -> None:
encoder = self._create_encoder(
pairwise_attention_bias_attr_names=["pairwise_distance"],
pairwise_hard_attention_bias_attr_names=["one_hop_masked"],
)

assert encoder._pairwise_pe_attention_bias_projection is not None

with torch.no_grad():
encoder._pairwise_pe_attention_bias_projection.weight.copy_(
torch.tensor([[1.0], [2.0]])
)
attn_bias = encoder._build_attention_bias(
valid_mask=torch.tensor([[True, True, True, False]]),
sequences=torch.zeros((1, 4, 8), dtype=torch.float),
attention_bias_data={
"anchor_bias": None,
"pairwise_bias": torch.tensor(
[
[
[[0.0], [1.0], [2.0], [0.0]],
[[3.0], [4.0], [5.0], [0.0]],
[[6.0], [7.0], [8.0], [0.0]],
[[0.0], [0.0], [0.0], [0.0]],
]
]
),
"pairwise_hard_mask": torch.tensor(
[
[
[True, True, False, False],
[True, True, True, False],
[False, True, True, False],
[False, False, False, False],
]
]
),
"token_input": None,
},
)

self.assertEqual(attn_bias.shape, (1, 2, 4, 4))
self.assertEqual(attn_bias[0, 0, 0, 1].item(), 1.0)
self.assertEqual(attn_bias[0, 1, 1, 2].item(), 10.0)
self.assertLess(attn_bias[0, 0, 0, 2].item(), -1e30)
self.assertEqual(attn_bias[0, 0, 3, 0].item(), 0.0)

def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights(
self,
) -> None:
Expand Down
Loading