Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gigl/common/collections/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __eq__(self, other: object) -> bool:
for self_key, self_val in self.items():
if self_key not in other:
return False
if self_val != other[self_key]:
if self_val != other[self_key]: # ty: ignore[invalid-argument-type]
return False
return True

Expand Down
9 changes: 5 additions & 4 deletions gigl/common/collections/sorted_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections.abc import Iterator, Mapping
from typing import Any, Protocol, TypeVar
from typing import Any, Protocol, Self, TypeVar


class _Comparable(Protocol):
"""Protocol for types that support comparison operations."""

def __lt__(self, other: Any) -> bool:
def __lt__(self, other: Self, /) -> bool:
"""Less than comparison."""
...

Expand Down Expand Up @@ -39,7 +39,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
*args: Positional arguments passed to dict constructor
**kwargs: Keyword arguments passed to dict constructor
"""
self.__dict: dict[KT, VT] = dict(*args, **kwargs)
# ty cannot resolve dict() constructor overloads with generic TypeVars
self.__dict: dict[KT, VT] = dict(*args, **kwargs) # ty: ignore[invalid-assignment]
self.__needs_memoization: bool = True
self.__sort_dict_if_needed()

Expand Down Expand Up @@ -84,7 +85,7 @@ def __eq__(self, other: object) -> bool:
for self_key, self_val in self.items():
if self_key not in other:
return False
if self_val != other[self_key]:
if self_val != other[self_key]: # ty: ignore[invalid-argument-type]
return False
return True

Expand Down
7 changes: 4 additions & 3 deletions gigl/common/services/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name
import datetime
import time
from dataclasses import dataclass
from typing import Final, Optional, Union
from typing import Final, Optional, Sequence, Union, cast

from google.cloud import aiplatform
from google.cloud.aiplatform_v1.types import (
Expand Down Expand Up @@ -336,13 +336,14 @@ def launch_graph_store_job(

def _submit_job(
self,
worker_pool_specs: Union[list[WorkerPoolSpec], list[dict]],
worker_pool_specs: Sequence[Union[WorkerPoolSpec, dict]],
job_config: VertexAiJobConfig,
) -> aiplatform.CustomJob:
"""Submit a job to Vertex AI and wait for it to complete."""
job = aiplatform.CustomJob(
display_name=job_config.job_name,
worker_pool_specs=worker_pool_specs,
# CustomJob's annotation requires a homogeneous list, but runtime iterates and accepts mixed entries.
worker_pool_specs=cast(list[WorkerPoolSpec], worker_pool_specs),
project=self._project,
location=self._location,
labels=job_config.labels,
Expand Down
2 changes: 1 addition & 1 deletion gigl/common/utils/compute/serialization/serialize_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __decode_nd_array_helper(obj: EncodedNdArray):
def __encode_nd_array_helper(array: np.ndarray) -> EncodedNdArray:
# Using array.data is a slight optimization given that we can use it
serialized_array: bytes = (
array.data if array.flags["C_CONTIGUOUS"] else array.tobytes()
bytes(array.data) if array.flags["C_CONTIGUOUS"] else array.tobytes()
)

if array.dtype == object:
Expand Down
38 changes: 29 additions & 9 deletions gigl/common/utils/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import typing
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper # type: ignore
from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper
from typing import IO, AnyStr, Iterable, Optional, Tuple, Union

import google.cloud.exceptions as google_exceptions
Expand Down Expand Up @@ -132,8 +132,14 @@ def upload_files_to_gcs(
parallel (bool): Flag indicating whether to upload files in parallel. Defaults to True.
"""
if parallel:
project = self.__storage_client.project
if project is None:
raise ValueError(
"GCS storage client has no associated project. "
"Set GOOGLE_CLOUD_PROJECT or pass project= to GcsUtils()."
)
_upload_files_to_gcs_parallel(
project=self.__storage_client.project,
project=project,
local_file_path_to_gcs_path_map=local_file_path_to_gcs_path_map,
)
else:
Expand Down Expand Up @@ -206,11 +212,23 @@ def list_uris_with_gcs_path_pattern(
)
blobs = self.__list_file_blobs_at_gcs_path(gcs_path=gcs_path)
if suffix:
blobs = [blob for blob in blobs if blob.name.endswith(suffix)]
blobs = [
blob
for blob in blobs
if blob.name is not None and blob.name.endswith(suffix)
]
if pattern:
matcher = re.compile(pattern)
blobs = [blob for blob in blobs if matcher.match(blob.name)]
gcs_uris = [GcsUri.join("gs://", blob.bucket.name, blob.name) for blob in blobs]
blobs = [
blob
for blob in blobs
if blob.name is not None and matcher.match(blob.name)
]
gcs_uris = [
GcsUri.join("gs://", blob.bucket.name, blob.name)
for blob in blobs
if blob.name is not None
]
return gcs_uris

