Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 193 additions & 9 deletions gigl/nn/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ class GraphTransformerEncoderLayer(nn.Module):
activation: Activation function for the feed-forward network.
Supported values: "gelu" (default), "relu", "silu", "tanh",
"geglu", "swiglu", "reglu".
relation_attention_mode: Optional relation-aware augmentation strategy
for attention scores. ``"none"`` preserves the default shared
self-attention path. ``"edge_type_additive"`` adds a learned
per-edge-type bilinear term for token pairs backed by sampled
directed graph edges.
num_relations: Number of relation channels expected in
``pairwise_relation_mask`` when
``relation_attention_mode="edge_type_additive"``.

Raises:
ValueError: If model_dim is not divisible by num_heads.
Expand All @@ -252,23 +260,43 @@ def __init__(
dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
activation: str = "gelu",
relation_attention_mode: Literal["none", "edge_type_additive"] = "none",
num_relations: int = 0,
) -> None:
super().__init__()
if model_dim % num_heads != 0:
raise ValueError(
f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})"
)
if relation_attention_mode not in {"none", "edge_type_additive"}:
raise ValueError(
"relation_attention_mode must be one of "
"{'none', 'edge_type_additive'}, "
f"got '{relation_attention_mode}'"
)
if relation_attention_mode == "edge_type_additive" and num_relations <= 0:
raise ValueError(
"relation_attention_mode='edge_type_additive' requires "
"num_relations > 0."
)

self._num_heads = num_heads
self._head_dim = model_dim // num_heads
self._attention_dropout_rate = attention_dropout_rate
self._relation_attention_mode = relation_attention_mode
self._num_relations = num_relations

self._attention_norm = nn.LayerNorm(model_dim)
self._query_projection = nn.Linear(model_dim, model_dim)
self._key_projection = nn.Linear(model_dim, model_dim)
self._value_projection = nn.Linear(model_dim, model_dim)
self._output_projection = nn.Linear(model_dim, model_dim)
self._dropout = nn.Dropout(dropout_rate)
self._relation_attention_matrices: Optional[nn.Parameter] = None
if relation_attention_mode == "edge_type_additive":
self._relation_attention_matrices = nn.Parameter(
torch.empty(num_relations, num_heads, self._head_dim, self._head_dim)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we look at HGT's equation 3, the W^(ATT)_{\phi(e)} does't have a super script i, so I think the edge type specific transformation is per edge type, not per (edge type, head index). The current implementation has more capacity but we are not doing apple-to-apple comparison with HGT

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it seems like we are missing the \mu attention multiplier per (source_type, edge_type, destination_type) in equation 3 if we compare with HGT

)

