diff --git a/gigl/common/collections/frozen_dict.py b/gigl/common/collections/frozen_dict.py index 84939b0a6..f3e1314dd 100644 --- a/gigl/common/collections/frozen_dict.py +++ b/gigl/common/collections/frozen_dict.py @@ -45,7 +45,7 @@ def __eq__(self, other: object) -> bool: for self_key, self_val in self.items(): if self_key not in other: return False - if self_val != other[self_key]: + if self_val != other[self_key]: # ty: ignore[invalid-argument-type] return False return True diff --git a/gigl/common/collections/sorted_dict.py b/gigl/common/collections/sorted_dict.py index bf4e3f14c..1627f79b9 100644 --- a/gigl/common/collections/sorted_dict.py +++ b/gigl/common/collections/sorted_dict.py @@ -1,11 +1,11 @@ from collections.abc import Iterator, Mapping -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, Self, TypeVar class _Comparable(Protocol): """Protocol for types that support comparison operations.""" - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: Self, /) -> bool: """Less than comparison.""" ... @@ -39,7 +39,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: *args: Positional arguments passed to dict constructor **kwargs: Keyword arguments passed to dict constructor """ - self.__dict: dict[KT, VT] = dict(*args, **kwargs) + # ty cannot resolve dict() constructor overloads with generic TypeVars + self.__dict: dict[KT, VT] = dict(*args, **kwargs) # ty: ignore[invalid-assignment] self.__needs_memoization: bool = True self.__sort_dict_if_needed() @@ -84,7 +85,7 @@ def __eq__(self, other: object) -> bool: for self_key, self_val in self.items(): if self_key not in other: return False - if self_val != other[self_key]: + if self_val != other[self_key]: # ty: ignore[invalid-argument-type] return False return True diff --git a/gigl/common/services/vertex_ai.py b/gigl/common/services/vertex_ai.py index 6cbb968b3..5f48e8e90 100644 --- a/gigl/common/services/vertex_ai.py +++ b/gigl/common/services/vertex_ai.py @@ -62,7 +62,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name import datetime import time from dataclasses import dataclass -from typing import Final, Optional, Union +from typing import Final, Optional, Sequence, Union, cast from google.cloud import aiplatform from google.cloud.aiplatform_v1.types import ( @@ -336,13 +336,14 @@ def launch_graph_store_job( def _submit_job( self, - worker_pool_specs: Union[list[WorkerPoolSpec], list[dict]], + worker_pool_specs: Sequence[Union[WorkerPoolSpec, dict]], job_config: VertexAiJobConfig, ) -> aiplatform.CustomJob: """Submit a job to Vertex AI and wait for it to complete.""" job = aiplatform.CustomJob( display_name=job_config.job_name, - worker_pool_specs=worker_pool_specs, + # CustomJob's annotation requires a homogeneous list, but runtime iterates and accepts mixed entries. + worker_pool_specs=cast(list[WorkerPoolSpec], worker_pool_specs), project=self._project, location=self._location, labels=job_config.labels, diff --git a/gigl/common/utils/compute/serialization/serialize_np.py b/gigl/common/utils/compute/serialization/serialize_np.py index bf85dde70..d9b50a2ca 100644 --- a/gigl/common/utils/compute/serialization/serialize_np.py +++ b/gigl/common/utils/compute/serialization/serialize_np.py @@ -31,7 +31,7 @@ def __decode_nd_array_helper(obj: EncodedNdArray): def __encode_nd_array_helper(array: np.ndarray) -> EncodedNdArray: # Using array.data is a slight optimization given that we can use it serialized_array: bytes = ( - array.data if array.flags["C_CONTIGUOUS"] else array.tobytes() + bytes(array.data) if array.flags["C_CONTIGUOUS"] else array.tobytes() ) if array.dtype == object: diff --git a/gigl/common/utils/gcs.py b/gigl/common/utils/gcs.py index fa4de4b62..3d516f1b8 100644 --- a/gigl/common/utils/gcs.py +++ b/gigl/common/utils/gcs.py @@ -3,7 +3,7 @@ import tempfile import typing from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper # type: ignore +from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper from typing import IO, AnyStr, Iterable, Optional, Tuple, Union import google.cloud.exceptions as google_exceptions @@ -132,8 +132,14 @@ def upload_files_to_gcs( parallel (bool): Flag indicating whether to upload files in parallel. Defaults to True. """ if parallel: + project = self.__storage_client.project + if project is None: + raise ValueError( + "GCS storage client has no associated project. " + "Set GOOGLE_CLOUD_PROJECT or pass project= to GcsUtils()." + ) _upload_files_to_gcs_parallel( - project=self.__storage_client.project, + project=project, local_file_path_to_gcs_path_map=local_file_path_to_gcs_path_map, ) else: @@ -206,11 +212,23 @@ def list_uris_with_gcs_path_pattern( ) blobs = self.__list_file_blobs_at_gcs_path(gcs_path=gcs_path) if suffix: - blobs = [blob for blob in blobs if blob.name.endswith(suffix)] + blobs = [ + blob + for blob in blobs + if blob.name is not None and blob.name.endswith(suffix) + ] if pattern: matcher = re.compile(pattern) - blobs = [blob for blob in blobs if matcher.match(blob.name)] - gcs_uris = [GcsUri.join("gs://", blob.bucket.name, blob.name) for blob in blobs] + blobs = [ + blob + for blob in blobs + if blob.name is not None and matcher.match(blob.name) + ] + gcs_uris = [ + GcsUri.join("gs://", blob.bucket.name, blob.name) + for blob in blobs + if blob.name is not None + ] return gcs_uris def __list_file_blobs_at_gcs_path(self, gcs_path: GcsUri) -> list[storage.Blob]: @@ -402,10 +420,12 @@ def __batch_copy_blobs( dst_prefix: str, src_blobs: list[storage.Blob], ): - dst_blob_names: list[str] = [ - src_blob.name.replace(src_prefix, dst_prefix, 1) - for src_blob in src_blobs - ] + dst_blob_names: list[str] = [] + for src_blob in src_blobs: + assert src_blob.name is not None, ( + "Blob from list_blobs must have a name" + ) + dst_blob_names.append(src_blob.name.replace(src_prefix, dst_prefix, 1)) with self.__storage_client.batch(): logger.debug( f"Will copy {len(src_blobs)} files from {src_bucket}://{src_prefix} to {dst_bucket}://{dst_prefix}." diff --git a/gigl/common/utils/jupyter_magics.py b/gigl/common/utils/jupyter_magics.py index adf4d9fd1..a010ca437 100644 --- a/gigl/common/utils/jupyter_magics.py +++ b/gigl/common/utils/jupyter_magics.py @@ -465,7 +465,7 @@ def visualize_graph( new_node_index_to_type[unenumerated_node.id] = unenumerated_node.type g = nx.relabel_nodes(g, mapping) # Update the node_index_to_type mapping to use global node types - node_index_to_type = new_node_index_to_type # type: ignore + node_index_to_type = new_node_index_to_type # Add positive edges to the graph if they don't already exist pos_edge_pairs = set() @@ -594,10 +594,10 @@ def visualize_graph( nx.draw_networkx_nodes( g, pos, - node_color=node_colors if node_colors else "lightblue", # type: ignore - edgecolors=node_edge_colors if node_edge_colors else CHARCOAL, # type: ignore - linewidths=node_line_widths if node_line_widths else 1, # type: ignore - node_size=node_sizes if node_sizes else 500, # type: ignore + node_color=node_colors if node_colors else "lightblue", + edgecolors=node_edge_colors if node_edge_colors else CHARCOAL, + linewidths=node_line_widths if node_line_widths else 1, + node_size=node_sizes if node_sizes else 500, ) # Draw edges - straight for homogeneous, curved for bipartite @@ -607,7 +607,7 @@ def visualize_graph( nx.draw_networkx_edges( g, pos, - edge_color=edge_colors, # type: ignore + edge_color=edge_colors, width=0.75, # 75% of default edge width alpha=0.9, # Less transparent for cleaner look ) @@ -616,7 +616,7 @@ def visualize_graph( nx.draw_networkx_edges( g, pos, - edge_color=edge_colors, # type: ignore + edge_color=edge_colors, width=0.75, # 75% of default edge width alpha=0.8, # Slightly transparent for better overlap visibility connectionstyle="arc3,rad=0.1", # Curved edges to reduce overlap diff --git a/gigl/common/utils/retry.py b/gigl/common/utils/retry.py index c71405258..6be63cb4e 100644 --- a/gigl/common/utils/retry.py +++ b/gigl/common/utils/retry.py @@ -65,7 +65,7 @@ def fn(*args, **kwargs) -> T: return timeout_individual_fn_call_decorator(f)(*args, **kwargs) return f(*args, **kwargs) - acceptable_exceptions: Tuple[Type[Exception], ...] = ( + acceptable_exceptions: Tuple[Type[Exception], ...] = ( # ty: ignore[invalid-assignment] exception_to_check if isinstance(exception_to_check, tuple) else (exception_to_check,) diff --git a/gigl/common/utils/tensorflow_schema.py b/gigl/common/utils/tensorflow_schema.py index 34752a17f..6422bc030 100644 --- a/gigl/common/utils/tensorflow_schema.py +++ b/gigl/common/utils/tensorflow_schema.py @@ -1,6 +1,7 @@ from typing import Tuple import absl +import absl.logging import tensorflow as tf from tensorflow_data_validation import load_schema_text from tensorflow_metadata.proto.v0.schema_pb2 import Schema diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index a990e491f..230aed66d 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -414,9 +414,8 @@ def build_dataset( node_world_size = torch.distributed.get_world_size() node_rank = torch.distributed.get_rank() master_ip_address = get_internal_ip_from_master_node() - master_dataset_building_ports = tuple( - get_free_ports_from_master_node(num_ports=2) - ) # type: ignore[assignment] + ports = get_free_ports_from_master_node(num_ports=2) + master_dataset_building_ports = (ports[0], ports[1]) if should_cleanup_distributed_context and torch.distributed.is_initialized(): logger.info( diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index d37bbc925..2323c3fe1 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -482,9 +482,11 @@ def _initialize_node_ids( ) else: train_nodes, val_nodes, test_nodes = splits - self._num_train = train_nodes.numel() - self._num_val = val_nodes.numel() - self._num_test = test_nodes.numel() + self._num_train = ( + train_nodes.numel() # ty: ignore[unresolved-attribute] + ) + self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute] + self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute] self._node_ids = _append_non_split_node_ids( train_nodes, val_nodes, test_nodes, node_ids_on_machine ) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 288929fdf..4566e6021 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -904,13 +904,13 @@ def wait_and_shutdown_server() -> None: def _call_func_on_server(func: Callable[..., R], *args: Any, **kwargs: Any) -> R: r"""A callee entry for remote requests on the server side.""" if not callable(func): - logging.warning( - f"'_call_func_on_server': receive a non-callable function target {func}" + raise TypeError( + f"'_call_func_on_server': received non-callable function target {func}" ) - return None server = get_server() - if hasattr(server, func.__name__): + func_name = getattr(func, "__name__", None) + if func_name is not None and hasattr(server, func_name): # NOTE: method does not respect inheritance. # `func` is the full name of the function, e.g. gigl.distributed.graph_store.dist_server.DistServer.get_edge_dir # And so if something subclasses DistServer, the *base* class method will be called, not the subclass method. diff --git a/gigl/env/dep_constants.py b/gigl/env/dep_constants.py index a3353b0af..b31328166 100644 --- a/gigl/env/dep_constants.py +++ b/gigl/env/dep_constants.py @@ -41,7 +41,7 @@ def get_current_jar_file(directory: LocalUri) -> LocalUri: if not list_of_files: raise FileNotFoundError(f"No .jar file found in: {directory.uri}") latest_file = max(list_of_files, key=os.path.getctime) - return LocalUri(latest_file) + return LocalUri(str(latest_file)) def get_jar_file_uri(component: GiGLComponents, use_spark35=False) -> LocalUri: diff --git a/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py b/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py index 3ae9849cf..71666a1a9 100644 --- a/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py +++ b/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py @@ -64,7 +64,7 @@ def _iterator_init(self): return current_worker_file_uris_to_process def __iter__(self) -> Iterator[Any]: - raise NotImplemented + raise NotImplementedError class GcsJSONLIterableDataset(GcsIterableDataset): diff --git a/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py b/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py index 78375618c..bdfd9f179 100644 --- a/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py +++ b/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py @@ -95,11 +95,11 @@ def _assert_sampling_config_is_valid(self): @property def active_sampling_config(self) -> SamplingConfig: if self.phase == ModelPhase.TRAIN: - return self.training_sampling_config # type: ignore + return self.training_sampling_config elif self.phase == ModelPhase.VAL: - return self.validation_sampling_config # type: ignore + return self.validation_sampling_config elif self.phase == ModelPhase.TEST: - return self.testing_sampling_config # type: ignore + return self.testing_sampling_config elif ( self.phase == ModelPhase.INFERENCE_SRC or self.phase == ModelPhase.INFERENCE_DST diff --git a/gigl/nn/models.py b/gigl/nn/models.py index 1c422ff09..d86920bdc 100644 --- a/gigl/nn/models.py +++ b/gigl/nn/models.py @@ -397,16 +397,20 @@ def _weighted_layer_sum( Returns: torch.Tensor: Weighted sum of all layer embeddings, shape [N, D]. """ - if len(all_layer_embeddings) != len(self._layer_weights): # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + if len(all_layer_embeddings) != len( + self._layer_weights + ): # https://github.com/Snapchat/GiGL/issues/408 raise ValueError( - f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # https://github.com/Snapchat/GiGL/issues/408 ) # Stack all layer embeddings and compute weighted sum # _layer_weights is already a tensor buffer registered in __init__ stacked = torch.stack(all_layer_embeddings, dim=0) # shape [K+1, N, D] w = self._layer_weights.to(stacked.device) # shape [K+1], ensure on same device - out = (stacked * w.view(-1, 1, 1)).sum( # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + out = ( + stacked * w.view(-1, 1, 1) + ).sum( # https://github.com/Snapchat/GiGL/issues/408 dim=0 ) # shape [N, D], w_0*X_0 + w_1*X_1 + ... diff --git a/gigl/orchestration/local/runner.py b/gigl/orchestration/local/runner.py index 207f4ae64..06bf4c2b2 100644 --- a/gigl/orchestration/local/runner.py +++ b/gigl/orchestration/local/runner.py @@ -89,7 +89,7 @@ def run( else: Runner.config_check(start_at, pipeline_config) - component_map: OrderedDict[GiGLComponents, Callable] = OrderedDict( + component_map: OrderedDict[GiGLComponents, Callable] = OrderedDict( # ty: ignore[invalid-assignment] { GiGLComponents.ConfigPopulator.value: Runner.run_config_populator, GiGLComponents.DataPreprocessor.value: Runner.run_data_preprocessor, diff --git a/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py b/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py index e5b2b5e73..443a3d6a8 100644 --- a/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py +++ b/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py @@ -129,7 +129,7 @@ def model(self) -> torch.nn.Module: @model.setter def model(self, model: torch.nn.Module) -> None: self.__model = model - self.__model.graph_backend = GraphBackend.PYG # type: ignore + self.__model.graph_backend = GraphBackend.PYG def init_model( self, @@ -517,9 +517,9 @@ def _compute_metrics( mrr = 1.0 / pos_rank.float() hit_rates = hit_rate_at_k( - pos_scores=pos_scores, # type: ignore - neg_scores=neg_scores, # type: ignore - ks=torch.tensor(ks, device=device, dtype=torch.long), # type: ignore + pos_scores=pos_scores, + neg_scores=neg_scores, + ks=torch.tensor(ks, device=device, dtype=torch.long), ) total_mrr += mrr.mean().item() diff --git a/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py index 82b0dbe3a..e61024642 100644 --- a/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py @@ -359,6 +359,10 @@ def train( main_data_loader = data_loaders.train_main random_negative_data_loader = data_loaders.train_random_negative + if main_data_loader is None or random_negative_data_loader is None: + raise ValueError( + "train_main and train_random_negative dataloaders must be set for training" + ) val_main_data_loader_iter = iter(data_loaders.val_main) # type: ignore val_random_data_loader_iter = iter(data_loaders.val_random_negative) # type: ignore @@ -375,7 +379,7 @@ def train( self.model.train() for batch_index, (main_batch, random_negative_batch) in enumerate( - zip(main_data_loader, random_negative_data_loader), # type: ignore[arg-type] + zip(main_data_loader, random_negative_data_loader), start=1, ): batch_st = time() @@ -543,7 +547,7 @@ def validate( hr_result = hit_rate_at_k( pos_scores=batch_scores.pos_scores, neg_scores=batch_scores.random_neg_scores, - ks=ks_for_evaluation, # type: ignore + ks=ks_for_evaluation, ) mrr_result = mean_reciprocal_rank( pos_scores=batch_scores.pos_scores, diff --git a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py index a20965e35..12e5a4e38 100644 --- a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py @@ -160,7 +160,8 @@ def _train( out = self.model(inputs) # Figure out why below is a typing issue loss = self._train_loss_fn( - input=out[root_node_indices], target=root_node_labels + input=out[root_node_indices], # ty: ignore[unknown-argument] + target=root_node_labels, # ty: ignore[unknown-argument] ) # type: ignore loss.backward() self._optimizer.step() @@ -200,7 +201,9 @@ def score( assert root_node_labels is not None results: InferBatchResults = self.infer_batch(batch=batch, device=device) - num_correct_in_batch = int((results.predictions == root_node_labels).sum()) # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + num_correct_in_batch = int( + (results.predictions == root_node_labels).sum() + ) # https://github.com/Snapchat/GiGL/issues/408 num_correct += num_correct_in_batch num_evaluated += len(batch.root_node_labels) diff --git a/gigl/src/common/modeling_task_specs/utils/infer.py b/gigl/src/common/modeling_task_specs/utils/infer.py index 1f93061df..17804a3ed 100644 --- a/gigl/src/common/modeling_task_specs/utils/infer.py +++ b/gigl/src/common/modeling_task_specs/utils/infer.py @@ -137,8 +137,10 @@ def infer_task_inputs( decoder = model.module.decode batch_result_types = model.module.tasks.result_types else: - decoder = model.decode # type: ignore # https://github.com/Snapchat/GiGL/issues/408 - batch_result_types = model.tasks.result_types # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + decoder = model.decode # https://github.com/Snapchat/GiGL/issues/408 + batch_result_types = ( + model.tasks.result_types + ) # https://github.com/Snapchat/GiGL/issues/408 # If we only have losses which only require the input batch, don't forward here and return the # input batch immediately to minimize computation we don't need, such as encoding and decoding. @@ -216,7 +218,7 @@ def infer_task_inputs( ].to(device=device) ) random_neg_root_embeddings[condensed_node_type] = ( - random_neg_embeddings[condensed_node_type][random_neg_root_node_indices] # type: ignore + random_neg_embeddings[condensed_node_type][random_neg_root_node_indices] if random_neg_root_node_indices.numel() else torch.FloatTensor([]).to(device=device) ) @@ -301,7 +303,7 @@ def infer_task_inputs( _batch_scores[condensed_supervision_edge_type] = BatchScores( pos_scores=pos_scores, hard_neg_scores=hard_neg_scores, - random_neg_scores=random_neg_scores_root, # type: ignore + random_neg_scores=random_neg_scores_root, ) if ModelResultType.batch_scores in batch_result_types or should_eval: @@ -321,12 +323,12 @@ def infer_task_inputs( condensed_supervision_edge_type ] pos_embeddings[condensed_supervision_edge_type] = ( - torch.cat(tuple(_pos_embeddings[condensed_supervision_edge_type])) # type: ignore + torch.cat(tuple(_pos_embeddings[condensed_supervision_edge_type])) if len(_pos_embeddings[condensed_supervision_edge_type]) else torch.tensor([]) ) hard_neg_embeddings[condensed_supervision_edge_type] = ( - torch.cat(tuple(_hard_neg_embeddings[condensed_supervision_edge_type])) # type: ignore + torch.cat(tuple(_hard_neg_embeddings[condensed_supervision_edge_type])) if len(_hard_neg_embeddings[condensed_supervision_edge_type]) else torch.tensor([]) ) @@ -336,7 +338,7 @@ def infer_task_inputs( torch.tensor(repeated_anchor_count[condensed_supervision_edge_type]).to( device=device ), - dim=0, # type: ignore + dim=0, ) ) @@ -434,21 +436,21 @@ def infer_task_inputs( batch_combined_scores[condensed_supervision_edge_type] = ( BatchCombinedScores( repeated_candidate_scores=repeated_candidate_scores, - positive_ids=global_positive_ids, # type: ignore - hard_neg_ids=global_hard_neg_ids, # type: ignore - random_neg_ids=global_random_neg_ids, # type: ignore - repeated_query_ids=repeated_global_query_ids, # type: ignore + positive_ids=global_positive_ids, + hard_neg_ids=global_hard_neg_ids, + random_neg_ids=global_random_neg_ids, + repeated_query_ids=repeated_global_query_ids, num_unique_query_ids=main_batch_root_node_indices.shape[0], ) ) # Populate all computed embeddings for task input batch_embeddings = BatchEmbeddings( - query_embeddings=query_embeddings, # type: ignore - repeated_query_embeddings=repeated_anchor_embeddings, # type: ignore - pos_embeddings=pos_embeddings, # type: ignore - hard_neg_embeddings=hard_neg_embeddings, # type: ignore - random_neg_embeddings=random_neg_root_embeddings, # type: ignore + query_embeddings=query_embeddings, + repeated_query_embeddings=repeated_anchor_embeddings, + pos_embeddings=pos_embeddings, + hard_neg_embeddings=hard_neg_embeddings, + random_neg_embeddings=random_neg_root_embeddings, ) return NodeAnchorBasedLinkPredictionTaskInputs( diff --git a/gigl/src/common/models/layers/count_min_sketch.py b/gigl/src/common/models/layers/count_min_sketch.py index 175642a4c..fb54657ff 100644 --- a/gigl/src/common/models/layers/count_min_sketch.py +++ b/gigl/src/common/models/layers/count_min_sketch.py @@ -84,7 +84,7 @@ def estimate_torch_long_tensor(self, tensor: torch.LongTensor) -> torch.LongTens Return the estimated count of all items in a torch long tensor """ tensor_cpu = tensor.cpu().numpy() - return torch.tensor( # type: ignore + return torch.tensor( [self.estimate(item) for item in tensor_cpu], dtype=torch.long, ) @@ -115,6 +115,6 @@ def calculate_in_batch_candidate_sampling_probability( because there is a larger error in P(candidate in batch | x) ~= P(candidate in batch) """ estimated_prob: torch.FloatTensor = ( - batch_size * frequency_tensor.float() / total_cnt # type: ignore + batch_size * frequency_tensor.float() / total_cnt ) return estimated_prob.clamp(max=1.0) diff --git a/gigl/src/common/models/layers/loss.py b/gigl/src/common/models/layers/loss.py index 76fa64c6a..4aab10dfc 100644 --- a/gigl/src/common/models/layers/loss.py +++ b/gigl/src/common/models/layers/loss.py @@ -65,7 +65,7 @@ def _calculate_margin_loss( input1=pos_scores_repeated, input2=all_neg_scores_repeated, target=ys, - margin=self.margin, # type: ignore + margin=self.margin, reduction="sum", ) sample_size = pos_scores_repeated.numel() @@ -142,7 +142,8 @@ def _calculate_softmax_loss( ) # shape=[num_pos_nodes] loss = F.cross_entropy( - input=all_scores / self.softmax_temperature, # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + input=all_scores + / self.softmax_temperature, # https://github.com/Snapchat/GiGL/issues/408 target=ys, reduction="sum", ) @@ -349,7 +350,7 @@ def forward( batch_combined_scores.random_neg_ids.to(device=device), ) ) - if repeated_query_embeddings.numel(): # type: ignore + if repeated_query_embeddings.numel(): loss = self.calculate_batch_retrieval_loss( scores=batch_combined_scores.repeated_candidate_scores, candidate_sampling_probability=candidate_sampling_probability, @@ -357,7 +358,7 @@ def forward( candidate_ids=candidate_ids, device=device, ) - batch_size = repeated_query_embeddings.shape[0] # type: ignore + batch_size = repeated_query_embeddings.shape[0] else: loss = torch.tensor(0.0).to(device=device) batch_size = 1 @@ -587,7 +588,7 @@ def forward( sim2 = F.cosine_similarity(q2, y1.detach()).mean() neg_sim1 = F.cosine_similarity(q1, neg_y.detach()).mean() # type: ignore neg_sim2 = F.cosine_similarity(q2, neg_y.detach()).mean() # type: ignore - loss = self.neg_lambda * (neg_sim1 + neg_sim2) - (1 - self.neg_lambda) * ( # type: ignore + loss = self.neg_lambda * (neg_sim1 + neg_sim2) - (1 - self.neg_lambda) * ( sim1 + sim2 ) return loss, 1 @@ -610,7 +611,7 @@ def forward( ) -> torch.Tensor: return ( (user_embeddings - item_embeddings).norm(p=2, dim=1).pow(self.alpha).mean() - ) # type: ignore + ) class UniformityLoss(nn.Module): @@ -635,7 +636,7 @@ def forward( .exp() .mean() .log() - ) # type: ignore + ) item_uniformity = ( torch.pdist(item_embeddings, p=2) .pow(2) @@ -643,7 +644,7 @@ def forward( .exp() .mean() .log() - ) # type: ignore + ) return (user_uniformity + item_uniformity) / 2 diff --git a/gigl/src/common/models/layers/task.py b/gigl/src/common/models/layers/task.py index b82dee44b..5a5bb8e54 100644 --- a/gigl/src/common/models/layers/task.py +++ b/gigl/src/common/models/layers/task.py @@ -7,7 +7,7 @@ from torch_geometric.nn import GraphConv from gigl.common.logger import Logger -from gigl.src.common.modeling_task_specs.utils.infer import ( # type: ignore +from gigl.src.common.modeling_task_specs.utils.infer import ( infer_root_embeddings, infer_training_batch, ) @@ -29,7 +29,7 @@ UniformityLoss, WhiteningDecorrelationLoss, ) -from gigl.src.common.models.pyg.graph.augmentations import ( # type: ignore +from gigl.src.common.models.pyg.graph.augmentations import ( get_augmented_graph, ) from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper @@ -228,9 +228,9 @@ def __init__( hid_dim = self.encoder.hid_dim out_dim = self.encoder.out_dim self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), # type: ignore + torch.nn.Linear(out_dim, hid_dim), torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), # type: ignore + torch.nn.Linear(hid_dim, out_dim), ) self.loss = GRACELoss(temperature=temperature) self.feat_drop_1 = feat_drop_1 @@ -248,8 +248,8 @@ def forward( main_batch = task_input.input_batch.main_batch augmented_graph_1 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_1, # type: ignore - feat_drop_ratio=self.edge_drop_2, # type: ignore + edge_drop_ratio=self.edge_drop_1, + feat_drop_ratio=self.edge_drop_2, ) augmented_embeddings_1 = infer_root_embeddings( model=self.encoder, @@ -260,8 +260,8 @@ def forward( ) augmented_graph_2 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_2, # type: ignore - feat_drop_ratio=self.feat_drop_2, # type: ignore + edge_drop_ratio=self.edge_drop_2, + feat_drop_ratio=self.feat_drop_2, ) augmented_embeddings_2 = infer_root_embeddings( model=self.encoder, @@ -296,8 +296,8 @@ def __init__( out_dim = self.encoder.out_dim self.loss = FeatureReconstructionLoss(alpha=alpha) self.reconstruction_decoder = GraphConv(out_dim, in_dim) - self.reconstruction_mask = torch.nn.Parameter(torch.zeros(1, in_dim)) # type: ignore - self.reconstruction_enc_dec = torch.nn.Linear(out_dim, out_dim, bias=False) # type: ignore + self.reconstruction_mask = torch.nn.Parameter(torch.zeros(1, in_dim)) + self.reconstruction_enc_dec = torch.nn.Linear(out_dim, out_dim, bias=False) self.edge_drop = edge_drop def forward( @@ -318,7 +318,7 @@ def forward( main_batch = task_input.input_batch.main_batch augmented_graph = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop, # type: ignore + edge_drop_ratio=self.edge_drop, feat_drop_ratio=0.0, ) root_node_indices = main_batch.root_node_indices @@ -368,9 +368,9 @@ def __init__( out_dim = self.encoder.out_dim self.loss = WhiteningDecorrelationLoss(lambd=lambd) self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), # type: ignore + torch.nn.Linear(out_dim, hid_dim), torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), # type: ignore + torch.nn.Linear(hid_dim, out_dim), ) self.feat_drop_1 = feat_drop_1 self.edge_drop_1 = edge_drop_1 @@ -387,8 +387,8 @@ def forward( main_batch = task_input.input_batch.main_batch augmented_graph_1 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_1, # type: ignore - feat_drop_ratio=self.feat_drop_1, # type: ignore + edge_drop_ratio=self.edge_drop_1, + feat_drop_ratio=self.feat_drop_1, ) augmented_embeddings_1 = infer_root_embeddings( model=self.encoder, @@ -399,8 +399,8 @@ def forward( ) augmented_graph_2 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_2, # type: ignore - feat_drop_ratio=self.feat_drop_2, # type: ignore + edge_drop_ratio=self.edge_drop_2, + feat_drop_ratio=self.feat_drop_2, ) augmented_embeddings_2 = infer_root_embeddings( model=self.encoder, @@ -449,8 +449,8 @@ def forward( main_batch = task_input.input_batch.main_batch augmented_graph_1 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_1, # type: ignore - feat_drop_ratio=self.feat_drop_1, # type: ignore + edge_drop_ratio=self.edge_drop_1, + feat_drop_ratio=self.feat_drop_1, ) augmented_embeddings_1 = infer_root_embeddings( model=self.encoder, @@ -461,8 +461,8 @@ def forward( ) augmented_graph_2 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_2, # type: ignore - feat_drop_ratio=self.feat_drop_2, # type: ignore + edge_drop_ratio=self.edge_drop_2, + feat_drop_ratio=self.feat_drop_2, ) augmented_embeddings_2 = infer_root_embeddings( model=self.encoder, @@ -498,13 +498,13 @@ def __init__( hid_dim = self.encoder.hid_dim out_dim = self.encoder.out_dim self.offline_encoder = copy.deepcopy(encoder) - for param in self.offline_encoder.parameters(): # type: ignore + for param in self.offline_encoder.parameters(): param.requires_grad = False self.loss = BGRLLoss() self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), # type: ignore + torch.nn.Linear(out_dim, hid_dim), torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), # type: ignore + torch.nn.Linear(hid_dim, out_dim), ) self.feat_drop_1 = feat_drop_1 self.edge_drop_1 = edge_drop_1 @@ -521,13 +521,13 @@ def forward( main_batch = task_input.input_batch.main_batch augmented_graph_1 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_1, # type: ignore - feat_drop_ratio=self.feat_drop_1, # type: ignore + edge_drop_ratio=self.edge_drop_1, + feat_drop_ratio=self.feat_drop_1, ) augmented_graph_2 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_2, # type: ignore - feat_drop_ratio=self.feat_drop_2, # type: ignore + edge_drop_ratio=self.edge_drop_2, + feat_drop_ratio=self.feat_drop_2, ) enc1 = infer_root_embeddings( model=self.encoder, @@ -581,13 +581,13 @@ def __init__( hid_dim = self.encoder.hid_dim out_dim = self.encoder.out_dim self.offline_encoder = copy.deepcopy(encoder) - for param in self.offline_encoder.parameters(): # type: ignore + for param in self.offline_encoder.parameters(): param.requires_grad = False self.loss = TBGRLLoss(neg_lambda=neg_lambda) self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), # type: ignore + torch.nn.Linear(out_dim, hid_dim), torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), # type: ignore + torch.nn.Linear(hid_dim, out_dim), ) self.feat_drop_1 = feat_drop_1 @@ -607,18 +607,18 @@ def forward( main_batch = task_input.input_batch.main_batch augmented_graph_1 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_1, # type: ignore - feat_drop_ratio=self.feat_drop_1, # type: ignore + edge_drop_ratio=self.edge_drop_1, + feat_drop_ratio=self.feat_drop_1, ) augmented_graph_2 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_2, # type: ignore - feat_drop_ratio=self.feat_drop_2, # type: ignore + edge_drop_ratio=self.edge_drop_2, + feat_drop_ratio=self.feat_drop_2, ) augmented_graph_3 = get_augmented_graph( graph=main_batch.graph.to(device=device), - edge_drop_ratio=self.edge_drop_neg, # type: ignore - feat_drop_ratio=self.feat_drop_neg, # type: ignore + edge_drop_ratio=self.edge_drop_neg, + feat_drop_ratio=self.feat_drop_neg, graph_perm=True, ) enc1 = infer_root_embeddings( @@ -709,7 +709,9 @@ def _get_all_tasks( for task in list(self._task_to_weights_map.keys()): fn = self._task_to_fn_map[task] weight = self._task_to_weights_map[task] - tasks_list.append((fn, weight)) # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + tasks_list.append( + (fn, weight) + ) # https://github.com/Snapchat/GiGL/issues/408 return tasks_list def add_task( diff --git a/gigl/src/common/models/pyg/heterogeneous.py b/gigl/src/common/models/pyg/heterogeneous.py index 62558368a..8fb5cc18a 100644 --- a/gigl/src/common/models/pyg/heterogeneous.py +++ b/gigl/src/common/models/pyg/heterogeneous.py @@ -127,7 +127,7 @@ def forward( ) if self.should_l2_normalize_embedding_layer_output: - node_typed_embeddings = l2_normalize_embeddings( # type: ignore + node_typed_embeddings = l2_normalize_embeddings( node_typed_embeddings=node_typed_embeddings ) @@ -282,7 +282,7 @@ def forward( ) if self.should_l2_normalize_embedding_layer_output: - node_typed_embeddings = l2_normalize_embeddings( # type: ignore + node_typed_embeddings = l2_normalize_embeddings( node_typed_embeddings=node_typed_embeddings ) return node_typed_embeddings diff --git a/gigl/src/common/models/pyg/homogeneous.py b/gigl/src/common/models/pyg/homogeneous.py index 71f7e35f0..5af61c82b 100644 --- a/gigl/src/common/models/pyg/homogeneous.py +++ b/gigl/src/common/models/pyg/homogeneous.py @@ -66,7 +66,7 @@ def __init__( self.feature_embedding_layer = feature_embedding_layer # Feature interaction layers self.feats_interaction = feature_interaction_layer - self.conv_layers: nn.ModuleList = self.init_conv_layers( # type: ignore + self.conv_layers: nn.ModuleList = self.init_conv_layers( in_dim=in_dim, out_dim=hid_dim if linear_layer or jk_mode else out_dim, edge_dim=edge_dim, @@ -98,7 +98,7 @@ def __init__( lstm_dim=jk_lstm_dim, ) else: - self.jk_layer = None # type: ignore + self.jk_layer = None self.return_emb = return_emb self.linear_layer = linear_layer if linear_layer: diff --git a/gigl/src/common/models/pyg/link_prediction.py b/gigl/src/common/models/pyg/link_prediction.py index 89ab589fd..274b8e436 100644 --- a/gigl/src/common/models/pyg/link_prediction.py +++ b/gigl/src/common/models/pyg/link_prediction.py @@ -69,7 +69,7 @@ def decode( @property def tasks(self) -> NodeAnchorBasedLinkPredictionTasks: - return self.__tasks # type: ignore + return self.__tasks @property def graph_backend(self) -> GraphBackend: diff --git a/gigl/src/common/models/pyg/nn/conv/gin_conv.py b/gigl/src/common/models/pyg/nn/conv/gin_conv.py index cbd780ca8..58be15a2c 100644 --- a/gigl/src/common/models/pyg/nn/conv/gin_conv.py +++ b/gigl/src/common/models/pyg/nn/conv/gin_conv.py @@ -62,7 +62,7 @@ def forward( ) -> Tensor: """""" if isinstance(x, Tensor): - x: OptPairTensor = (x, x) # type: ignore + x: OptPairTensor = (x, x) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) diff --git a/gigl/src/common/models/pyg/nn/conv/hgt_conv.py b/gigl/src/common/models/pyg/nn/conv/hgt_conv.py index 3cb7be3e0..d651dae1c 100644 --- a/gigl/src/common/models/pyg/nn/conv/hgt_conv.py +++ b/gigl/src/common/models/pyg/nn/conv/hgt_conv.py @@ -150,8 +150,8 @@ def _construct_src_node_feat( ks.append(k_dict[src]) vs.append(v_dict[src]) - ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) # type: ignore - vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) # type: ignore + ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) + vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) type_vec = torch.cat(type_list, dim=1).flatten() k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1) diff --git a/gigl/src/common/models/pyg/nn/models/feature_embedding.py b/gigl/src/common/models/pyg/nn/models/feature_embedding.py index 26bc3d645..1ddde48d6 100644 --- a/gigl/src/common/models/pyg/nn/models/feature_embedding.py +++ b/gigl/src/common/models/pyg/nn/models/feature_embedding.py @@ -107,7 +107,7 @@ def __init__( ): feature_padding_value = str( self.__feature_padding_value_map[feature_name] - ) # type: ignore + ) feature_padding_idx = self.__feature_schema.feature_vocab[ feature_name ].index(feature_padding_value) diff --git a/gigl/src/common/models/pyg/nn/models/jumping_knowledge.py b/gigl/src/common/models/pyg/nn/models/jumping_knowledge.py index 96bc1b9ee..1d5f0403e 100644 --- a/gigl/src/common/models/pyg/nn/models/jumping_knowledge.py +++ b/gigl/src/common/models/pyg/nn/models/jumping_knowledge.py @@ -77,16 +77,16 @@ def __init__( self.output_linear = nn.Linear(hid_dim, out_dim) elif self.mode == "cat": assert num_layers is not None, "num_layers cannot be none for cat mode" - self.lstm_dim = None # type: ignore - self.lstm = None # type: ignore - self.att = None # type: ignore + self.lstm_dim = None + self.lstm = None + self.att = None self.num_layers = num_layers self.output_linear = nn.Linear((num_layers * hid_dim), out_dim) else: # self.mode == "max" - self.lstm_dim = None # type: ignore - self.lstm = None # type: ignore - self.att = None # type: ignore - self.num_layers = None # type: ignore + self.lstm_dim = None + self.lstm = None + self.att = None + self.num_layers = None self.output_linear = nn.Linear(hid_dim, out_dim) self.reset_parameters() diff --git a/gigl/src/common/translators/training_samples_protos_translator.py b/gigl/src/common/translators/training_samples_protos_translator.py index 66f68c9a3..56c7f7a49 100644 --- a/gigl/src/common/translators/training_samples_protos_translator.py +++ b/gigl/src/common/translators/training_samples_protos_translator.py @@ -141,9 +141,7 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( ): condensed_supervision_edge_type_to_pos_edge_feats[ condensed_edge_type - ].append( - pos_edge[1] # type: ignore - ) + ].append(pos_edge[1]) for hard_neg_edge_pb in sample.hard_neg_edges: hard_neg_edge: Tuple[Edge, Optional[torch.Tensor]] = ( @@ -167,9 +165,7 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( ): condensed_supervision_edge_type_to_hard_neg_edge_feats[ condensed_edge_type - ].append( - hard_neg_edge[1] # type: ignore - ) + ].append(hard_neg_edge[1]) for condensed_edge_type in graph_metadata_pb_wrapper.condensed_edge_types: condensed_edge_type_to_supervision_edge_data[condensed_edge_type] = ( @@ -181,8 +177,8 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( condensed_edge_type ], pos_edge_features=( - torch.stack( # type: ignore - condensed_supervision_edge_type_to_pos_edge_feats[ # type: ignore + torch.stack( + condensed_supervision_edge_type_to_pos_edge_feats[ condensed_edge_type ] ) @@ -195,8 +191,8 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( else None ), hard_neg_edge_features=( - torch.stack( # type: ignore - condensed_supervision_edge_type_to_hard_neg_edge_feats[ # type: ignore + torch.stack( + condensed_supervision_edge_type_to_hard_neg_edge_feats[ condensed_edge_type ] ) diff --git a/gigl/src/common/types/graph_data.py b/gigl/src/common/types/graph_data.py index f566392c6..13f4c45ad 100644 --- a/gigl/src/common/types/graph_data.py +++ b/gigl/src/common/types/graph_data.py @@ -72,12 +72,12 @@ def from_nodes(cls, src_node: Node, dst_node: Node, relation: Relation) -> Edge: ) return edge - @property # type: ignore + @property @lru_cache(maxsize=1) def src_node(self) -> Node: return Node(id=self.src_node_id, type=self.edge_type.src_node_type) - @property # type: ignore + @property @lru_cache(maxsize=1) def dst_node(self) -> Node: return Node(id=self.dst_node_id, type=self.edge_type.dst_node_type) diff --git a/gigl/src/common/types/pb_wrappers/graph_data_types.py b/gigl/src/common/types/pb_wrappers/graph_data_types.py index fa33dc033..858297095 100644 --- a/gigl/src/common/types/pb_wrappers/graph_data_types.py +++ b/gigl/src/common/types/pb_wrappers/graph_data_types.py @@ -87,7 +87,7 @@ def __hash__(self) -> int: ) ) object.__setattr__(self, _HASH_CACHE_KEY, h) - return cast(int, h) + return h def __getstate__(self): state = self.__dict__.copy() @@ -159,7 +159,7 @@ def __hash__(self) -> int: ) ) object.__setattr__(self, _HASH_CACHE_KEY, h) - return cast(int, h) + return h def __getstate__(self): state = self.__dict__.copy() @@ -186,24 +186,24 @@ def __repr__(self): class GraphPbWrapper: pb: graph_schema_pb2.Graph - @property # type: ignore + @property @lru_cache(maxsize=1) def nodes_wrapper(self) -> list[NodePbWrapper]: # TODO: rename to nodes_pb_wrapper for clarity return [NodePbWrapper(pb=node_pb2) for node_pb2 in self.pb.nodes] - @property # type: ignore + @property @lru_cache(maxsize=1) def nodes_pb(self) -> list[graph_schema_pb2.Node]: return list(self.pb.nodes) - @property # type: ignore + @property @lru_cache(maxsize=1) def edges_wrapper(self) -> list[EdgePbWrapper]: # TODO: rename to edges_pb_wrapper for clarity return [EdgePbWrapper(pb=edge_pb2) for edge_pb2 in self.pb.edges] - @property # type: ignore + @property @lru_cache(maxsize=1) def edges_pb(self) -> list[graph_schema_pb2.Edge]: return list(self.pb.edges) @@ -335,7 +335,7 @@ def __hash__(self) -> int: ) h = hash(sorted_graph_pb_repr) object.__setattr__(self, _HASH_CACHE_KEY, h) - return cast(int, h) + return h def __getstate__(self): state = self.__dict__.copy() diff --git a/gigl/src/common/types/pb_wrappers/graph_metadata.py b/gigl/src/common/types/pb_wrappers/graph_metadata.py index ea4aec8b0..60e20ca2f 100644 --- a/gigl/src/common/types/pb_wrappers/graph_metadata.py +++ b/gigl/src/common/types/pb_wrappers/graph_metadata.py @@ -140,7 +140,7 @@ def homogeneous_condensed_edge_type(self) -> CondensedEdgeType: ) return self.condensed_edge_types[0] - @property # type: ignore + @property @lru_cache(maxsize=1) def condensed_node_type_to_node_type_map(self) -> dict[CondensedNodeType, NodeType]: return { @@ -148,12 +148,12 @@ def condensed_node_type_to_node_type_map(self) -> dict[CondensedNodeType, NodeTy for condensed_node_type, node_type in self.graph_metadata_pb.condensed_node_type_map.items() } - @property # type: ignore + @property @lru_cache(maxsize=1) def node_type_to_condensed_node_type_map(self) -> dict[NodeType, CondensedNodeType]: return {v: k for k, v in self.condensed_node_type_to_node_type_map.items()} - @property # type: ignore + @property @lru_cache(maxsize=1) def condensed_edge_type_to_edge_type_map(self) -> dict[CondensedEdgeType, EdgeType]: return { @@ -165,32 +165,32 @@ def condensed_edge_type_to_edge_type_map(self) -> dict[CondensedEdgeType, EdgeTy for condensed_edge_type, edge_type in self.graph_metadata_pb.condensed_edge_type_map.items() } - @property # type: ignore + @property @lru_cache(maxsize=1) def edge_type_to_condensed_edge_type_map(self) -> dict[EdgeType, CondensedEdgeType]: return {v: k for k, v in self.condensed_edge_type_to_edge_type_map.items()} - @property # type: ignore + @property @lru_cache(maxsize=1) def edge_types(self) -> list[EdgeType]: return list(self.condensed_edge_type_to_edge_type_map.values()) - @property # type: ignore + @property @lru_cache(maxsize=1) def node_types(self) -> list[NodeType]: return list(self.condensed_node_type_to_node_type_map.values()) - @property # type: ignore + @property @lru_cache(maxsize=1) def condensed_edge_types(self) -> list[CondensedEdgeType]: return list(self.condensed_edge_type_to_edge_type_map.keys()) - @property # type: ignore + @property @lru_cache(maxsize=1) def condensed_node_types(self) -> list[CondensedNodeType]: return list(self.condensed_node_type_to_node_type_map.keys()) - @property # type: ignore + @property @lru_cache(maxsize=1) def is_heterogeneous(self) -> bool: return len(self.edge_types) > 1 or len(self.node_types) > 1 diff --git a/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py b/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py index 5f841c77a..658d071f9 100644 --- a/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py +++ b/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py @@ -283,12 +283,12 @@ def __get_feature_to_vocab_list_map( if isinstance(transform_fn_assets_uri, LocalUri): list_files_fn = partial( LocalFsUtils.list_at_path, entity=LocalFsUtils.FileSystemEntity.FILE - ) # type: ignore - read_file_fn = lambda path: open(path, "rb") # type: ignore + ) + read_file_fn = lambda path: open(path, "rb") elif isinstance(transform_fn_assets_uri, GcsUri): gcs_utils = GcsUtils() - list_files_fn = gcs_utils.list_uris_with_gcs_path_pattern # type: ignore - read_file_fn = gcs_utils.download_file_from_gcs_to_temp_file # type: ignore + list_files_fn = gcs_utils.list_uris_with_gcs_path_pattern + read_file_fn = gcs_utils.download_file_from_gcs_to_temp_file else: raise ValueError( f"Invalid uri: {transform_fn_assets_uri}. Must be either {GcsUri.__name__} or {LocalUri.__name__}" @@ -298,7 +298,7 @@ def __get_feature_to_vocab_list_map( feature_to_vocab_list_map = {} for asset_file_path in assets_file_paths: feature_key = asset_file_path.uri.split("/")[-1] - f = read_file_fn(asset_file_path) + f = read_file_fn(asset_file_path) # ty: ignore[invalid-argument-type] vocab_list = [line.decode().rstrip() for line in f] feature_to_vocab_list_map[feature_key] = vocab_list f.close() diff --git a/gigl/src/common/utils/eval_metrics.py b/gigl/src/common/utils/eval_metrics.py index edee5a30e..147482fff 100644 --- a/gigl/src/common/utils/eval_metrics.py +++ b/gigl/src/common/utils/eval_metrics.py @@ -1,5 +1,3 @@ -from typing import cast - import torch @@ -45,7 +43,7 @@ def hit_rate_at_k( ) ks_adjusted = ks - 1 # subtract 1 since indices are 0-indexed hits_at_ks = torch.gather(input=hit_rates_padded, dim=0, index=ks_adjusted) - return cast(torch.FloatTensor, hits_at_ks) + return hits_at_ks def mean_reciprocal_rank( @@ -70,4 +68,4 @@ def mean_reciprocal_rank( adjusted_ranks = unadjusted_ranks + 1 # +1 since ranks are 0-indexed here reciprocal_ranks = 1.0 / adjusted_ranks # compute reciprocal mrr = torch.mean(reciprocal_ranks) - return cast(torch.FloatTensor, mrr) + return mrr diff --git a/gigl/src/common/utils/file_loader.py b/gigl/src/common/utils/file_loader.py index 0a362a489..701c286a6 100644 --- a/gigl/src/common/utils/file_loader.py +++ b/gigl/src/common/utils/file_loader.py @@ -1,7 +1,7 @@ import shutil import tempfile from collections.abc import Mapping -from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper # type: ignore +from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper from typing import IO, AnyStr, Optional, Sequence, Tuple, Type, Union, cast from gigl.common import GcsUri, HttpUri, LocalUri, Uri, UriFactory @@ -213,10 +213,16 @@ def load_from_filelike(self, uri: Uri, filelike: IO[AnyStr]) -> None: filelike.seek(ptr) # Reset the file pointer after reading if isinstance(first, bytes): with open(uri.uri, "wb") as dest: - shutil.copyfileobj(filelike, dest) + shutil.copyfileobj( + filelike, + dest, # ty: ignore[invalid-argument-type] + ) else: with open(uri.uri, "w", encoding="utf-8") as dest: - shutil.copyfileobj(filelike, dest) + shutil.copyfileobj( + filelike, + dest, # ty: ignore[invalid-argument-type] + ) else: raise NotImplementedError( @@ -266,12 +272,12 @@ def does_uri_exist(self, uri: Union[str, Uri]) -> bool: _uri = UriFactory.create_uri(uri=uri) if isinstance(uri, str) else uri exists: bool - if type(_uri) == GcsUri: + if isinstance(_uri, GcsUri): exists = self.__gcs_utils.does_gcs_file_exist(gcs_path=_uri) - elif type(_uri) == LocalUri: - exists = does_path_exist(cast(LocalUri, _uri)) - elif type(_uri) == HttpUri: - exists = HttpUtils.does_http_path_resolve(http_path=cast(HttpUri, _uri)) + elif isinstance(_uri, LocalUri): + exists = does_path_exist(_uri) + elif isinstance(_uri, HttpUri): + exists = HttpUtils.does_http_path_resolve(http_path=_uri) else: raise NotImplementedError(f"{self.__unsupported_uri_message} : {_uri}") return exists diff --git a/gigl/src/config_populator/config_populator.py b/gigl/src/config_populator/config_populator.py index bc7cf40a5..d53a6fa11 100644 --- a/gigl/src/config_populator/config_populator.py +++ b/gigl/src/config_populator/config_populator.py @@ -559,7 +559,7 @@ def _populate_frozen_gbml_config_pb( # Build SharedConfig from constants, and merge into the content of the template / input GbmlConfig. shared_config_pb = gbml_config_pb2.GbmlConfig.SharedConfig( - preprocessed_metadata_uri=preprocessed_metadata_uri.uri, # type: ignore + preprocessed_metadata_uri=preprocessed_metadata_uri.uri, flattened_graph_metadata=flattened_graph_metadata_pb, dataset_metadata=dataset_metadata_pb, trained_model_metadata=trained_model_metadata_pb, diff --git a/gigl/src/data_preprocessor/data_preprocessor.py b/gigl/src/data_preprocessor/data_preprocessor.py index 95d063aaf..98c0bc153 100644 --- a/gigl/src/data_preprocessor/data_preprocessor.py +++ b/gigl/src/data_preprocessor/data_preprocessor.py @@ -8,9 +8,9 @@ import tensorflow as tf import tensorflow_data_validation as tfdv -import tensorflow_transform as tft from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult from apache_beam.runners.runner import PipelineState +from tensorflow_transform.tf_metadata import schema_utils import gigl.common.utils.dataflow import gigl.src.common.constants.gcs as gcs_constants @@ -166,7 +166,7 @@ def __import_data_preprocessor_config(self) -> DataPreprocessorConfig: data_preprocessor_cls_str: str = self.gbml_config_pb_wrapper.dataset_config.data_preprocessor_config.data_preprocessor_config_cls_path data_preprocessor_cls = os_utils.import_obj(data_preprocessor_cls_str) - kwargs = self.gbml_config_pb_wrapper.dataset_config.data_preprocessor_config.data_preprocessor_args # type: ignore + kwargs = self.gbml_config_pb_wrapper.dataset_config.data_preprocessor_config.data_preprocessor_args try: data_preprocessor_config: DataPreprocessorConfig = data_preprocessor_cls( @@ -262,9 +262,7 @@ def __get_feature_dimension_for_single_data_reference( schema_path: Uri, feature_outputs: list[str] ) -> int: schema = tfdv.load_schema_text(schema_path.uri) - feature_spec = tft.tf_metadata.schema_utils.schema_as_feature_spec( - schema - ).feature_spec + feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec feature_dimension = 0 for feature in feature_spec: if feature in feature_outputs: diff --git a/gigl/src/data_preprocessor/lib/ingest/bigquery.py b/gigl/src/data_preprocessor/lib/ingest/bigquery.py index 63811da72..b3566e393 100644 --- a/gigl/src/data_preprocessor/lib/ingest/bigquery.py +++ b/gigl/src/data_preprocessor/lib/ingest/bigquery.py @@ -37,14 +37,13 @@ def _get_bigquery_ptransform( InstanceDictPTransform, beam.io.ReadFromBigQuery( table=table_name, - method=beam.io.ReadFromBigQuery.Method.EXPORT, # type: ignore + method=beam.io.ReadFromBigQuery.Method.EXPORT, *args, **kwargs, ), ) -# Below type ignores are due to star expansion issues with the type checker: https://github.com/python/mypy/issues/6799 @dataclass(frozen=True) class BigqueryNodeDataReference(NodeDataReference): """ @@ -69,11 +68,11 @@ class BigqueryNodeDataReference(NodeDataReference): def yield_instance_dict_ptransform(self, *args, **kwargs) -> InstanceDictPTransform: return _get_bigquery_ptransform( - table_name=self.reference_uri, - sharded_read_config=self.sharded_read_config, + self.reference_uri, + self.sharded_read_config, *args, **kwargs, - ) # type: ignore + ) def __repr__(self) -> str: return f"BigqueryNodeDataReference(node_type={self.node_type}, identifier={self.identifier}, reference_uri={self.reference_uri}, sharded_read_config={self.sharded_read_config})" @@ -106,11 +105,11 @@ class BigqueryEdgeDataReference(EdgeDataReference): def yield_instance_dict_ptransform(self, *args, **kwargs) -> InstanceDictPTransform: return _get_bigquery_ptransform( - table_name=self.reference_uri, - sharded_read_config=self.sharded_read_config, + self.reference_uri, + self.sharded_read_config, *args, **kwargs, - ) # type: ignore + ) def __repr__(self) -> str: return f"BigqueryEdgeDataReference(edge_type={self.edge_type}, src_identifier={self.src_identifier}, dst_identifier={self.dst_identifier}, reference_uri={self.reference_uri}, sharded_read_config={self.sharded_read_config})" diff --git a/gigl/src/data_preprocessor/lib/ingest/reference.py b/gigl/src/data_preprocessor/lib/ingest/reference.py index d382d855f..c86807837 100644 --- a/gigl/src/data_preprocessor/lib/ingest/reference.py +++ b/gigl/src/data_preprocessor/lib/ingest/reference.py @@ -7,10 +7,8 @@ from gigl.src.common.types.graph_data import EdgeType, EdgeUsageType, NodeType from gigl.src.data_preprocessor.lib.types import InstanceDictPTransform -# Type hints for abstract dataclasses may have limited support in type checkers. https://github.com/python/mypy/issues/5374 - -@dataclass(frozen=True) # type: ignore +@dataclass(frozen=True) class DataReference(ABC): """ Contains a URI string to the data reference, and provides a means of yielding @@ -37,7 +35,7 @@ def yield_instance_dict_ptransform(self, *args, **kwargs) -> InstanceDictPTransf raise NotImplementedError -@dataclass(frozen=True) # type: ignore +@dataclass(frozen=True) class NodeDataReference(DataReference, ABC): """ DataReference which stores node data. @@ -52,7 +50,7 @@ def __repr__(self) -> str: return f"NodeDataReference(node_type={self.node_type}, identifier={self.identifier}, reference_uri={self.reference_uri})" -@dataclass(frozen=True) # type: ignore +@dataclass(frozen=True) class EdgeDataReference(DataReference, ABC): """ DataReference which stores edge data diff --git a/gigl/src/data_preprocessor/lib/transform/tf_value_encoder.py b/gigl/src/data_preprocessor/lib/transform/tf_value_encoder.py index 70d647f2c..b3da385d6 100644 --- a/gigl/src/data_preprocessor/lib/transform/tf_value_encoder.py +++ b/gigl/src/data_preprocessor/lib/transform/tf_value_encoder.py @@ -76,9 +76,13 @@ def encode_value_as_feature(value: Any, dtype: tf.dtypes.DType) -> tf.train.Feat # encode value if dtype.is_integer or dtype.is_bool: - tf_feature = TFValueEncoder.__int_values_to_tf_feature(value=value) + tf_feature = TFValueEncoder.__int_values_to_tf_feature( + value=value # ty: ignore[invalid-argument-type] + ) elif dtype.is_floating: - tf_feature = TFValueEncoder.__float_values_to_tf_feature(value=value) + tf_feature = TFValueEncoder.__float_values_to_tf_feature( + value=value # ty: ignore[invalid-argument-type] + ) else: tf_feature = TFValueEncoder.__bytes_values_to_tf_feature(value=value) diff --git a/gigl/src/data_preprocessor/lib/transform/utils.py b/gigl/src/data_preprocessor/lib/transform/utils.py index f2b990abf..7e9be7081 100644 --- a/gigl/src/data_preprocessor/lib/transform/utils.py +++ b/gigl/src/data_preprocessor/lib/transform/utils.py @@ -3,8 +3,11 @@ import apache_beam as beam import pyarrow as pa import tensorflow_data_validation as tfdv +import tensorflow_data_validation.utils.display_util import tensorflow_transform import tfx_bsl +import tfx_bsl.tfxio.tensor_adapter +import tfx_bsl.tfxio.tf_example_record from apache_beam.pvalue import PBegin, PCollection, PDone from tensorflow_metadata.proto.v0 import schema_pb2, statistics_pb2 from tensorflow_transform import beam as tft_beam @@ -12,7 +15,7 @@ from tfx_bsl.tfxio.record_based_tfxio import RecordBasedTFXIO from gigl.common import GcsUri, LocalUri, Uri -from gigl.common.beam.better_tfrecordio import BetterWriteToTFRecord # type: ignore +from gigl.common.beam.better_tfrecordio import BetterWriteToTFRecord from gigl.common.logger import Logger from gigl.env.pipelines_config import get_resource_config from gigl.src.common.constants.components import GiGLComponents diff --git a/gigl/src/inference/v1/gnn_inferencer.py b/gigl/src/inference/v1/gnn_inferencer.py index 555c10cea..759aa6707 100644 --- a/gigl/src/inference/v1/gnn_inferencer.py +++ b/gigl/src/inference/v1/gnn_inferencer.py @@ -248,7 +248,7 @@ def __run( inferencer_instance: BaseInferencer = self.generate_inferencer_instance() graph_builder = GraphBuilderFactory.get_graph_builder( - backend_name=inferencer_instance.model.graph_backend # type: ignore + backend_name=inferencer_instance.model.graph_backend ) inference_blueprint: BaseInferenceBlueprint = ( diff --git a/gigl/src/mocking/dataset_asset_mocking_suite.py b/gigl/src/mocking/dataset_asset_mocking_suite.py index e37ee472d..6b924aadf 100644 --- a/gigl/src/mocking/dataset_asset_mocking_suite.py +++ b/gigl/src/mocking/dataset_asset_mocking_suite.py @@ -281,7 +281,7 @@ def mock_dblp_node_anchor_based_link_prediction_dataset( data, node_types, edge_types = self._get_pyg_dblp_dataset() mocked_dataset_info = MockedDatasetInfo( name="dblp_node_anchor_edge_features_lp", - task_metadata_type=TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK, # type: ignore + task_metadata_type=TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK, edge_index={ edge_types["author_to_paper"]: data[ edge_types["author_to_paper"].tuple_repr() @@ -327,7 +327,7 @@ def mock_dblp_node_anchor_based_link_prediction_dataset_with_user_defined_labels ) mocked_dataset_info = MockedDatasetInfo( name="dblp_node_anchor_edge_features_user_defined_labels", - task_metadata_type=TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK, # type: ignore + task_metadata_type=TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK, edge_index={ edge_types["author_to_paper"]: data[ edge_types["author_to_paper"].tuple_repr() diff --git a/gigl/src/mocking/lib/mock_input_for_subgraph_sampler.py b/gigl/src/mocking/lib/mock_input_for_subgraph_sampler.py index ff03bfc8f..006dd99f9 100644 --- a/gigl/src/mocking/lib/mock_input_for_subgraph_sampler.py +++ b/gigl/src/mocking/lib/mock_input_for_subgraph_sampler.py @@ -295,7 +295,7 @@ def generate_preprocessed_tfrecord_data( feature_keys=node_preprocess_metadata.feature_cols, label_keys=[node_preprocess_metadata.label_col] if node_preprocess_metadata.label_col is not None - else None, # type: ignore + else None, tfrecord_uri_prefix=node_preprocess_metadata.features_uri.uri, schema_uri=node_preprocess_metadata.schema_uri.uri, feature_dim=num_features, diff --git a/gigl/src/mocking/lib/pyg_to_training_samples.py b/gigl/src/mocking/lib/pyg_to_training_samples.py index 684ca7fd4..61c3ad85b 100644 --- a/gigl/src/mocking/lib/pyg_to_training_samples.py +++ b/gigl/src/mocking/lib/pyg_to_training_samples.py @@ -111,7 +111,7 @@ def _build_graph_pb_wrapper_from_hetero_data( src_node_id=global_src_node_id, dst_node_id=global_dst_node_id, condensed_edge_type=condensed_edge_type, - feature_values=edge_feature_value, # type: ignore + feature_values=edge_feature_value, ) khop_subgraph_edges.append(edge) @@ -136,7 +136,7 @@ def _build_graph_pb_wrapper_from_hetero_data( node = graph_schema_pb2.Node( node_id=global_node_id, condensed_node_type=condensed_node_type, - feature_values=node_feature_value, # type: ignore + feature_values=node_feature_value, ) khop_subgraph_nodes.append(node) @@ -203,7 +203,7 @@ def _get_random_negative_samples_for_pos_edges( pos_node_ids = edge_index[0].repeat(num_negative_samples_per_pos_edge) neg_node_ids = torch.randint(low=0, high=num_nodes, size=[pos_node_ids.numel()]) - return torch.vstack((pos_node_ids, neg_node_ids)) # type: ignore + return torch.vstack((pos_node_ids, neg_node_ids)) def _build_rooted_node_neighborhood_samples_from_subgraphs( @@ -216,7 +216,7 @@ def _build_rooted_node_neighborhood_samples_from_subgraphs( root_node=graph_schema_pb2.Node( node_id=int(root_node_id), condensed_node_type=condensed_node_type, - feature_values=None, # type: ignore + feature_values=None, ), neighborhood=subgraph.pb, ) @@ -255,7 +255,7 @@ def build_supervised_node_classification_samples_from_pyg_heterodata( root_node=graph_schema_pb2.Node( node_id=int(root_node_id), condensed_node_type=condensed_node_type, - feature_values=None, # type: ignore + feature_values=None, ), neighborhood=subgraph.pb, root_node_labels=[ @@ -407,7 +407,7 @@ def build_node_anchor_link_prediction_samples_from_pyg_heterodata( root_node_pb = graph_schema_pb2.Node( node_id=root_node_id, condensed_node_type=condensed_src_node_type, - feature_values=None, # type: ignore + feature_values=None, ) subgraphs_to_merge.append(src_node_id_to_k_hop_subgraph[root_node_id]) diff --git a/gigl/src/mocking/lib/user_defined_edge_sampling.py b/gigl/src/mocking/lib/user_defined_edge_sampling.py index de39256ba..7e48ff903 100644 --- a/gigl/src/mocking/lib/user_defined_edge_sampling.py +++ b/gigl/src/mocking/lib/user_defined_edge_sampling.py @@ -19,14 +19,14 @@ def sample_hydrate_user_def_edge( ) edge_index = ( mocked_dataset_info.user_defined_edge_index[ - mocked_dataset_info.sample_edge_type # type: ignore + mocked_dataset_info.sample_edge_type ][edge_usage_type] if mocked_dataset_info.user_defined_edge_index else None ) edge_feats = ( mocked_dataset_info.user_defined_edge_feats[ - mocked_dataset_info.sample_edge_type # type: ignore + mocked_dataset_info.sample_edge_type ][edge_usage_type] if mocked_dataset_info.user_defined_edge_feats else None diff --git a/gigl/src/post_process/utils/unenumeration.py b/gigl/src/post_process/utils/unenumeration.py index ec42cdfc3..bc7972dbe 100644 --- a/gigl/src/post_process/utils/unenumeration.py +++ b/gigl/src/post_process/utils/unenumeration.py @@ -1,4 +1,5 @@ import concurrent +import concurrent.futures from typing import List from google.cloud import bigquery diff --git a/gigl/src/subgraph_sampler/subgraph_sampler.py b/gigl/src/subgraph_sampler/subgraph_sampler.py index 2129305fe..34a5e7d38 100644 --- a/gigl/src/subgraph_sampler/subgraph_sampler.py +++ b/gigl/src/subgraph_sampler/subgraph_sampler.py @@ -129,7 +129,7 @@ def run( graph_db_config.graph_db_ingestion_cls_path ) - graph_db_ingestion_args = graph_db_config.graph_db_ingestion_args # type: ignore + graph_db_ingestion_args = graph_db_config.graph_db_ingestion_args graph_db_args = graph_db_config.graph_db_args all_graph_db_args = {**graph_db_ingestion_args, **graph_db_args} diff --git a/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py b/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py index 818db3ac3..b0dace301 100644 --- a/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py +++ b/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py @@ -263,7 +263,7 @@ def get_default_data_loader( iterable_training_dataset, batch_size=config.batch_size, num_workers=config.num_workers, - collate_fn=collate_fn, # type: ignore + collate_fn=collate_fn, persistent_workers=False, pin_memory=config.pin_memory, ) diff --git a/gigl/src/training/v1/lib/data_loaders/rooted_node_neighborhood_data_loader.py b/gigl/src/training/v1/lib/data_loaders/rooted_node_neighborhood_data_loader.py index ef9e45ba2..47405f5ce 100644 --- a/gigl/src/training/v1/lib/data_loaders/rooted_node_neighborhood_data_loader.py +++ b/gigl/src/training/v1/lib/data_loaders/rooted_node_neighborhood_data_loader.py @@ -281,9 +281,7 @@ def get_default_data_loader( iterable_training_dataset: CombinedIterableDatasets[ RootedNodeNeighborhoodSample - ] = CombinedIterableDatasets( - iterable_dataset_map=iterable_dataset_map # type: ignore - ) + ] = CombinedIterableDatasets(iterable_dataset_map=iterable_dataset_map) collate_fn = partial( RootedNodeNeighborhoodBatch.collate_pyg_rooted_node_neighborhood_minibatch, @@ -296,7 +294,7 @@ def get_default_data_loader( iterable_training_dataset, batch_size=config.batch_size, num_workers=config.num_workers, - collate_fn=collate_fn, # type: ignore + collate_fn=collate_fn, persistent_workers=False, pin_memory=config.pin_memory, ) diff --git a/gigl/src/training/v1/lib/data_loaders/tf_records_iterable_dataset.py b/gigl/src/training/v1/lib/data_loaders/tf_records_iterable_dataset.py index f434f6522..d24fac9ec 100644 --- a/gigl/src/training/v1/lib/data_loaders/tf_records_iterable_dataset.py +++ b/gigl/src/training/v1/lib/data_loaders/tf_records_iterable_dataset.py @@ -3,8 +3,8 @@ import numpy as np import tensorflow as tf import tensorflow_data_validation as tfdv -import tensorflow_transform as tft import torch.utils.data +from tensorflow_transform.tf_metadata import schema_utils from gigl.common import Uri from gigl.src.common.types.graph_data import NodeType @@ -137,9 +137,7 @@ def __iter__(self) -> Iterator[dict[NodeType, T]]: def get_np_iterator_from_tfrecords(schema_path: Uri, tfrecord_files: list[str]): batch_size = 1 schema = tfdv.load_schema_text(schema_path.uri) - feature_spec = tft.tf_metadata.schema_utils.schema_as_feature_spec( - schema - ).feature_spec + feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec dataset = ( tf.data.TFRecordDataset(tfrecord_files) .map(lambda record: tf.io.parse_example(record, feature_spec)) diff --git a/gigl/src/validation_check/libs/resource_config_checks.py b/gigl/src/validation_check/libs/resource_config_checks.py index 98a12a360..8b85be6f7 100644 --- a/gigl/src/validation_check/libs/resource_config_checks.py +++ b/gigl/src/validation_check/libs/resource_config_checks.py @@ -227,7 +227,7 @@ def _validate_accelerator_type( """ Checks if the provided accelerator type is valid. """ - if proto_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name: # type: ignore + if proto_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name: assert ( proto_config.gpu_limit == 0 ), f"""gpu_limit must be equal to 0 for cpu training/inference, indicated by provided gpu_type {proto_config.gpu_type}. @@ -236,7 +236,7 @@ def _validate_accelerator_type( assert ( proto_config.gpu_limit > 0 ), f"""gpu_limit must be greater than 0 for gpu training/inference, indicated by provided gpu_type {proto_config.gpu_type}. - Got gpu_limit {proto_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training.""" # type: ignore + Got gpu_limit {proto_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training.""" def _validate_cloud_machine_config( diff --git a/gigl/types/graph.py b/gigl/types/graph.py index d70109637..1e3219240 100644 --- a/gigl/types/graph.py +++ b/gigl/types/graph.py @@ -275,7 +275,7 @@ def message_passing_to_positive_label( NodeType(edge_type[0]), Relation(edge_type[1]), NodeType(edge_type[2]) ) else: - return edge_type + return edge_type # ty: ignore[invalid-return-type] def message_passing_to_negative_label( @@ -299,7 +299,7 @@ def message_passing_to_negative_label( NodeType(edge_type[0]), Relation(edge_type[1]), NodeType(edge_type[2]) ) else: - return edge_type + return edge_type # ty: ignore[invalid-return-type] def is_label_edge_type( @@ -336,7 +336,11 @@ def label_edge_type_to_message_passing_edge_type( if isinstance(label_edge_type, EdgeType): return EdgeType(label_edge_type[0], Relation(relation), label_edge_type[2]) else: - return (label_edge_type[0], relation, label_edge_type[2]) + return ( + label_edge_type[0], + relation, + label_edge_type[2], + ) # ty: ignore[invalid-return-type] def select_label_edge_types( @@ -386,10 +390,7 @@ def select_label_edge_types( list, str, int, - # TODO(kmonte): Add GLT Partition book here - # We cannot at the moment as we type-ignore GLT - # And adding it as a type here will break the type checker. - # PartitionBook + PartitionBook, ) @@ -421,7 +422,7 @@ def to_heterogeneous_node( if x is None: return None if isinstance(x, dict): - return x + return x # ty: ignore[invalid-return-type] return {DEFAULT_HOMOGENEOUS_NODE_TYPE: x} @@ -453,7 +454,7 @@ def to_heterogeneous_edge( if x is None: return None if isinstance(x, dict): - return x + return x # ty: ignore[invalid-return-type] return {DEFAULT_HOMOGENEOUS_EDGE_TYPE: x} @@ -517,4 +518,8 @@ def reverse_edge_type(edge_type: _EdgeType) -> _EdgeType: if isinstance(edge_type, EdgeType): return EdgeType(edge_type[2], edge_type[1], edge_type[0]) else: - return (edge_type[2], edge_type[1], edge_type[0]) + return ( + edge_type[2], + edge_type[1], + edge_type[0], + ) # ty: ignore[invalid-return-type] diff --git a/tests/integration/pipeline/data_preprocessor/data_preprocessor_pipeline_test.py b/tests/integration/pipeline/data_preprocessor/data_preprocessor_pipeline_test.py index 60217efd3..9a651ede6 100644 --- a/tests/integration/pipeline/data_preprocessor/data_preprocessor_pipeline_test.py +++ b/tests/integration/pipeline/data_preprocessor/data_preprocessor_pipeline_test.py @@ -6,8 +6,8 @@ import numpy as np import tensorflow as tf import tensorflow_data_validation as tfdv -import tensorflow_transform as tft import torch +from tensorflow_transform.tf_metadata import schema_utils import gigl.common.utils.local_fs as local_fs_utils import gigl.src.common.constants.gcs as gcs_consts @@ -85,9 +85,7 @@ def __get_np_arrays_from_tfrecords( max_batch_size=16384, ) -> dict[str, np.ndarray]: schema = tfdv.load_schema_text(schema_path.uri) - feature_spec = tft.tf_metadata.schema_utils.schema_as_feature_spec( - schema - ).feature_spec + feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec dataset = ( tf.data.TFRecordDataset(tfrecord_files) .map(lambda record: tf.io.parse_example(record, feature_spec)) diff --git a/tests/integration/pipeline/split_generator/lib/node_anchor_based_link_prediction.py b/tests/integration/pipeline/split_generator/lib/node_anchor_based_link_prediction.py index 2163a3288..ddde7a954 100644 --- a/tests/integration/pipeline/split_generator/lib/node_anchor_based_link_prediction.py +++ b/tests/integration/pipeline/split_generator/lib/node_anchor_based_link_prediction.py @@ -102,20 +102,20 @@ def log_node_anchor_based_link_prediction_split_details( :return: """ logger.info( - f"Train split: {train_split.graph.num_nodes} nodes, " # type: ignore - f"{train_split.graph.num_edges} edges used in message passing." # type: ignore + f"Train split: {train_split.graph.num_nodes} nodes, " + f"{train_split.graph.num_edges} edges used in message passing." f" ( {len(train_split.pos_edges)} + supervision edges, " f"{len(train_split.hard_neg_edges)} - supervision edges )" ) logger.info( - f"Val split: {val_split.graph.num_nodes} nodes, " # type: ignore - f"{val_split.graph.num_edges} edges used in message passing." # type: ignore + f"Val split: {val_split.graph.num_nodes} nodes, " + f"{val_split.graph.num_edges} edges used in message passing." f" ( {len(val_split.pos_edges)} + supervision edges, " f"{len(val_split.hard_neg_edges)} - supervision edges )" ) logger.info( - f"Test split: {test_split.graph.num_nodes} nodes, " # type: ignore - f"{test_split.graph.num_edges} edges used in message passing." # type: ignore + f"Test split: {test_split.graph.num_nodes} nodes, " + f"{test_split.graph.num_edges} edges used in message passing." f" ({len(test_split.pos_edges)} + supervision edges, " f"{len(test_split.hard_neg_edges)} - supervision edges )" ) diff --git a/tests/integration/pipeline/split_generator/lib/supervised_node_classification.py b/tests/integration/pipeline/split_generator/lib/supervised_node_classification.py index ed1443a62..61ef25be8 100644 --- a/tests/integration/pipeline/split_generator/lib/supervised_node_classification.py +++ b/tests/integration/pipeline/split_generator/lib/supervised_node_classification.py @@ -76,17 +76,17 @@ def log_node_classification_split_details( :return: """ logger.info( - f"Train split: {train_split.graph.num_nodes} nodes " # type: ignore + f"Train split: {train_split.graph.num_nodes} nodes " f"({len(train_split.labeled_nodes)} labeled), " - f"{train_split.graph.num_edges} edges." # type: ignore + f"{train_split.graph.num_edges} edges." ) logger.info( - f"Val split: {val_split.graph.num_nodes} nodes " # type: ignore + f"Val split: {val_split.graph.num_nodes} nodes " f"({len(val_split.labeled_nodes)} labeled), " - f"{val_split.graph.num_edges} edges." # type: ignore + f"{val_split.graph.num_edges} edges." ) logger.info( - f"Test split: {test_split.graph.num_nodes} nodes " # type: ignore + f"Test split: {test_split.graph.num_nodes} nodes " f"({len(test_split.labeled_nodes)} labeled), " - f"{test_split.graph.num_edges} edges." # type: ignore + f"{test_split.graph.num_edges} edges." ) diff --git a/tests/test_assets/distributed/utils.py b/tests/test_assets/distributed/utils.py index cc15c9346..45a99c623 100644 --- a/tests/test_assets/distributed/utils.py +++ b/tests/test_assets/distributed/utils.py @@ -92,8 +92,8 @@ class MockGraphStoreInfo(GraphStoreInfo): """ def __init__(self, real_info: GraphStoreInfo, compute_node_rank: int): - self._real_info = real_info - self._compute_node_rank = compute_node_rank + self._real_info = real_info # ty: ignore[invalid-assignment] + self._compute_node_rank = compute_node_rank # ty: ignore[invalid-assignment] @property def num_storage_nodes(self) -> int: diff --git a/tests/unit/common/collections/frozen_dict_test.py b/tests/unit/common/collections/frozen_dict_test.py index a6847c439..b1282a530 100644 --- a/tests/unit/common/collections/frozen_dict_test.py +++ b/tests/unit/common/collections/frozen_dict_test.py @@ -7,7 +7,7 @@ def test_frozen_dict_is_frozen(self): frozen_dict: FrozenDict[int, int] = FrozenDict() def assign_dict_value(): - frozen_dict[10] = 20 # type: ignore [index] + frozen_dict[10] = 20 # type: ignore[index] # ty: ignore[invalid-assignment] self.assertRaises(Exception, assign_dict_value) diff --git a/tests/unit/distributed/dataset_input_metadata_translator_test.py b/tests/unit/distributed/dataset_input_metadata_translator_test.py index 6fa7c61aa..e5c5709b8 100644 --- a/tests/unit/distributed/dataset_input_metadata_translator_test.py +++ b/tests/unit/distributed/dataset_input_metadata_translator_test.py @@ -1,5 +1,5 @@ from collections import abc -from typing import Optional, Union +from typing import Optional, Union, cast from parameterized import param, parameterized @@ -261,22 +261,21 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) expected_entity_types=graph_metadata_pb_wrapper.edge_types, ) serialized_positive_label_info_iterable: list[SerializedTFRecordInfo] - if isinstance( - serialized_graph_metadata.positive_label_entity_info, abc.Mapping - ): + positive_label_entity_info = ( + serialized_graph_metadata.positive_label_entity_info + ) + if isinstance(positive_label_entity_info, dict): serialized_positive_label_info_iterable = list( - serialized_graph_metadata.positive_label_entity_info.values() + cast( + dict[EdgeType, SerializedTFRecordInfo], + positive_label_entity_info, + ).values() ) - elif isinstance( - serialized_graph_metadata.positive_label_entity_info, - SerializedTFRecordInfo, - ): - serialized_positive_label_info_iterable = [ - serialized_graph_metadata.positive_label_entity_info - ] + elif isinstance(positive_label_entity_info, SerializedTFRecordInfo): + serialized_positive_label_info_iterable = [positive_label_entity_info] else: raise ValueError( - f"Expected positive labels to be a dictionary or of type `SerializedTFRecordInfo`, got {type(serialized_graph_metadata.positive_label_entity_info)}" + f"Expected positive labels to be a dictionary or of type `SerializedTFRecordInfo`, got {type(positive_label_entity_info)}" ) self.assertEqual( @@ -351,22 +350,21 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) expected_entity_types=graph_metadata_pb_wrapper.edge_types, ) serialized_negative_label_info_iterable: list[SerializedTFRecordInfo] - if isinstance( - serialized_graph_metadata.negative_label_entity_info, abc.Mapping - ): + negative_label_entity_info = ( + serialized_graph_metadata.negative_label_entity_info + ) + if isinstance(negative_label_entity_info, dict): serialized_negative_label_info_iterable = list( - serialized_graph_metadata.negative_label_entity_info.values() + cast( + dict[EdgeType, SerializedTFRecordInfo], + negative_label_entity_info, + ).values() ) - elif isinstance( - serialized_graph_metadata.negative_label_entity_info, - SerializedTFRecordInfo, - ): - serialized_negative_label_info_iterable = [ - serialized_graph_metadata.negative_label_entity_info - ] + elif isinstance(negative_label_entity_info, SerializedTFRecordInfo): + serialized_negative_label_info_iterable = [negative_label_entity_info] else: raise ValueError( - f"Expected negative labels to be a dictionary or of type `SerializedTFRecordInfo`, got {type(serialized_graph_metadata.negative_label_entity_info)}" + f"Expected negative labels to be a dictionary or of type `SerializedTFRecordInfo`, got {type(negative_label_entity_info)}" ) self.assertEqual( diff --git a/tests/unit/distributed/distributed_partitioner_test.py b/tests/unit/distributed/distributed_partitioner_test.py index 1ac56df31..89cecc6d3 100644 --- a/tests/unit/distributed/distributed_partitioner_test.py +++ b/tests/unit/distributed/distributed_partitioner_test.py @@ -1,7 +1,7 @@ # Originally taken from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/test/python/test_dist_random_partitioner.py from collections import abc, defaultdict -from typing import Iterable, Literal, MutableMapping, Optional, Tuple, Type, Union +from typing import Iterable, Literal, MutableMapping, Optional, Tuple, Type, Union, cast import torch import torch.multiprocessing as mp @@ -145,21 +145,25 @@ def _assert_graph_outputs( if isinstance(output_edge_partition_book, abc.Mapping): for edge_type in MOCKED_HETEROGENEOUS_EDGE_TYPES: entity_iterable.append( - ( + ( # ty: ignore[invalid-argument-type] edge_type, output_edge_partition_book[edge_type] if edge_type in output_edge_partition_book else None, - output_edge_index[edge_type], + output_edge_index[ # ty: ignore[invalid-argument-type] + edge_type + ], ) ) elif output_edge_partition_book is None: for edge_type in MOCKED_HETEROGENEOUS_EDGE_TYPES: entity_iterable.append( - ( + ( # ty: ignore[invalid-argument-type] edge_type, None, - output_edge_index[edge_type], + output_edge_index[ # ty: ignore[invalid-argument-type] + edge_type + ], ) ) else: @@ -310,7 +314,9 @@ def _assert_node_data_outputs( assert isinstance(output_graph, abc.Mapping), ( f"Homogeneous output detected from node {entity_name} for heterogeneous input" ) - entity_iterable = list(output_graph.items()) + entity_iterable = list( + cast("dict[EdgeType, GraphPartitionData]", output_graph).items() + ) else: assert isinstance(output_graph, GraphPartitionData), ( f"Heterogeneous output detected from node {entity_name} for homogeneous input" @@ -335,7 +341,9 @@ def _assert_node_data_outputs( assert isinstance(output_node_data, abc.Mapping), ( f"Found homogeneous node {entity_name} for heterogeneous input" ) - node_data = output_node_data[target_node_type] + node_data = cast( + "dict[NodeType, FeaturePartitionData]", output_node_data + )[target_node_type] else: assert isinstance(output_node_data, FeaturePartitionData), ( f"Found heterogeneous node {entity_name} for homogeneous input" @@ -358,7 +366,9 @@ def _assert_node_data_outputs( assert node_data.ids is not None node_data_ids = node_data.ids self.assert_tensor_equality( - tensor_a=node_ids, tensor_b=node_data.ids, dim=0 + tensor_a=node_ids, + tensor_b=node_data.ids, + dim=0, ) # Validate dimensions and values based on whether this is labels or features @@ -438,11 +448,15 @@ def _assert_edge_feature_outputs( assert isinstance(output_graph, abc.Mapping), ( "Homogeneous output detected from graph for heterogeneous input" ) + output_edge_feat_dict = cast( + "dict[EdgeType, FeaturePartitionData]", output_edge_feat + ) + output_graph_dict = cast("dict[EdgeType, GraphPartitionData]", output_graph) entity_iterable = [ ( edge_type, - output_edge_feat.get(edge_type, None), - output_graph[edge_type], + output_edge_feat_dict.get(edge_type, None), + output_graph_dict[edge_type], ) for edge_type in MOCKED_HETEROGENEOUS_EDGE_TYPES ] diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index a0f1a3594..bd9a46c4b 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -594,6 +594,7 @@ def test_fetch_node_types_labeled_homogeneous(self, mock_request): node_types = remote_dataset.fetch_node_types() self.assertIsNotNone(node_types) + assert node_types is not None # Type narrowing for the type checker self.assertIn(DEFAULT_HOMOGENEOUS_NODE_TYPE, node_types) def test_fetch_node_ids_auto_detects_default_node_type(self): @@ -1356,7 +1357,7 @@ def test_dispatches_server_method(self): def test_non_callable_returns_none(self): """Test that _call_func_on_server returns None for non-callable input.""" - result: None = _call_func_on_server("not_a_function") # type: ignore[arg-type] + result: None = _call_func_on_server("not_a_function") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] self.assertIsNone(result) def test_falls_back_for_non_server_function(self): diff --git a/tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py b/tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py index d3ccb3f86..dee46539c 100644 --- a/tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py +++ b/tests/unit/distributed/graph_store/shared_dist_sampling_producer_test.py @@ -9,6 +9,7 @@ from gigl.distributed.graph_store.shared_dist_sampling_producer import ( EPOCH_DONE_EVENT, ActiveEpochState, + CommandPayload, SharedDistSamplingBackend, SharedMpCommand, StartEpochCmd, @@ -139,9 +140,15 @@ def test_start_new_epoch_sampling_shuffle_refreshes_per_epoch(self) -> None: ) backend._initialized = True recorded: list[tuple[int, SharedMpCommand, object]] = [] - backend._enqueue_worker_command = lambda worker_rank, command, payload: ( # type: ignore[method-assign] + + def _record_command( + worker_rank: int, + command: SharedMpCommand, + payload: CommandPayload, + ) -> None: recorded.append((worker_rank, command, payload)) - ) + + backend._enqueue_worker_command = _record_command # type: ignore[method-assign] # ty: ignore[invalid-assignment] channel = MagicMock() input_tensor = torch.arange(6, dtype=torch.long) diff --git a/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py b/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py index cdddaa450..083625ffe 100644 --- a/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py +++ b/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py @@ -93,7 +93,7 @@ def test_early_stopping( for step_num, value in enumerate(mocked_criteria_values): has_metric_improved, should_early_stop = early_stopper.step(value=value) if model is not None: - model.foo += 1 # type: ignore # https://github.com/Snapchat/GiGL/issues/408 + model.foo += 1 # https://github.com/Snapchat/GiGL/issues/408 if step_num in improvement_steps: self.assertTrue(has_metric_improved) else: diff --git a/tests/unit/src/common/models/layers/count_min_sketch_test.py b/tests/unit/src/common/models/layers/count_min_sketch_test.py index 68703c1c8..de837d135 100644 --- a/tests/unit/src/common/models/layers/count_min_sketch_test.py +++ b/tests/unit/src/common/models/layers/count_min_sketch_test.py @@ -13,7 +13,7 @@ def test_count(self): # Initialize the CountMinSketch object cms = CountMinSketch(width=20, depth=5) candidate_ids = torch.tensor([1, 2, 2, 3, 3, 3, 4, 4, 4, 4], dtype=torch.long) - cms.add_torch_long_tensor(candidate_ids) # type: ignore + cms.add_torch_long_tensor(candidate_ids) # Check the total count self.assertEqual(cms.total(), 10) # Check the estimated count diff --git a/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py b/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py index 4196a1641..9892ee745 100644 --- a/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py +++ b/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py @@ -78,7 +78,7 @@ def preprocess_raw_sample_fn( loopy_dataset = LoopyIterableDataset(iterable_dataset=tf_dataset) loopy_datasets_map[condensed_node_type_str] = loopy_dataset - dataset = CombinedIterableDatasets(iterable_dataset_map=loopy_datasets_map) # type: ignore + dataset = CombinedIterableDatasets(iterable_dataset_map=loopy_datasets_map) dataset_iter = iter(dataset) for _ in range(15): dataset_sample = next(dataset_iter) @@ -107,7 +107,7 @@ def preprocess_raw_sample_fn( ) datasets_map[condensed_node_type_str] = tf_dataset - dataset = CombinedIterableDatasets(iterable_dataset_map=datasets_map) # type: ignore + dataset = CombinedIterableDatasets(iterable_dataset_map=datasets_map) dataset_iter = iter(dataset) for _ in range(10): dataset_sample = next(dataset_iter)