def __list_file_blobs_at_gcs_path(self, gcs_path: GcsUri) -> list[storage.Blob]:
Expand Down Expand Up @@ -402,10 +420,12 @@ def __batch_copy_blobs(
dst_prefix: str,
src_blobs: list[storage.Blob],
):
dst_blob_names: list[str] = [
src_blob.name.replace(src_prefix, dst_prefix, 1)
for src_blob in src_blobs
]
dst_blob_names: list[str] = []
for src_blob in src_blobs:
assert src_blob.name is not None, (
"Blob from list_blobs must have a name"
)
dst_blob_names.append(src_blob.name.replace(src_prefix, dst_prefix, 1))
with self.__storage_client.batch():
logger.debug(
f"Will copy {len(src_blobs)} files from {src_bucket}://{src_prefix} to {dst_bucket}://{dst_prefix}."
Expand Down
14 changes: 7 additions & 7 deletions gigl/common/utils/jupyter_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def visualize_graph(
new_node_index_to_type[unenumerated_node.id] = unenumerated_node.type
g = nx.relabel_nodes(g, mapping)
# Update the node_index_to_type mapping to use global node types
node_index_to_type = new_node_index_to_type # type: ignore
node_index_to_type = new_node_index_to_type

# Add positive edges to the graph if they don't already exist
pos_edge_pairs = set()
Expand Down Expand Up @@ -594,10 +594,10 @@ def visualize_graph(
nx.draw_networkx_nodes(
g,
pos,
node_color=node_colors if node_colors else "lightblue", # type: ignore
edgecolors=node_edge_colors if node_edge_colors else CHARCOAL, # type: ignore
linewidths=node_line_widths if node_line_widths else 1, # type: ignore
node_size=node_sizes if node_sizes else 500, # type: ignore
node_color=node_colors if node_colors else "lightblue",
edgecolors=node_edge_colors if node_edge_colors else CHARCOAL,
linewidths=node_line_widths if node_line_widths else 1,
node_size=node_sizes if node_sizes else 500,
)

# Draw edges - straight for homogeneous, curved for bipartite
Expand All @@ -607,7 +607,7 @@ def visualize_graph(
nx.draw_networkx_edges(
g,
pos,
edge_color=edge_colors, # type: ignore
edge_color=edge_colors,
width=0.75, # 75% of default edge width
alpha=0.9, # Less transparent for cleaner look
)
Expand All @@ -616,7 +616,7 @@ def visualize_graph(
nx.draw_networkx_edges(
g,
pos,
edge_color=edge_colors, # type: ignore
edge_color=edge_colors,
width=0.75, # 75% of default edge width
alpha=0.8, # Slightly transparent for better overlap visibility
connectionstyle="arc3,rad=0.1", # Curved edges to reduce overlap
Expand Down
2 changes: 1 addition & 1 deletion gigl/common/utils/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def fn(*args, **kwargs) -> T:
return timeout_individual_fn_call_decorator(f)(*args, **kwargs)
return f(*args, **kwargs)

acceptable_exceptions: Tuple[Type[Exception], ...] = (
acceptable_exceptions: Tuple[Type[Exception], ...] = ( # ty: ignore[invalid-assignment]
exception_to_check
if isinstance(exception_to_check, tuple)
else (exception_to_check,)
Expand Down
1 change: 1 addition & 0 deletions gigl/common/utils/tensorflow_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple

import absl
import absl.logging
import tensorflow as tf
from tensorflow_data_validation import load_schema_text
from tensorflow_metadata.proto.v0.schema_pb2 import Schema
Expand Down
5 changes: 2 additions & 3 deletions gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,8 @@ def build_dataset(
node_world_size = torch.distributed.get_world_size()
node_rank = torch.distributed.get_rank()
master_ip_address = get_internal_ip_from_master_node()
master_dataset_building_ports = tuple(
get_free_ports_from_master_node(num_ports=2)
) # type: ignore[assignment]
ports = get_free_ports_from_master_node(num_ports=2)
master_dataset_building_ports = (ports[0], ports[1])

if should_cleanup_distributed_context and torch.distributed.is_initialized():
logger.info(
Expand Down
8 changes: 5 additions & 3 deletions gigl/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,11 @@ def _initialize_node_ids(
)
else:
train_nodes, val_nodes, test_nodes = splits
self._num_train = train_nodes.numel()
self._num_val = val_nodes.numel()
self._num_test = test_nodes.numel()
self._num_train = (
train_nodes.numel() # ty: ignore[unresolved-attribute]
)
self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute]
self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute]
self._node_ids = _append_non_split_node_ids(
train_nodes, val_nodes, test_nodes, node_ids_on_machine
)
Expand Down
8 changes: 4 additions & 4 deletions gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,13 +904,13 @@ def wait_and_shutdown_server() -> None:
def _call_func_on_server(func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
r"""A callee entry for remote requests on the server side."""
if not callable(func):
logging.warning(
f"'_call_func_on_server': receive a non-callable function target {func}"
raise TypeError(
f"'_call_func_on_server': received non-callable function target {func}"
)
return None

server = get_server()
if hasattr(server, func.__name__):
func_name = getattr(func, "__name__", None)
if func_name is not None and hasattr(server, func_name):
# NOTE: method does not respect inheritance.
# `func` is the full name of the function, e.g. gigl.distributed.graph_store.dist_server.DistServer.get_edge_dir
# And so if something subclasses DistServer, the *base* class method will be called, not the subclass method.
Expand Down
2 changes: 1 addition & 1 deletion gigl/env/dep_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_current_jar_file(directory: LocalUri) -> LocalUri:
if not list_of_files:
raise FileNotFoundError(f"No .jar file found in: {directory.uri}")
latest_file = max(list_of_files, key=os.path.getctime)
return LocalUri(latest_file)
return LocalUri(str(latest_file))


def get_jar_file_uri(component: GiGLComponents, use_spark35=False) -> LocalUri:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _iterator_init(self):
return current_worker_file_uris_to_process

def __iter__(self) -> Iterator[Any]:
raise NotImplemented
raise NotImplementedError


class GcsJSONLIterableDataset(GcsIterableDataset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def _assert_sampling_config_is_valid(self):
@property
def active_sampling_config(self) -> SamplingConfig:
if self.phase == ModelPhase.TRAIN:
return self.training_sampling_config # type: ignore
return self.training_sampling_config
elif self.phase == ModelPhase.VAL:
return self.validation_sampling_config # type: ignore
return self.validation_sampling_config
elif self.phase == ModelPhase.TEST:
return self.testing_sampling_config # type: ignore
return self.testing_sampling_config
elif (
self.phase == ModelPhase.INFERENCE_SRC
or self.phase == ModelPhase.INFERENCE_DST
Expand Down
10 changes: 7 additions & 3 deletions gigl/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,16 +397,20 @@ def _weighted_layer_sum(
Returns:
torch.Tensor: Weighted sum of all layer embeddings, shape [N, D].
"""
if len(all_layer_embeddings) != len(self._layer_weights): # type: ignore # https://github.com/Snapchat/GiGL/issues/408
if len(all_layer_embeddings) != len(
self._layer_weights
): # https://github.com/Snapchat/GiGL/issues/408
raise ValueError(
f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # type: ignore # https://github.com/Snapchat/GiGL/issues/408
f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # https://github.com/Snapchat/GiGL/issues/408
)

# Stack all layer embeddings and compute weighted sum
# _layer_weights is already a tensor buffer registered in __init__
stacked = torch.stack(all_layer_embeddings, dim=0) # shape [K+1, N, D]
w = self._layer_weights.to(stacked.device) # shape [K+1], ensure on same device
out = (stacked * w.view(-1, 1, 1)).sum( # type: ignore # https://github.com/Snapchat/GiGL/issues/408
out = (
stacked * w.view(-1, 1, 1)
).sum( # https://github.com/Snapchat/GiGL/issues/408
dim=0
) # shape [N, D], w_0*X_0 + w_1*X_1 + ...

Expand Down
2 changes: 1 addition & 1 deletion gigl/orchestration/local/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run(
else:
Runner.config_check(start_at, pipeline_config)

component_map: OrderedDict[GiGLComponents, Callable] = OrderedDict(
component_map: OrderedDict[GiGLComponents, Callable] = OrderedDict( # ty: ignore[invalid-assignment]
{
GiGLComponents.ConfigPopulator.value: Runner.run_config_populator,
GiGLComponents.DataPreprocessor.value: Runner.run_data_preprocessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def model(self) -> torch.nn.Module:
@model.setter
def model(self, model: torch.nn.Module) -> None:
self.__model = model
self.__model.graph_backend = GraphBackend.PYG # type: ignore
self.__model.graph_backend = GraphBackend.PYG

def init_model(
self,
Expand Down Expand Up @@ -517,9 +517,9 @@ def _compute_metrics(
mrr = 1.0 / pos_rank.float()

hit_rates = hit_rate_at_k(
pos_scores=pos_scores, # type: ignore
neg_scores=neg_scores, # type: ignore
ks=torch.tensor(ks, device=device, dtype=torch.long), # type: ignore
pos_scores=pos_scores,
neg_scores=neg_scores,
ks=torch.tensor(ks, device=device, dtype=torch.long),
)

total_mrr += mrr.mean().item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def train(

main_data_loader = data_loaders.train_main
random_negative_data_loader = data_loaders.train_random_negative
if main_data_loader is None or random_negative_data_loader is None:
raise ValueError(
"train_main and train_random_negative dataloaders must be set for training"
)
val_main_data_loader_iter = iter(data_loaders.val_main) # type: ignore
val_random_data_loader_iter = iter(data_loaders.val_random_negative) # type: ignore

Expand All @@ -375,7 +379,7 @@ def train(
self.model.train()

for batch_index, (main_batch, random_negative_batch) in enumerate(
zip(main_data_loader, random_negative_data_loader), # type: ignore[arg-type]
zip(main_data_loader, random_negative_data_loader),
start=1,
):
batch_st = time()
Expand Down Expand Up @@ -543,7 +547,7 @@ def validate(
hr_result = hit_rate_at_k(
pos_scores=batch_scores.pos_scores,
neg_scores=batch_scores.random_neg_scores,
ks=ks_for_evaluation, # type: ignore
ks=ks_for_evaluation,
)
mrr_result = mean_reciprocal_rank(
pos_scores=batch_scores.pos_scores,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def _train(
out = self.model(inputs)
# Figure out why below is a typing issue
loss = self._train_loss_fn(
input=out[root_node_indices], target=root_node_labels
input=out[root_node_indices], # ty: ignore[unknown-argument]
target=root_node_labels, # ty: ignore[unknown-argument]
) # type: ignore
loss.backward()
self._optimizer.step()
Expand Down Expand Up @@ -200,7 +201,9 @@ def score(
assert root_node_labels is not None

results: InferBatchResults = self.infer_batch(batch=batch, device=device)
num_correct_in_batch = int((results.predictions == root_node_labels).sum()) # type: ignore # https://github.com/Snapchat/GiGL/issues/408
num_correct_in_batch = int(
(results.predictions == root_node_labels).sum()
) # https://github.com/Snapchat/GiGL/issues/408
num_correct += num_correct_in_batch
num_evaluated += len(batch.root_node_labels)

Expand Down
Loading