From 5c30b91f319d5e2c4dbf20c11d59d31cc322d249 Mon Sep 17 00:00:00 2001 From: nshah Date: Tue, 29 Jul 2025 19:59:32 +0000 Subject: [PATCH 1/4] add common utils --- .../common/README.md | 1 + .../common/__init__.py | 0 .../common/dist_checkpoint.py | 260 +++++++++++++++ .../common/distributed.py | 52 +++ .../common/graph_dataset.py | 297 ++++++++++++++++++ .../common/iterator_utils.py | 35 +++ .../common/torchrec/__init__.py | 0 .../common/torchrec/batch.py | 91 ++++++ .../common/torchrec/large_embedding_lookup.py | 28 ++ .../common/torchrec/utils.py | 170 ++++++++++ 10 files changed, 934 insertions(+) create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/README.md create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/__init__.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/distributed.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/iterator_utils.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/torchrec/__init__.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/torchrec/large_embedding_lookup.py create mode 100644 python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/README.md b/python/gigl/experimental/knowledge_graph_embedding/common/README.md new file mode 100644 index 00000000..f19e2657 --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/README.md @@ -0,0 +1 @@ +These utilities may be more generically reusable inside GiGL for other applications. diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/__init__.py b/python/gigl/experimental/knowledge_graph_embedding/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py b/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py new file mode 100644 index 00000000..e1d5e22b --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py @@ -0,0 +1,260 @@ +""" +This module provides functions to load and save distributed checkpoints +using the Torch Distributed Checkpointing API. +""" + +import tempfile +from typing import Optional, Union +from concurrent.futures import ThreadPoolExecutor, Future +import torch.nn as nn +import torch.optim as optim +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +from gigl.common import GcsUri, LocalUri, Uri +from gigl.common.logger import Logger +from gigl.src.common.utils.file_loader import FileLoader + +logger = Logger() + +class AppState(Stateful): + """ + This is a useful wrapper for checkpointing an application state. Since this + object is compliant with the Stateful protocol, DCP will automatically + call state_dict/loade_stat_dict as needed in the dcp.save/load APIs. + + We take advantage of this wrapper to hande calling distributed state dict + methods on the model and optimizer. + + See https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html + for more details. + """ + + MODEL_KEY = "model" + OPTIMIZER_KEY = "optimizer" + APP_STATE_KEY = "app" + + + def __init__(self, model: nn.Module, optimizer: Optional[optim.Optimizer] = None): + self.model = model + self.optimizer = optimizer + + def state_dict(self): + model_state_dict = self.model.state_dict() + optimizer_state_dict = self.optimizer.state_dict() if self.optimizer else None + return { + self.MODEL_KEY: model_state_dict, + self.OPTIMIZER_KEY: optimizer_state_dict, + } + + def load_state_dict(self, state_dict): + # sets our state dicts on the model and optimizer, now that we've loaded + self.model.load_state_dict(state_dict[self.MODEL_KEY]) + if self.optimizer and state_dict.get(self.OPTIMIZER_KEY): + self.optimizer.load_state_dict(state_dict[self.OPTIMIZER_KEY]) + + def to_state_dict(self) -> STATE_DICT_TYPE: + """ + Converts the AppState to a state dict that can be used with DCP. + """ + return { + self.APP_STATE_KEY: self, + } + +def load_checkpoint_from_uri( + state_dict: STATE_DICT_TYPE, + checkpoint_id: Uri, +): + assert isinstance(checkpoint_id, LocalUri) or isinstance( + checkpoint_id, GcsUri + ), "checkpoint_id must be a LocalUri or GcsUri." + local_uri = ( + checkpoint_id + if isinstance(checkpoint_id, LocalUri) + else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) + ) + if isinstance(checkpoint_id, GcsUri): + # If the URI is a GCS URI, we need to download it first + file_loader = FileLoader() + file_loader.load_directory(dir_uri_src=checkpoint_id, dir_uri_dst=local_uri) + logger.info(f"Downloaded checkpoint from GCS: {checkpoint_id} to {local_uri}") + + reader = dcp.FileSystemReader(path=local_uri.uri) + dcp.load(state_dict=state_dict, storage_reader=reader) + logger.info(f"Loaded checkpoint from {checkpoint_id}") + + +# def save_checkpoint_to_uri( +# state_dict: STATE_DICT_TYPE, +# checkpoint_id: Uri, +# should_save_asynchronously: bool = False, +# ): +# """ +# Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. + +# If the checkpoint_id is a GCS URI, it will first save the checkpoint +# locally and then upload it to GCS. +# """ + +# assert isinstance(checkpoint_id, LocalUri) or isinstance( +# checkpoint_id, GcsUri +# ), "checkpoint_id must be a LocalUri or GcsUri." +# local_uri = ( +# checkpoint_id +# if isinstance(checkpoint_id, LocalUri) +# else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) +# ) + +# writer = dcp.FileSystemWriter(path=local_uri.uri) +# dcp_fn = dcp.async_save if should_save_asynchronously else dcp.save +# dcp_fn(state_dict, storage_writer=writer) +# if isinstance(checkpoint_id, GcsUri): +# # If the URI is a GCS URI, we need to ensure the file is uploaded +# # to GCS after saving it locally. +# file_loader = FileLoader() +# file_loader.load_directory(dir_uri_src=local_uri, dir_uri_dst=checkpoint_id) +# logger.info(f"Uploaded checkpoint to GCS: {checkpoint_id}") + + +# def save_checkpoint_to_uri( +# state_dict: STATE_DICT_TYPE, +# checkpoint_id: Uri, +# should_save_asynchronously: bool = False, +# ) -> Union[Future[Uri], Uri]: +# """ +# Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. + +# If the checkpoint_id is a GCS URI, it will first save the checkpoint +# locally and then upload it to GCS. + +# If `should_save_asynchronously` is True, the save operation will be +# performed asynchronously, returning a Future object. Otherwise, it will +# block until the save operation is complete. + +# Args: +# state_dict (STATE_DICT_TYPE): The state dictionary to save. +# checkpoint_id (Uri): The URI where the checkpoint will be saved. +# should_save_asynchronously (bool): If True, saves the checkpoint asynchronously. +# Returns: +# Union[Future[Uri], Uri]: The URI where the checkpoint was saved, or +# a Future object if saved asynchronously. +# Raises: +# AssertionError: If checkpoint_id is not a LocalUri or GcsUri. +# """ +# logger.info(f"inside save_checkpoint_to_uri with condition: {should_save_asynchronously}") +# def _save_checkpoint( +# state_dict: STATE_DICT_TYPE, checkpoint_id: Uri, should_save_asynchronously: bool = False +# ) -> Uri: +# assert isinstance(checkpoint_id, LocalUri) or isinstance( +# checkpoint_id, GcsUri +# ), "checkpoint_id must be a LocalUri or GcsUri." +# local_uri = ( +# checkpoint_id +# if isinstance(checkpoint_id, LocalUri) +# else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) +# ) + +# writer = dcp.FileSystemWriter(path=local_uri.uri) + +# checkpoint_future: Optional[Future] = None +# if should_save_asynchronously: +# checkpoint_future = dcp.async_save(state_dict, storage_writer=writer) +# else: +# dcp.save(state_dict, storage_writer=writer) + +# if checkpoint_future: +# checkpoint_future.result() # Wait for the async save to complete + +# if isinstance(checkpoint_id, GcsUri): +# # If the URI is a GCS URI, we need to ensure the file is uploaded +# # to GCS after saving it locally. +# file_loader = FileLoader() +# file_loader.load_directory(dir_uri_src=local_uri, dir_uri_dst=checkpoint_id) +# logger.info(f"Uploaded checkpoint to GCS: {checkpoint_id}") + +# return checkpoint_id + + +# if should_save_asynchronously: +# logger.info(f"Saving checkpoint asynchronously to {checkpoint_id}") +# executor = ThreadPoolExecutor(max_workers=1) +# # Use a ThreadPoolExecutor to run the save operation asynchronously +# # This allows the main thread to continue while the checkpoint is being saved. +# # The Future object will be returned, which can be used to check the status of the +# # save operation or to wait for it to complete. +# future = executor.submit( +# _save_checkpoint, state_dict, checkpoint_id, should_save_asynchronously +# ) +# return future +# else: +# logger.info(f"Saving checkpoint synchronously to {checkpoint_id}") +# return _save_checkpoint(state_dict, checkpoint_id, should_save_asynchronously) + + +def save_checkpoint_to_uri( + state_dict: STATE_DICT_TYPE, + checkpoint_id: Uri, + should_save_asynchronously: bool = False, +) -> Union[Future[Uri], Uri]: + """ + Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. + + If the checkpoint_id is a GCS URI, it will first save the checkpoint + locally and then upload it to GCS. + + If `should_save_asynchronously` is True, the save operation will be + performed asynchronously, returning a Future object. Otherwise, it will + block until the save operation is complete. + + Args: + state_dict (STATE_DICT_TYPE): The state dictionary to save. + checkpoint_id (Uri): The URI where the checkpoint will be saved. + should_save_asynchronously (bool): If True, saves the checkpoint asynchronously. + Returns: + Union[Future[Uri], Uri]: The URI where the checkpoint was saved, or + a Future object if saved asynchronously. + Raises: + AssertionError: If checkpoint_id is not a LocalUri or GcsUri. + """ + + def _save_checkpoint( + checkpoint_id: Uri, local_uri: LocalUri, checkpoint_future: Optional[Future] = None + ) -> Uri: + # If we have a checkpoint future, we will wait for it to complete (async save) + if checkpoint_future: + checkpoint_future.result() + + if isinstance(checkpoint_id, GcsUri): + # If the URI is a GCS URI, we need to ensure the file is uploaded + # to GCS after saving it locally. + file_loader = FileLoader() + file_loader.load_directory(dir_uri_src=local_uri, dir_uri_dst=checkpoint_id) + logger.info(f"Uploaded checkpoint to GCS: {checkpoint_id}") + + return checkpoint_id + + assert isinstance(checkpoint_id, LocalUri) or isinstance( + checkpoint_id, GcsUri + ), "checkpoint_id must be a LocalUri or GcsUri." + local_uri = ( + checkpoint_id + if isinstance(checkpoint_id, LocalUri) + else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) + ) + + writer = dcp.FileSystemWriter(path=local_uri.uri) + + if should_save_asynchronously: + logger.info(f"Saving checkpoint asynchronously to {checkpoint_id}") + checkpoint_future = dcp.async_save(state_dict, storage_writer=writer) + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit( + _save_checkpoint, checkpoint_id, local_uri, checkpoint_future + ) + return future + else: + logger.info(f"Saving checkpoint synchronously to {checkpoint_id}") + dcp.save(state_dict, storage_writer=writer) + return _save_checkpoint(checkpoint_id, local_uri, None) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/distributed.py b/python/gigl/experimental/knowledge_graph_embedding/common/distributed.py new file mode 100644 index 00000000..fd4befc8 --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/distributed.py @@ -0,0 +1,52 @@ +import os +from typing import Tuple + +from gigl.common.logger import Logger +from gigl.distributed.dist_context import DistributedContext + +logger = Logger() + + +def set_process_env_vars_for_torch_dist( + process_number_on_current_machine: int, + num_processes_on_current_machine: int, + machine_context: DistributedContext, + port: int = 29500, +) -> Tuple[int, int, int, int]: + """ + This function sets the environment variables required for + distributed training with PyTorch. It assumes a multi-machine + setup where each machine has a number of processes running. + The number of machines and rendevous is determined by the + `machine_context` provided. + + Args: + process_number_on_current_machine (int): The process number on the current machine. + num_processes_on_current_machine (int): The total number of processes on the current machine. + machine_context (DistributedContext): The context containing information about the distributed setup. + + Returns: + Tuple[int, int, int, int]: A tuple containing: + - local_rank (int): The local rank of the process on the current machine. + - rank (int): The global rank of the process across all machines. + - local_world_size (int): The number of processes on the current machine. + - world_size (int): The total number of processes across all machines. + """ + # Set the environment variables for the current process + # This is required for distributed training + os.environ["LOCAL_RANK"] = str(process_number_on_current_machine) + os.environ["RANK"] = str( + machine_context.global_rank * num_processes_on_current_machine + + process_number_on_current_machine + ) + os.environ["WORLD_SIZE"] = str( + num_processes_on_current_machine * machine_context.global_world_size + ) + os.environ["LOCAL_WORLD_SIZE"] = str(num_processes_on_current_machine) + os.environ["MASTER_ADDR"] = machine_context.main_worker_ip_address + os.environ["MASTER_PORT"] = str(port) + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + world_size = int(os.environ["WORLD_SIZE"]) + return local_rank, rank, local_world_size, world_size diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py new file mode 100644 index 00000000..db532c86 --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py @@ -0,0 +1,297 @@ +# This file can probably be gigl-generic utilities. +# We include a few graph-related IterableDatasets backed by GCS and BigQuery + +from typing import Dict, Iterator, List, Optional, TypedDict + +import numpy as np +import orjson +import pyarrow.parquet as pq +import torch +from google.cloud.bigquery_storage import BigQueryReadClient, types +from torch.utils.data._utils.worker import WorkerInfo + +from gigl.common.types.uri.gcs_uri import GcsUri +from gigl.common.types.uri.uri_factory import UriFactory +from gigl.common.utils.torch_training import get_rank, get_world_size +from gigl.src.common.utils.file_loader import FileLoader +from gigl.src.training.v1.lib.data_loaders.utils import ( + get_data_split_for_current_worker, +) + +SRC_FIELD = "src" +DST_FIELD = "dst" +CONDENSED_EDGE_TYPE_FIELD = "condensed_edge_type" + + +HeterogeneousGraphEdgeDict = TypedDict( + "HeterogeneousGraphEdgeDict", + { + SRC_FIELD: str, + DST_FIELD: str, + CONDENSED_EDGE_TYPE_FIELD: str, + }, +) + + +class GcsIterableDataset(torch.utils.data.IterableDataset): + def __init__( + self, + file_uris: List[GcsUri], + seed: int = 42, + ) -> None: + """ + Args: + file_uris (List[UriType]): Holds all the uris for the dataset. + Note: for now only uris supported are ones that `tf.data.TFRecordDataset` + can load from default; i.e .GcsUri and LocalUri. + We permute the file list based on a seed as a means of "shuffling" the data + on a file-level (rather than sample-level, as would be possible in cases + where the data fits in memory. + """ + assert isinstance(file_uris, list) + self._file_uris: np.ndarray = np.random.RandomState(seed).permutation( + np.array([uri.uri for uri in file_uris]) + ) + self._file_loader = None + + def _iterator_init(self): + # Initialize it here to avoid client pickling issues for multiprocessing. + if not self._file_loader: + self._file_loader = FileLoader() + + # Need to first split the work based on worker information + current_worker_file_uris_to_process = get_data_split_for_current_worker( + self._file_uris + ) + + return current_worker_file_uris_to_process + + def __iter__(self) -> Iterator[Dict]: + raise NotImplemented + + +class GcsJSONLIterableDataset(GcsIterableDataset): + def __init__( + self, + file_uris: List[GcsUri], + seed: int = 42, + ) -> None: + """ + Args: + file_uris (List[UriType]): Holds all the uris for the dataset. + Note: for now only uris supported are ones that `tf.data.TFRecordDataset` + can load from default; i.e .GcsUri and LocalUri. + We permute the file list based on a seed as a means of "shuffling" the data + on a file-level (rather than sample-level, as would be possible in cases + where the data fits in memory. + """ + super().__init__(file_uris=file_uris, seed=seed) + + def __iter__(self) -> Iterator[Dict]: + current_worker_file_uris_to_process = self._iterator_init() + + for file in current_worker_file_uris_to_process: + tfh = self._file_loader.load_to_temp_file( + file_uri_src=UriFactory.create_uri(file), delete=True + ) + with open(tfh.name, "rb") as f: + # Read the file and yield each line + for line in f: + data = orjson.loads(line) + yield data + + +class GcsParquetIterableDataset(GcsIterableDataset): + def __init__( + self, file_uris: List[GcsUri], seed: int = 42, batch_size: Optional[int] = None + ) -> None: + """ + Args: + file_uris (List[UriType]): Holds all the uris for the dataset. + Note: for now only uris supported are ones that `tf.data.TFRecordDataset` + can load from default; i.e .GcsUri and LocalUri. + We permute the file list based on a seed as a means of "shuffling" the data + on a file-level (rather than sample-level, as would be possible in cases + where the data fits in memory. + """ + self._iter_batches_kwargs = {"batch_size": batch_size} if batch_size else {} + super().__init__(file_uris=file_uris, seed=seed) + + def __iter__(self) -> Iterator[Dict]: + # Need to first split the work based on worker information + current_worker_file_uris_to_process = self._iterator_init() + + for file in current_worker_file_uris_to_process: + tfh = self._file_loader.load_to_temp_file( + file_uri_src=UriFactory.create_uri(file), delete=True + ) + parquet_file = pq.ParquetFile(tfh.name) + + for batch in parquet_file.iter_batches(**self._iter_batches_kwargs): + df = batch.to_pandas( + split_blocks=True, self_destruct=True + ) # Fast, memory-friendly + for row in df.itertuples(index=False, name=None): + yield dict(zip(df.columns, row)) + + +class BigQueryIterableDataset(torch.utils.data.IterableDataset): + def __init__( + self, + table: str, # Format: "project.dataset.table" + random_column: str, + project: Optional[str] = None, + selected_fields=None, + ): + """ + Enables reading from a BigQuery table in a sharded manner. + This is done by using a random column to split the data into bins + based on the number of workers in the global dataloading process id. + + The dataset is read in a sharded manner, where each worker reads a specific + range of rows designated by conditions on the random column. + The random column is used to ensure that the data is evenly distributed + across the workers. + + Args: + table (str): BigQuery table in the format "project.dataset.table" + random_column (str): Column name used for random sampling. Used to ensure sharded reading of data. + project (Optional[str]): Project ID if not included in the table string + selected_fields (Optional[List[str]]): List of fields to select from the table + """ + + self.project = f"projects/{project}" if project else None + self.table = table + self.selected_fields = selected_fields or [] + if self.selected_fields and (random_column not in self.selected_fields): + self.selected_fields.append(random_column) + self.random_column = random_column + + def _create_read_session( + self, client: BigQueryReadClient, row_restriction: str = "" + ): + project, dataset, table = self.table.split(".") + table_path = f"projects/{project}/datasets/{dataset}/tables/{table}" + + read_options = types.ReadSession.TableReadOptions( + selected_fields=self.selected_fields, + row_restriction=row_restriction, + ) + + session = types.ReadSession( + table=table_path, + data_format=types.DataFormat.ARROW, + read_options=read_options, + ) + + return client.create_read_session( + parent=self.project, + read_session=session, + max_stream_count=1, + ) + + def __iter__(self): + client = BigQueryReadClient() + + worker_info: Optional[WorkerInfo] = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info else 1 + worker_id = worker_info.id if worker_info else 0 + global_worker_id = (get_rank() * num_workers) + worker_id + global_num_workers = num_workers * get_world_size() + + bin_width = 1.0 / global_num_workers + bin_start, bin_end = ( + global_worker_id * bin_width, + (global_worker_id + 1) * bin_width, + ) + row_restriction = f"row_id BETWEEN {bin_start} AND {bin_end}" + + session = self._create_read_session( + client=client, row_restriction=row_restriction + ) + stream = session.streams[0].name + reader = client.read_rows(stream) + rows = reader.rows(session) + + for row in rows: + yield {key: value.as_py() for key, value in row.items()} + + +class GcsJSONLHeterogeneousGraphIterableDataset(GcsJSONLIterableDataset): + def __init__( + self, + file_uris: List[GcsUri], + src_field: str = SRC_FIELD, + dst_field: str = DST_FIELD, + condensed_edge_type_field: str = CONDENSED_EDGE_TYPE_FIELD, + seed: int = 42, + ) -> None: + self._src_field = src_field + self._dst_field = dst_field + self._condensed_edge_type_field = condensed_edge_type_field + super().__init__(file_uris=file_uris, seed=seed) + + def __iter__(self) -> Iterator[HeterogeneousGraphEdgeDict]: + for data in super().__iter__(): + # Convert the data to a filtered dictionary with just essential keys. + yield { + SRC_FIELD: data[self._src_field], + DST_FIELD: data[self._dst_field], + CONDENSED_EDGE_TYPE_FIELD: data[self._condensed_edge_type_field], + } + + +class GcsParquetHeterogeneousGraphIterableDataset(GcsParquetIterableDataset): + def __init__( + self, + file_uris: List[GcsUri], + src_field: str = SRC_FIELD, + dst_field: str = DST_FIELD, + condensed_edge_type_field: str = CONDENSED_EDGE_TYPE_FIELD, + seed: int = 42, + ) -> None: + self._src_field = src_field + self._dst_field = dst_field + self._condensed_edge_type_field = condensed_edge_type_field + super().__init__(file_uris=file_uris, seed=seed) + + def __iter__(self) -> Iterator[HeterogeneousGraphEdgeDict]: + for data in super().__iter__(): + # Convert the data to a filtered dictionary with just essential keys. + yield { + SRC_FIELD: data[self._src_field], + DST_FIELD: data[self._dst_field], + CONDENSED_EDGE_TYPE_FIELD: data[self._condensed_edge_type_field], + } + + +class BigQueryHeterogeneousGraphIterableDataset(BigQueryIterableDataset): + def __init__( + self, + table: str, + random_column: str, + project: Optional[str] = None, + src_field: str = SRC_FIELD, + dst_field: str = DST_FIELD, + condensed_edge_type_field: str = CONDENSED_EDGE_TYPE_FIELD, + **kwargs, + ) -> None: + self._src_field = src_field + self._dst_field = dst_field + self._condensed_edge_type_field = condensed_edge_type_field + super().__init__( + table=table, + project=project, + random_column=random_column, + selected_fields=[src_field, dst_field, condensed_edge_type_field], + **kwargs, + ) + + def __iter__(self) -> Iterator[HeterogeneousGraphEdgeDict]: + for row in super().__iter__(): + # Convert the data to a filtered dictionary with just essential keys. + yield { + SRC_FIELD: row[self._src_field], + DST_FIELD: row[self._dst_field], + CONDENSED_EDGE_TYPE_FIELD: row[self._condensed_edge_type_field], + } diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/iterator_utils.py b/python/gigl/experimental/knowledge_graph_embedding/common/iterator_utils.py new file mode 100644 index 00000000..edf69ee9 --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/iterator_utils.py @@ -0,0 +1,35 @@ +import itertools +from typing import Iterator + + +def batched(it: Iterator, n: int): + """ + Create batches of up to n elements from an iterator. + + Takes an input iterator and yields sub-iterators, each containing up to n elements. + This is useful for processing data in chunks or creating batched operations for + efficient data pipeline processing. + + Args: + it (Iterator): The input iterator to batch. + n (int): Maximum number of elements per batch. Must be >= 1. + + Yields: + Iterator: Sub-iterators containing up to n elements from the input iterator. + The last batch may contain fewer than n elements if the input + iterator is exhausted. + + Raises: + AssertionError: If n < 1. + + Example: + >>> data = iter([1, 2, 3, 4, 5, 6, 7]) + >>> for batch in batched(data, 3): + ... print(list(batch)) + [1, 2, 3] + [4, 5, 6] + [7] + """ + assert n >= 1 + for x in it: + yield itertools.chain((x,), itertools.islice(it, n - 1)) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/__init__.py b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py new file mode 100644 index 00000000..22bd128a --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py @@ -0,0 +1,91 @@ +import abc +from dataclasses import dataclass, field, make_dataclass +from typing import Dict + +import torch +from torchrec.streamable import Pipelineable + + +class BatchBase(Pipelineable, abc.ABC): + """ + This class extends https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/utils.py#L28 + to be reusable for any batch. + + This enables use with certain torchrec tools like pipelined training, which overlaps + dataloading device transfer (copy to GPU), inter-device ocmmunications, and fwd/bkwd. + """ + + @abc.abstractmethod + def as_dict(self) -> Dict: + raise NotImplementedError + + def to(self, device: torch.device, non_blocking: bool = False): + args = {} + for feature_name, feature_value in self.as_dict().items(): + args[feature_name] = feature_value.to( + device=device, non_blocking=non_blocking + ) + return self.__class__(**args) + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + for feature_value in self.as_dict().values(): + feature_value.record_stream(stream) + + def pin_memory(self): + args = {} + for feature_name, feature_value in self.as_dict().items(): + args[feature_name] = feature_value.pin_memory() + return self.__class__(**args) + + def __repr__(self) -> str: + def obj2str(v): + return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" + + return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) + + @property + def batch_size(self) -> int: + for tensor in self.as_dict().values(): + if tensor is None: + continue + if not isinstance(tensor, torch.Tensor): + continue + return tensor.shape[0] + raise Exception("Could not determine batch size from tensors.") + + +@dataclass +class DataclassBatch(BatchBase): + """ + Makes it easy to create a Batch with some generic dataclass. + """ + + @classmethod + def feature_names(cls): + return list(cls.__dataclass_fields__.keys()) + + def as_dict(self): + return { + feature_name: getattr(self, feature_name) + for feature_name in self.feature_names() + if hasattr(self, feature_name) + } + + @staticmethod + def from_schema(name: str, schema): + """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" + return make_dataclass( + cls_name=name, + fields=[(name, torch.Tensor, field(default=None)) for name in schema.names], + bases=(DataclassBatch,), + ) + + @staticmethod + def from_fields(name: str, fields: dict): + return make_dataclass( + cls_name=name, + fields=[ + (_name, _type, field(default=None)) for _name, _type in fields.items() + ], + bases=(DataclassBatch,), + ) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/large_embedding_lookup.py b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/large_embedding_lookup.py new file mode 100644 index 00000000..e0050aca --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/large_embedding_lookup.py @@ -0,0 +1,28 @@ +from typing import List + +import torch +import torch.nn as nn +import torchrec + +from gigl.common.logger import Logger + +logger = Logger() + + +class LargeEmbeddingLookup(nn.Module): + def __init__(self, embeddings_config: List[torchrec.EmbeddingBagConfig]): + super().__init__() + self.ebc = torchrec.EmbeddingBagCollection( + tables=embeddings_config, + device=torch.device("meta"), + ) + + logger.info( + f"EmbeddingBagCollection named parameters: {list(self.ebc.named_parameters())}" + ) + + def forward( + self, sparse_features: torchrec.KeyedJaggedTensor + ) -> torchrec.KeyedTensor: + # Forward pass through the embedding bag collection + return self.ebc(sparse_features) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py new file mode 100644 index 00000000..b78a3f9d --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py @@ -0,0 +1,170 @@ +from typing import Any, Dict, Iterable, Optional, Type + +import torch +import torch.nn as nn +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torch.optim import Optimizer +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.fbgemm_qcomm_codec import ( + CommType, + QCommsConfig, + get_qcomm_codecs_registry, +) +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.storage_reservations import ( + HeuristicalStorageReservation, +) +from torchrec.distributed.types import ShardingPlan +from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad + +from gigl.common.logger import Logger + +logger = Logger() + + +def maybe_shard_model( + model, + device: torch.device, + sharding_plan: ShardingPlan = None, +): + """ + If in a distributed environment, apply DistributedModelParallel to the model, + using an optionally specified ShardingPlan. + If not in a distributed environment, return the model directly. + Args: + model: The model to be wrapped. + device: The device to use for the model. + sharding_plan: An optional ShardingPlan to use for the DistributedModelParallel. + Returns: + The model wrapped in DistributedModelParallel if in a distributed environment, + otherwise the model itself. + """ + + if torch.distributed.is_initialized(): + # Build a sharding plan + logger.info("***** Wrapping in DistributedModelParallel *****") + logger.info(f"Model before wrapping: {model}") + model = DistributedModelParallel( + module=model, + device=device, + plan=sharding_plan, + ) + logger.info(f"Model after wrapping: {model}") + + return model + + +def get_sharding_plan( + model: nn.Module, + batch_size: int, + local_world_size: int, + world_size: int, + use_cuda: bool = False, + storage_reservation_percentage: float = 0.15, + qcomm_forward_precision: CommType = CommType.FP32, + qcomm_backward_precision: CommType = CommType.FP32, +) -> ShardingPlan: + """ + Create a sharding plan for the model using the EmbeddingShardingPlanner. + Args: + model: The model to be sharded. + batch_size: The batch size for the sharding plan. + use_cuda: Whether to use CUDA for the sharding plan. + storage_reservation_percentage: The percentage of storage reservation. + qcomm_forward_precision: The precision for forward communication (can be FP32, FP16, etc.). + qcomm_backward_precision: The precision for backward communication (can be FP32, FP16, etc.). + Returns: + A ShardingPlan object representing the sharding plan for the model. + """ + + topology = Topology( + world_size=world_size, + local_world_size=local_world_size, # TODO(nshah): We should expose this in torch_training.py + compute_device="cuda" if use_cuda else "cpu", + hbm_cap=torch.cuda.get_device_properties(0).total_memory if use_cuda else 0, + ) + + planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + storage_reservation=HeuristicalStorageReservation( + percentage=storage_reservation_percentage + ), # bumping this % can alleviate OOM issues by being more conservative + ) + + # Enable custom fwd/bkwd precisions for QComms when using GPU + qcomm_codecs_registry = ( + get_qcomm_codecs_registry( + qcomms_config=QCommsConfig( + forward_precision=qcomm_forward_precision, + backward_precision=qcomm_backward_precision, + ) + ) + if use_cuda + else None + ) + ebc_sharder = EmbeddingBagCollectionSharder( + qcomm_codecs_registry=qcomm_codecs_registry + ) + plan = planner.collective_plan( + model, [ebc_sharder], torch.distributed.GroupMember.WORLD + ) + return plan + + +def apply_sparse_optimizer( + parameters: Iterable[nn.Parameter], + optimizer_cls: Type[Optimizer] = None, + optimizer_kwargs: Dict[str, Any] = dict(), +) -> None: + """ + Apply a sparse optimizer to the sparse/EBC parts of a model. + This optimizer is fused, so it will be applied directly in the backward pass. + + This should only be used for sparse parameters. + + Args: + parameters (Iterable[nn.Parameter]): The sparse parameters to apply the optimizer to. + optimizer_cls (Type[Optimizer], optional): The optimizer class to use. Defaults to RowWiseAdagrad. + optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. + """ + + if not optimizer_cls and optimizer_kwargs: + optimizer_cls = RowWiseAdagrad + optimizer_kwargs = {"lr": 0.01} + apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) + + +def apply_dense_optimizer( + model: nn.Module, + optimizer_cls: Type[Optimizer], + optimizer_kwargs: Dict[str, Any] = dict(), +) -> Optional[KeyedOptimizerWrapper]: + """ + This creates an optimizer for the dense parts of the model. + It uses the `KeyedOptimizerWrapper` to wrap the optimizer. + + Args: + model (nn.Module): The model containing dense parameters. + optimizer_cls (Type[Optimizer]): The optimizer class to use for dense parameters. + optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. + + Returns: + Optional[KeyedOptimizerWrapper]: A wrapped optimizer for dense parameters, or + None if no dense parameters are found. + """ + dense_params = dict(in_backward_optimizer_filter(model.named_parameters())) + if not dense_params: + # We cannot apply a dense optimizer if there are no dense parameters. + logger.warning("No dense parameters found in the model.") + return None + dense_optimizer = KeyedOptimizerWrapper( + dict(in_backward_optimizer_filter(model.named_parameters())), + lambda params: optimizer_cls(params, **optimizer_kwargs), + ) + return dense_optimizer From a475af2aecd76a69ed098a2c05e427fe7b3ec4d6 Mon Sep 17 00:00:00 2001 From: nshah Date: Tue, 29 Jul 2025 20:52:12 +0000 Subject: [PATCH 2/4] fix --- .../common/dist_checkpoint.py | 109 +----------------- 1 file changed, 1 insertion(+), 108 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py b/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py index e1d5e22b..41d3a715 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py +++ b/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py @@ -23,7 +23,7 @@ class AppState(Stateful): """ This is a useful wrapper for checkpointing an application state. Since this object is compliant with the Stateful protocol, DCP will automatically - call state_dict/loade_stat_dict as needed in the dcp.save/load APIs. + call state_dict/load_state_dict as needed in the dcp.save/load APIs. We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. @@ -86,113 +86,6 @@ def load_checkpoint_from_uri( logger.info(f"Loaded checkpoint from {checkpoint_id}") -# def save_checkpoint_to_uri( -# state_dict: STATE_DICT_TYPE, -# checkpoint_id: Uri, -# should_save_asynchronously: bool = False, -# ): -# """ -# Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. - -# If the checkpoint_id is a GCS URI, it will first save the checkpoint -# locally and then upload it to GCS. -# """ - -# assert isinstance(checkpoint_id, LocalUri) or isinstance( -# checkpoint_id, GcsUri -# ), "checkpoint_id must be a LocalUri or GcsUri." -# local_uri = ( -# checkpoint_id -# if isinstance(checkpoint_id, LocalUri) -# else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) -# ) - -# writer = dcp.FileSystemWriter(path=local_uri.uri) -# dcp_fn = dcp.async_save if should_save_asynchronously else dcp.save -# dcp_fn(state_dict, storage_writer=writer) -# if isinstance(checkpoint_id, GcsUri): -# # If the URI is a GCS URI, we need to ensure the file is uploaded -# # to GCS after saving it locally. -# file_loader = FileLoader() -# file_loader.load_directory(dir_uri_src=local_uri, dir_uri_dst=checkpoint_id) -# logger.info(f"Uploaded checkpoint to GCS: {checkpoint_id}") - - -# def save_checkpoint_to_uri( -# state_dict: STATE_DICT_TYPE, -# checkpoint_id: Uri, -# should_save_asynchronously: bool = False, -# ) -> Union[Future[Uri], Uri]: -# """ -# Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. - -# If the checkpoint_id is a GCS URI, it will first save the checkpoint -# locally and then upload it to GCS. - -# If `should_save_asynchronously` is True, the save operation will be -# performed asynchronously, returning a Future object. Otherwise, it will -# block until the save operation is complete. - -# Args: -# state_dict (STATE_DICT_TYPE): The state dictionary to save. -# checkpoint_id (Uri): The URI where the checkpoint will be saved. -# should_save_asynchronously (bool): If True, saves the checkpoint asynchronously. -# Returns: -# Union[Future[Uri], Uri]: The URI where the checkpoint was saved, or -# a Future object if saved asynchronously. -# Raises: -# AssertionError: If checkpoint_id is not a LocalUri or GcsUri. -# """ -# logger.info(f"inside save_checkpoint_to_uri with condition: {should_save_asynchronously}") -# def _save_checkpoint( -# state_dict: STATE_DICT_TYPE, checkpoint_id: Uri, should_save_asynchronously: bool = False -# ) -> Uri: -# assert isinstance(checkpoint_id, LocalUri) or isinstance( -# checkpoint_id, GcsUri -# ), "checkpoint_id must be a LocalUri or GcsUri." -# local_uri = ( -# checkpoint_id -# if isinstance(checkpoint_id, LocalUri) -# else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) -# ) - -# writer = dcp.FileSystemWriter(path=local_uri.uri) - -# checkpoint_future: Optional[Future] = None -# if should_save_asynchronously: -# checkpoint_future = dcp.async_save(state_dict, storage_writer=writer) -# else: -# dcp.save(state_dict, storage_writer=writer) - -# if checkpoint_future: -# checkpoint_future.result() # Wait for the async save to complete - -# if isinstance(checkpoint_id, GcsUri): -# # If the URI is a GCS URI, we need to ensure the file is uploaded -# # to GCS after saving it locally. -# file_loader = FileLoader() -# file_loader.load_directory(dir_uri_src=local_uri, dir_uri_dst=checkpoint_id) -# logger.info(f"Uploaded checkpoint to GCS: {checkpoint_id}") - -# return checkpoint_id - - -# if should_save_asynchronously: -# logger.info(f"Saving checkpoint asynchronously to {checkpoint_id}") -# executor = ThreadPoolExecutor(max_workers=1) -# # Use a ThreadPoolExecutor to run the save operation asynchronously -# # This allows the main thread to continue while the checkpoint is being saved. -# # The Future object will be returned, which can be used to check the status of the -# # save operation or to wait for it to complete. -# future = executor.submit( -# _save_checkpoint, state_dict, checkpoint_id, should_save_asynchronously -# ) -# return future -# else: -# logger.info(f"Saving checkpoint synchronously to {checkpoint_id}") -# return _save_checkpoint(state_dict, checkpoint_id, should_save_asynchronously) - - def save_checkpoint_to_uri( state_dict: STATE_DICT_TYPE, checkpoint_id: Uri, From d23a73eca151e0930b202d1eb98c4285d8e3b724 Mon Sep 17 00:00:00 2001 From: nshah Date: Wed, 20 Aug 2025 19:01:37 +0000 Subject: [PATCH 3/4] comments --- .../common/dist_checkpoint.py | 31 ++++++++++++------- .../common/torchrec/batch.py | 1 + 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py b/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py index 41d3a715..2859fd59 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py +++ b/python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py @@ -4,21 +4,23 @@ """ import tempfile +from concurrent.futures import Future, ThreadPoolExecutor from typing import Optional, Union -from concurrent.futures import ThreadPoolExecutor, Future + +import torch.distributed.checkpoint as dcp import torch.nn as nn import torch.optim as optim -import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE - +from torch.distributed.checkpoint.stateful import Stateful from gigl.common import GcsUri, LocalUri, Uri from gigl.common.logger import Logger +from gigl.common.utils.local_fs import delete_local_directory from gigl.src.common.utils.file_loader import FileLoader logger = Logger() + class AppState(Stateful): """ This is a useful wrapper for checkpointing an application state. Since this @@ -36,7 +38,6 @@ class AppState(Stateful): OPTIMIZER_KEY = "optimizer" APP_STATE_KEY = "app" - def __init__(self, model: nn.Module, optimizer: Optional[optim.Optimizer] = None): self.model = model self.optimizer = optimizer @@ -63,6 +64,7 @@ def to_state_dict(self) -> STATE_DICT_TYPE: self.APP_STATE_KEY: self, } + def load_checkpoint_from_uri( state_dict: STATE_DICT_TYPE, checkpoint_id: Uri, @@ -70,21 +72,26 @@ def load_checkpoint_from_uri( assert isinstance(checkpoint_id, LocalUri) or isinstance( checkpoint_id, GcsUri ), "checkpoint_id must be a LocalUri or GcsUri." - local_uri = ( - checkpoint_id - if isinstance(checkpoint_id, LocalUri) - else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) - ) + + created_temp_local_uri = False if isinstance(checkpoint_id, GcsUri): # If the URI is a GCS URI, we need to download it first + local_uri = LocalUri(tempfile.mkdtemp(prefix="checkpoint")) + created_temp_local_uri = True file_loader = FileLoader() file_loader.load_directory(dir_uri_src=checkpoint_id, dir_uri_dst=local_uri) logger.info(f"Downloaded checkpoint from GCS: {checkpoint_id} to {local_uri}") + else: + local_uri = checkpoint_id reader = dcp.FileSystemReader(path=local_uri.uri) dcp.load(state_dict=state_dict, storage_reader=reader) logger.info(f"Loaded checkpoint from {checkpoint_id}") + # Clean up the temp local uri if it was created + if created_temp_local_uri: + delete_local_directory(local_path=local_uri) + def save_checkpoint_to_uri( state_dict: STATE_DICT_TYPE, @@ -113,7 +120,9 @@ def save_checkpoint_to_uri( """ def _save_checkpoint( - checkpoint_id: Uri, local_uri: LocalUri, checkpoint_future: Optional[Future] = None + checkpoint_id: Uri, + local_uri: LocalUri, + checkpoint_future: Optional[Future] = None, ) -> Uri: # If we have a checkpoint future, we will wait for it to complete (async save) if checkpoint_future: diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py index 22bd128a..37308597 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py +++ b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/batch.py @@ -54,6 +54,7 @@ def batch_size(self) -> int: raise Exception("Could not determine batch size from tensors.") +# TODO(nshah-sc): Consider folding BatchBase into this class. @dataclass class DataclassBatch(BatchBase): """ From 5be12b4b9015f385e06e7ed11995f3dbab250b03 Mon Sep 17 00:00:00 2001 From: nshah Date: Thu, 21 Aug 2025 02:40:12 +0000 Subject: [PATCH 4/4] type issues --- mypy.ini | 5 +- .../common/graph_dataset.py | 66 +++++++++---------- .../common/torchrec/utils.py | 2 +- 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/mypy.ini b/mypy.ini index 27dbda79..644bde68 100644 --- a/mypy.ini +++ b/mypy.ini @@ -34,7 +34,7 @@ ignore_missing_imports = True [mypy-parameterized.*] ignore_missing_imports = True -[mypy-pyarrow] +[mypy-pyarrow.*] ignore_missing_imports = True [mypy-matplotlib.*] @@ -57,3 +57,6 @@ ignore_missing_imports = True [mypy-google.cloud.storage.*] ignore_missing_imports = True + +[mypy-torchrec.*] +ignore_missing_imports = True diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py index db532c86..d434bde4 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/common/graph_dataset.py @@ -1,13 +1,14 @@ # This file can probably be gigl-generic utilities. # We include a few graph-related IterableDatasets backed by GCS and BigQuery -from typing import Dict, Iterator, List, Optional, TypedDict +from typing import Any, Iterator, List, Mapping, Optional, TypedDict import numpy as np import orjson import pyarrow.parquet as pq import torch -from google.cloud.bigquery_storage import BigQueryReadClient, types +from google.cloud.bigquery_storage import BigQueryReadClient +from google.cloud.bigquery_storage_v1.types import DataFormat, ReadSession from torch.utils.data._utils.worker import WorkerInfo from gigl.common.types.uri.gcs_uri import GcsUri @@ -23,14 +24,10 @@ CONDENSED_EDGE_TYPE_FIELD = "condensed_edge_type" -HeterogeneousGraphEdgeDict = TypedDict( - "HeterogeneousGraphEdgeDict", - { - SRC_FIELD: str, - DST_FIELD: str, - CONDENSED_EDGE_TYPE_FIELD: str, - }, -) +class HeterogeneousGraphEdgeDict(TypedDict): + src: str + dst: str + condensed_edge_type: str class GcsIterableDataset(torch.utils.data.IterableDataset): @@ -52,7 +49,7 @@ def __init__( self._file_uris: np.ndarray = np.random.RandomState(seed).permutation( np.array([uri.uri for uri in file_uris]) ) - self._file_loader = None + self._file_loader: Optional[FileLoader] = None def _iterator_init(self): # Initialize it here to avoid client pickling issues for multiprocessing. @@ -66,7 +63,7 @@ def _iterator_init(self): return current_worker_file_uris_to_process - def __iter__(self) -> Iterator[Dict]: + def __iter__(self) -> Iterator[Any]: raise NotImplemented @@ -87,8 +84,9 @@ def __init__( """ super().__init__(file_uris=file_uris, seed=seed) - def __iter__(self) -> Iterator[Dict]: + def __iter__(self) -> Iterator[Mapping[str, Any]]: current_worker_file_uris_to_process = self._iterator_init() + assert self._file_loader is not None, "File loader not initialized" for file in current_worker_file_uris_to_process: tfh = self._file_loader.load_to_temp_file( @@ -117,9 +115,10 @@ def __init__( self._iter_batches_kwargs = {"batch_size": batch_size} if batch_size else {} super().__init__(file_uris=file_uris, seed=seed) - def __iter__(self) -> Iterator[Dict]: + def __iter__(self) -> Iterator[Mapping[str, Any]]: # Need to first split the work based on worker information current_worker_file_uris_to_process = self._iterator_init() + assert self._file_loader is not None, "File loader not initialized" for file in current_worker_file_uris_to_process: tfh = self._file_loader.load_to_temp_file( @@ -173,14 +172,14 @@ def _create_read_session( project, dataset, table = self.table.split(".") table_path = f"projects/{project}/datasets/{dataset}/tables/{table}" - read_options = types.ReadSession.TableReadOptions( + read_options = ReadSession.TableReadOptions( selected_fields=self.selected_fields, row_restriction=row_restriction, ) - session = types.ReadSession( + session = ReadSession( table=table_path, - data_format=types.DataFormat.ARROW, + data_format=DataFormat.ARROW, read_options=read_options, ) @@ -190,7 +189,7 @@ def _create_read_session( max_stream_count=1, ) - def __iter__(self): + def __iter__(self) -> Iterator[Mapping[str, Any]]: client = BigQueryReadClient() worker_info: Optional[WorkerInfo] = torch.utils.data.get_worker_info() @@ -234,11 +233,11 @@ def __init__( def __iter__(self) -> Iterator[HeterogeneousGraphEdgeDict]: for data in super().__iter__(): # Convert the data to a filtered dictionary with just essential keys. - yield { - SRC_FIELD: data[self._src_field], - DST_FIELD: data[self._dst_field], - CONDENSED_EDGE_TYPE_FIELD: data[self._condensed_edge_type_field], - } + yield HeterogeneousGraphEdgeDict( + src=data[self._src_field], + dst=data[self._dst_field], + condensed_edge_type=data[self._condensed_edge_type_field], + ) class GcsParquetHeterogeneousGraphIterableDataset(GcsParquetIterableDataset): @@ -257,12 +256,11 @@ def __init__( def __iter__(self) -> Iterator[HeterogeneousGraphEdgeDict]: for data in super().__iter__(): - # Convert the data to a filtered dictionary with just essential keys. - yield { - SRC_FIELD: data[self._src_field], - DST_FIELD: data[self._dst_field], - CONDENSED_EDGE_TYPE_FIELD: data[self._condensed_edge_type_field], - } + yield HeterogeneousGraphEdgeDict( + src=data[self._src_field], + dst=data[self._dst_field], + condensed_edge_type=data[self._condensed_edge_type_field], + ) class BigQueryHeterogeneousGraphIterableDataset(BigQueryIterableDataset): @@ -290,8 +288,8 @@ def __init__( def __iter__(self) -> Iterator[HeterogeneousGraphEdgeDict]: for row in super().__iter__(): # Convert the data to a filtered dictionary with just essential keys. - yield { - SRC_FIELD: row[self._src_field], - DST_FIELD: row[self._dst_field], - CONDENSED_EDGE_TYPE_FIELD: row[self._condensed_edge_type_field], - } + yield HeterogeneousGraphEdgeDict( + src=row[self._src_field], + dst=row[self._dst_field], + condensed_edge_type=row[self._condensed_edge_type_field], + ) diff --git a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py index b78a3f9d..5b39af6a 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py +++ b/python/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py @@ -119,7 +119,7 @@ def get_sharding_plan( def apply_sparse_optimizer( parameters: Iterable[nn.Parameter], - optimizer_cls: Type[Optimizer] = None, + optimizer_cls: Optional[Type[Optimizer]] = None, optimizer_kwargs: Dict[str, Any] = dict(), ) -> None: """