-
Notifications
You must be signed in to change notification settings - Fork 15
GT relation aware attn #599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bbe345f
83dadce
ea4a0d9
6b60820
257d07d
b1e508f
14b4a4c
808e47f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -239,6 +239,14 @@ class GraphTransformerEncoderLayer(nn.Module): | |
| activation: Activation function for the feed-forward network. | ||
| Supported values: "gelu" (default), "relu", "silu", "tanh", | ||
| "geglu", "swiglu", "reglu". | ||
| relation_attention_mode: Optional relation-aware augmentation strategy | ||
| for attention scores. ``"none"`` preserves the default shared | ||
| self-attention path. ``"edge_type_additive"`` adds a learned | ||
| per-edge-type bilinear term for token pairs backed by sampled | ||
| directed graph edges. | ||
| num_relations: Number of relation channels expected in | ||
| ``pairwise_relation_mask`` when | ||
| ``relation_attention_mode="edge_type_additive"``. | ||
|
|
||
| Raises: | ||
| ValueError: If model_dim is not divisible by num_heads. | ||
|
|
@@ -252,23 +260,43 @@ def __init__( | |
| dropout_rate: float = 0.1, | ||
| attention_dropout_rate: float = 0.0, | ||
| activation: str = "gelu", | ||
| relation_attention_mode: Literal["none", "edge_type_additive"] = "none", | ||
| num_relations: int = 0, | ||
| ) -> None: | ||
| super().__init__() | ||
| if model_dim % num_heads != 0: | ||
| raise ValueError( | ||
| f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" | ||
| ) | ||
| if relation_attention_mode not in {"none", "edge_type_additive"}: | ||
| raise ValueError( | ||
| "relation_attention_mode must be one of " | ||
| "{'none', 'edge_type_additive'}, " | ||
| f"got '{relation_attention_mode}'" | ||
| ) | ||
| if relation_attention_mode == "edge_type_additive" and num_relations <= 0: | ||
| raise ValueError( | ||
| "relation_attention_mode='edge_type_additive' requires " | ||
| "num_relations > 0." | ||
| ) | ||
|
|
||
| self._num_heads = num_heads | ||
| self._head_dim = model_dim // num_heads | ||
| self._attention_dropout_rate = attention_dropout_rate | ||
| self._relation_attention_mode = relation_attention_mode | ||
| self._num_relations = num_relations | ||
|
|
||
| self._attention_norm = nn.LayerNorm(model_dim) | ||
| self._query_projection = nn.Linear(model_dim, model_dim) | ||
| self._key_projection = nn.Linear(model_dim, model_dim) | ||
| self._value_projection = nn.Linear(model_dim, model_dim) | ||
| self._output_projection = nn.Linear(model_dim, model_dim) | ||
| self._dropout = nn.Dropout(dropout_rate) | ||
| self._relation_attention_matrices: Optional[nn.Parameter] = None | ||
| if relation_attention_mode == "edge_type_additive": | ||
| self._relation_attention_matrices = nn.Parameter( | ||
| torch.empty(num_relations, num_heads, self._head_dim, self._head_dim) | ||
| ) | ||
|
|
||
| self._ffn_norm = nn.LayerNorm(model_dim) | ||
| self._ffn = FeedForwardNetwork( | ||
|
|
@@ -287,13 +315,18 @@ def reset_parameters(self) -> None: | |
| nn.init.xavier_uniform_(projection.weight) | ||
| if projection.bias is not None: | ||
| nn.init.zeros_(projection.bias) | ||
| if self._relation_attention_matrices is not None: | ||
| for relation_matrices in self._relation_attention_matrices: | ||
| for head_matrix in relation_matrices: | ||
| nn.init.xavier_uniform_(head_matrix) | ||
| self._ffn_norm.reset_parameters() | ||
| self._ffn.reset_parameters() | ||
|
|
||
| def forward( | ||
| self, | ||
| x: Tensor, | ||
| attn_bias: Optional[Tensor] = None, | ||
| pairwise_relation_mask: Optional[Tensor] = None, | ||
| valid_mask: Optional[Tensor] = None, | ||
| ) -> Tensor: | ||
| """Forward pass. | ||
|
|
@@ -303,6 +336,9 @@ def forward( | |
| attn_bias: Optional attention bias of shape | ||
| ``(batch, num_heads, seq, seq)`` or broadcastable. | ||
| Added as an additive mask to attention scores. | ||
| pairwise_relation_mask: Optional boolean multi-hot relation mask of shape | ||
| ``(batch, seq, seq, num_relations)`` that marks which sampled | ||
| directed edge types connect each token pair as ``key -> query``. | ||
| valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used | ||
| to zero out padded token states after each residual block. | ||
|
|
||
|
|
@@ -330,14 +366,23 @@ def forward( | |
| batch_size, seq_len, self._num_heads, self._head_dim | ||
| ).transpose(1, 2) | ||
|
|
||
| attention_output = F.scaled_dot_product_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attn_bias, | ||
| dropout_p=self._attention_dropout_rate if self.training else 0.0, | ||
| is_causal=False, | ||
| ) | ||
| if self._relation_attention_mode == "none": | ||
| attention_output = F.scaled_dot_product_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attn_bias, | ||
| dropout_p=self._attention_dropout_rate if self.training else 0.0, | ||
| is_causal=False, | ||
| ) | ||
| else: | ||
| attention_output = self._run_relation_aware_attention( | ||
| query=query, | ||
| key=key, | ||
| value=value, | ||
| attn_bias=attn_bias, | ||
| pairwise_relation_mask=pairwise_relation_mask, | ||
| ) | ||
|
|
||
| # Reshape back to (batch, seq, model_dim) | ||
| attention_output = attention_output.transpose(1, 2).reshape( | ||
|
|
@@ -360,6 +405,102 @@ def forward( | |
|
|
||
| return x | ||
|
|
||
| def _run_relation_aware_attention( | ||
| self, | ||
| query: Tensor, | ||
| key: Tensor, | ||
| value: Tensor, | ||
| attn_bias: Optional[Tensor], | ||
| pairwise_relation_mask: Optional[Tensor], | ||
| ) -> Tensor: | ||
| relation_attention_bias = self._build_relation_attention_bias( | ||
| query, | ||
| key, | ||
| pairwise_relation_mask=pairwise_relation_mask, | ||
| ) | ||
| if relation_attention_bias is not None: | ||
| attn_bias = ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. HGT uses KWQ^T, whereas we use KQ^T + KW*Q^T. So basically we reparametrized HGT's formula with W'^ATT = I + W^ATT. I think we have the same expressiveness, but maybe should initialize our W^ATT differently. zero-init or small sigma gaussian init could be an option, as xavier could make the variance a bit too big for bias. I'm open to discussion |
||
| relation_attention_bias | ||
| if attn_bias is None | ||
| else attn_bias + relation_attention_bias | ||
| ) | ||
|
|
||
| return F.scaled_dot_product_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attn_bias, | ||
| dropout_p=self._attention_dropout_rate if self.training else 0.0, | ||
| is_causal=False, | ||
| ) | ||
|
|
||
| def _build_relation_attention_bias( | ||
| self, | ||
| query: Tensor, | ||
| key: Tensor, | ||
| pairwise_relation_mask: Optional[Tensor], | ||
| ) -> Optional[Tensor]: | ||
| if pairwise_relation_mask is None: | ||
| raise ValueError( | ||
| "pairwise_relation_mask is required when " | ||
| "relation_attention_mode='edge_type_additive'." | ||
| ) | ||
| if pairwise_relation_mask.size(-1) != self._num_relations: | ||
| raise ValueError( | ||
| "pairwise_relation_mask has unexpected relation dimension " | ||
| f"{pairwise_relation_mask.size(-1)}; expected {self._num_relations}." | ||
| ) | ||
| if self._relation_attention_matrices is None: | ||
| raise ValueError("Relation attention matrices are not initialized.") | ||
| if pairwise_relation_mask.size(1) != query.size( | ||
| 2 | ||
| ) or pairwise_relation_mask.size(2) != key.size(2): | ||
| raise ValueError( | ||
| "pairwise_relation_mask must align with the query/key sequence " | ||
| "dimensions." | ||
| ) | ||
|
|
||
| relation_mask = pairwise_relation_mask.to( | ||
| device=query.device, | ||
| dtype=torch.bool, | ||
| ) | ||
| active_relation_positions = relation_mask.nonzero(as_tuple=False) | ||
| if active_relation_positions.numel() == 0: | ||
| return None | ||
|
|
||
| relation_attention_bias = query.new_zeros( | ||
| (query.size(0), query.size(2), key.size(2), self._num_heads) | ||
| ) | ||
| query_by_position = query.transpose(1, 2) | ||
| key_by_position = key.transpose(1, 2) | ||
| relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) | ||
| active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) | ||
|
|
||
| for relation_idx_tensor in active_relation_ids: | ||
| relation_idx = int(relation_idx_tensor.item()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this |
||
| relation_positions = active_relation_positions[ | ||
| active_relation_positions[:, 3] == relation_idx | ||
| ] | ||
| batch_indices, query_indices, key_indices = relation_positions[ | ||
| :, :3 | ||
| ].unbind(dim=1) | ||
| # Only materialize bilinear scores for token pairs backed by this relation. | ||
| selected_query = query_by_position[batch_indices, query_indices] | ||
| transformed_query = torch.einsum( | ||
| "nhe,hde->nhd", | ||
| selected_query, | ||
| relation_matrices[relation_idx], | ||
| ) | ||
| selected_key = key_by_position[batch_indices, key_indices] | ||
| relation_scores = (selected_key * transformed_query).sum(dim=-1) | ||
| relation_attention_bias.index_put_( | ||
| (batch_indices, query_indices, key_indices), | ||
| relation_scores / math.sqrt(self._head_dim), | ||
| accumulate=True, | ||
| ) | ||
|
|
||
| return relation_attention_bias.permute(0, 3, 1, 2) | ||
|
|
||
|
|
||
| class GraphTransformerEncoder(nn.Module): | ||
| """Graph Transformer encoder for heterogeneous graphs. | ||
|
|
@@ -450,6 +591,10 @@ class GraphTransformerEncoder(nn.Module): | |
| uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, | ||
| following the convention that XGLU's gating doubles the effective | ||
| parameters, so a smaller ratio maintains similar parameter count. | ||
| relation_attention_mode: Optional relation-aware augmentation for | ||
| attention scores. ``"none"`` preserves the current dense transformer | ||
| path. ``"edge_type_additive"`` adds a learned per-edge-type | ||
| bilinear score term for sampled directed edges in ``"khop"`` mode. | ||
|
|
||
| Notes: | ||
| This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap | ||
|
|
@@ -499,6 +644,7 @@ def __init__( | |
| pe_integration_mode: Literal["concat", "add"] = "concat", | ||
| activation: str = "gelu", | ||
| feedforward_ratio: Optional[float] = None, | ||
| relation_attention_mode: Literal["none", "edge_type_additive"] = "none", | ||
| **kwargs: object, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
@@ -540,6 +686,20 @@ def __init__( | |
| "sequence_construction_method='ppr' because khop sequences do not " | ||
| "enforce a stable token order." | ||
| ) | ||
| if relation_attention_mode not in {"none", "edge_type_additive"}: | ||
| raise ValueError( | ||
| "relation_attention_mode must be one of " | ||
| "{'none', 'edge_type_additive'}, " | ||
| f"got '{relation_attention_mode}'" | ||
| ) | ||
| if ( | ||
| relation_attention_mode == "edge_type_additive" | ||
| and sequence_construction_method != "khop" | ||
| ): | ||
| raise ValueError( | ||
| "relation_attention_mode='edge_type_additive' requires " | ||
| "sequence_construction_method='khop'." | ||
| ) | ||
| anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] | ||
| anchor_input_attr_names = anchor_based_input_attr_names or [] | ||
| pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] | ||
|
|
@@ -571,6 +731,12 @@ def __init__( | |
| self._feature_embedding_layer_dict = feature_embedding_layer_dict | ||
| self._pe_integration_mode = pe_integration_mode | ||
| self._num_heads = num_heads | ||
| self._relation_attention_mode = relation_attention_mode | ||
| self._relation_attention_edge_types = ( | ||
| sorted(edge_type_to_feat_dim_map.keys()) | ||
| if relation_attention_mode == "edge_type_additive" | ||
| else [] | ||
| ) | ||
| anchor_input_embedding_attr_names = ( | ||
| set(anchor_based_input_embedding_dict.keys()) | ||
| if anchor_based_input_embedding_dict is not None | ||
|
|
@@ -664,6 +830,8 @@ def __init__( | |
| dropout_rate=dropout_rate, | ||
| attention_dropout_rate=attention_dropout_rate, | ||
| activation=activation, | ||
| relation_attention_mode=relation_attention_mode, | ||
| num_relations=len(self._relation_attention_edge_types), | ||
| ) | ||
| for _ in range(num_layers) | ||
| ] | ||
|
|
@@ -801,6 +969,11 @@ def forward( | |
| anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, | ||
| anchor_based_input_attr_names=self._anchor_based_input_attr_names, | ||
| pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, | ||
| relation_edge_types=( | ||
| self._relation_attention_edge_types | ||
| if self._relation_attention_mode == "edge_type_additive" | ||
| else None | ||
| ), | ||
| ) | ||
|
|
||
| # Free memory after sequences are built | ||
|
|
@@ -837,6 +1010,9 @@ def forward( | |
| sequences=sequences, | ||
| valid_mask=valid_mask, | ||
| attn_bias=attn_bias, | ||
| pairwise_relation_mask=sequence_auxiliary_data.get( | ||
| "pairwise_relation_mask" | ||
| ), | ||
| ) | ||
| embeddings = self._output_projection(embeddings) | ||
|
|
||
|
|
@@ -1036,6 +1212,7 @@ def _encode_and_readout( | |
| sequences: Tensor, | ||
| valid_mask: Tensor, | ||
| attn_bias: Optional[Tensor] = None, | ||
| pairwise_relation_mask: Optional[Tensor] = None, | ||
| ) -> Tensor: | ||
| """Process sequences through transformer layers and attention readout. | ||
|
|
||
|
|
@@ -1044,14 +1221,21 @@ def _encode_and_readout( | |
| valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. | ||
| attn_bias: Optional additive attention bias broadcastable to | ||
| ``(batch_size, num_heads, seq, seq)``. | ||
| pairwise_relation_mask: Optional boolean relation mask shaped | ||
| ``(batch_size, seq, seq, num_relations)``. | ||
|
|
||
| Returns: | ||
| Output embeddings of shape ``(batch_size, hid_dim)``. | ||
| """ | ||
| x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) | ||
|
|
||
| for encoder_layer in self._encoder_layers: | ||
| x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) | ||
| x = encoder_layer( | ||
| x, | ||
| attn_bias=attn_bias, | ||
| pairwise_relation_mask=pairwise_relation_mask, | ||
| valid_mask=valid_mask, | ||
| ) | ||
|
|
||
| x = self._final_norm(x) | ||
| x = x * valid_mask.unsqueeze(-1).to(x.dtype) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we look at HGT's equation 3, the
W^(ATT)_{\phi(e)}does't have a super script i, so I think the edge type specific transformation is per edge type, not per (edge type, head index). The current implementation has more capacity but we are not doing apple-to-apple comparison with HGTThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it seems like we are missing the
\muattention multiplier per (source_type, edge_type, destination_type) in equation 3 if we compare with HGT