self._ffn_norm = nn.LayerNorm(model_dim)
self._ffn = FeedForwardNetwork(
Expand All @@ -287,13 +315,18 @@ def reset_parameters(self) -> None:
nn.init.xavier_uniform_(projection.weight)
if projection.bias is not None:
nn.init.zeros_(projection.bias)
if self._relation_attention_matrices is not None:
for relation_matrices in self._relation_attention_matrices:
for head_matrix in relation_matrices:
nn.init.xavier_uniform_(head_matrix)
self._ffn_norm.reset_parameters()
self._ffn.reset_parameters()

def forward(
self,
x: Tensor,
attn_bias: Optional[Tensor] = None,
pairwise_relation_mask: Optional[Tensor] = None,
valid_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward pass.
Expand All @@ -303,6 +336,9 @@ def forward(
attn_bias: Optional attention bias of shape
``(batch, num_heads, seq, seq)`` or broadcastable.
Added as an additive mask to attention scores.
pairwise_relation_mask: Optional boolean multi-hot relation mask of shape
``(batch, seq, seq, num_relations)`` that marks which sampled
directed edge types connect each token pair as ``key -> query``.
valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used
to zero out padded token states after each residual block.

Expand Down Expand Up @@ -330,14 +366,23 @@ def forward(
batch_size, seq_len, self._num_heads, self._head_dim
).transpose(1, 2)

attention_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_bias,
dropout_p=self._attention_dropout_rate if self.training else 0.0,
is_causal=False,
)
if self._relation_attention_mode == "none":
attention_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_bias,
dropout_p=self._attention_dropout_rate if self.training else 0.0,
is_causal=False,
)
else:
attention_output = self._run_relation_aware_attention(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
pairwise_relation_mask=pairwise_relation_mask,
)

# Reshape back to (batch, seq, model_dim)
attention_output = attention_output.transpose(1, 2).reshape(
Expand All @@ -360,6 +405,102 @@ def forward(

return x

def _run_relation_aware_attention(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_bias: Optional[Tensor],
pairwise_relation_mask: Optional[Tensor],
) -> Tensor:
relation_attention_bias = self._build_relation_attention_bias(
query,
key,
pairwise_relation_mask=pairwise_relation_mask,
)
if relation_attention_bias is not None:
attn_bias = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HGT uses KWQ^T, whereas we use KQ^T + KW*Q^T. So basically we reparametrized HGT's formula with W'^ATT = I + W^ATT. I think we have the same expressiveness, but maybe should initialize our W^ATT differently. zero-init or small sigma gaussian init could be an option, as xavier could make the variance a bit too big for bias. I'm open to discussion

relation_attention_bias
if attn_bias is None
else attn_bias + relation_attention_bias
)

return F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_bias,
dropout_p=self._attention_dropout_rate if self.training else 0.0,
is_causal=False,
)

def _build_relation_attention_bias(
self,
query: Tensor,
key: Tensor,
pairwise_relation_mask: Optional[Tensor],
) -> Optional[Tensor]:
if pairwise_relation_mask is None:
raise ValueError(
"pairwise_relation_mask is required when "
"relation_attention_mode='edge_type_additive'."
)
if pairwise_relation_mask.size(-1) != self._num_relations:
raise ValueError(
"pairwise_relation_mask has unexpected relation dimension "
f"{pairwise_relation_mask.size(-1)}; expected {self._num_relations}."
)
if self._relation_attention_matrices is None:
raise ValueError("Relation attention matrices are not initialized.")
if pairwise_relation_mask.size(1) != query.size(
2
) or pairwise_relation_mask.size(2) != key.size(2):
raise ValueError(
"pairwise_relation_mask must align with the query/key sequence "
"dimensions."
)

relation_mask = pairwise_relation_mask.to(
device=query.device,
dtype=torch.bool,
)
active_relation_positions = relation_mask.nonzero(as_tuple=False)
if active_relation_positions.numel() == 0:
return None

relation_attention_bias = query.new_zeros(
(query.size(0), query.size(2), key.size(2), self._num_heads)
)
query_by_position = query.transpose(1, 2)
key_by_position = key.transpose(1, 2)
relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype)
active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True)

for relation_idx_tensor in active_relation_ids:
relation_idx = int(relation_idx_tensor.item())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this .item() can the above active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) causes GPU sync. Some AI suggestions to avoid it for performance

# remove the torch.unique line above
for relation_idx in range(self._num_relations):
    # remove the .item() line
    relation_positions = active_relation_positions[
        active_relation_positions[:, 3] == relation_idx
    ]
    ...

relation_positions = active_relation_positions[
active_relation_positions[:, 3] == relation_idx
]
batch_indices, query_indices, key_indices = relation_positions[
:, :3
].unbind(dim=1)
# Only materialize bilinear scores for token pairs backed by this relation.
selected_query = query_by_position[batch_indices, query_indices]
transformed_query = torch.einsum(
"nhe,hde->nhd",
selected_query,
relation_matrices[relation_idx],
)
selected_key = key_by_position[batch_indices, key_indices]
relation_scores = (selected_key * transformed_query).sum(dim=-1)
relation_attention_bias.index_put_(
(batch_indices, query_indices, key_indices),
relation_scores / math.sqrt(self._head_dim),
accumulate=True,
)

return relation_attention_bias.permute(0, 3, 1, 2)


class GraphTransformerEncoder(nn.Module):
"""Graph Transformer encoder for heterogeneous graphs.
Expand Down Expand Up @@ -450,6 +591,10 @@ class GraphTransformerEncoder(nn.Module):
uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants,
following the convention that XGLU's gating doubles the effective
parameters, so a smaller ratio maintains similar parameter count.
relation_attention_mode: Optional relation-aware augmentation for
attention scores. ``"none"`` preserves the current dense transformer
path. ``"edge_type_additive"`` adds a learned per-edge-type
bilinear score term for sampled directed edges in ``"khop"`` mode.

Notes:
This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap
Expand Down Expand Up @@ -499,6 +644,7 @@ def __init__(
pe_integration_mode: Literal["concat", "add"] = "concat",
activation: str = "gelu",
feedforward_ratio: Optional[float] = None,
relation_attention_mode: Literal["none", "edge_type_additive"] = "none",
**kwargs: object,
) -> None:
super().__init__()
Expand Down Expand Up @@ -540,6 +686,20 @@ def __init__(
"sequence_construction_method='ppr' because khop sequences do not "
"enforce a stable token order."
)
if relation_attention_mode not in {"none", "edge_type_additive"}:
raise ValueError(
"relation_attention_mode must be one of "
"{'none', 'edge_type_additive'}, "
f"got '{relation_attention_mode}'"
)
if (
relation_attention_mode == "edge_type_additive"
and sequence_construction_method != "khop"
):
raise ValueError(
"relation_attention_mode='edge_type_additive' requires "
"sequence_construction_method='khop'."
)
anchor_bias_attr_names = anchor_based_attention_bias_attr_names or []
anchor_input_attr_names = anchor_based_input_attr_names or []
pairwise_bias_attr_names = pairwise_attention_bias_attr_names or []
Expand Down Expand Up @@ -571,6 +731,12 @@ def __init__(
self._feature_embedding_layer_dict = feature_embedding_layer_dict
self._pe_integration_mode = pe_integration_mode
self._num_heads = num_heads
self._relation_attention_mode = relation_attention_mode
self._relation_attention_edge_types = (
sorted(edge_type_to_feat_dim_map.keys())
if relation_attention_mode == "edge_type_additive"
else []
)
anchor_input_embedding_attr_names = (
set(anchor_based_input_embedding_dict.keys())
if anchor_based_input_embedding_dict is not None
Expand Down Expand Up @@ -664,6 +830,8 @@ def __init__(
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
activation=activation,
relation_attention_mode=relation_attention_mode,
num_relations=len(self._relation_attention_edge_types),
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -801,6 +969,11 @@ def forward(
anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names,
anchor_based_input_attr_names=self._anchor_based_input_attr_names,
pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names,
relation_edge_types=(
self._relation_attention_edge_types
if self._relation_attention_mode == "edge_type_additive"
else None
),
)

# Free memory after sequences are built
Expand Down Expand Up @@ -837,6 +1010,9 @@ def forward(
sequences=sequences,
valid_mask=valid_mask,
attn_bias=attn_bias,
pairwise_relation_mask=sequence_auxiliary_data.get(
"pairwise_relation_mask"
),
)
embeddings = self._output_projection(embeddings)

Expand Down Expand Up @@ -1036,6 +1212,7 @@ def _encode_and_readout(
sequences: Tensor,
valid_mask: Tensor,
attn_bias: Optional[Tensor] = None,
pairwise_relation_mask: Optional[Tensor] = None,
) -> Tensor:
"""Process sequences through transformer layers and attention readout.

Expand All @@ -1044,14 +1221,21 @@ def _encode_and_readout(
valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``.
attn_bias: Optional additive attention bias broadcastable to
``(batch_size, num_heads, seq, seq)``.
pairwise_relation_mask: Optional boolean relation mask shaped
``(batch_size, seq, seq, num_relations)``.

Returns:
Output embeddings of shape ``(batch_size, hid_dim)``.
"""
x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype)

for encoder_layer in self._encoder_layers:
x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask)
x = encoder_layer(
x,
attn_bias=attn_bias,
pairwise_relation_mask=pairwise_relation_mask,
valid_mask=valid_mask,
)

x = self._final_norm(x)
x = x * valid_mask.unsqueeze(-1).to(x.dtype)
Expand Down
Loading