-
Notifications
You must be signed in to change notification settings - Fork 6
Add some 'shared' KGE utils #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
5c30b91
add common utils
nshah-sc a475af2
fix
nshah-sc 00b175b
Merge branch 'main' into nshah/add-common-torchrec-utils
nshah-sc d23a73e
comments
nshah-sc bb39ce8
Merge branch 'main' into nshah/add-common-torchrec-utils
nshah-sc 5be12b4
type issues
nshah-sc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
python/gigl/experimental/knowledge_graph_embedding/common/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
These utilities may be more generically reusable inside GiGL for other applications. |
Empty file.
162 changes: 162 additions & 0 deletions
162
python/gigl/experimental/knowledge_graph_embedding/common/dist_checkpoint.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
""" | ||
This module provides functions to load and save distributed checkpoints | ||
using the Torch Distributed Checkpointing API. | ||
""" | ||
|
||
import tempfile | ||
from concurrent.futures import Future, ThreadPoolExecutor | ||
from typing import Optional, Union | ||
|
||
import torch.distributed.checkpoint as dcp | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
|
||
from gigl.common import GcsUri, LocalUri, Uri | ||
from gigl.common.logger import Logger | ||
from gigl.common.utils.local_fs import delete_local_directory | ||
from gigl.src.common.utils.file_loader import FileLoader | ||
|
||
logger = Logger() | ||
|
||
|
||
class AppState(Stateful): | ||
""" | ||
This is a useful wrapper for checkpointing an application state. Since this | ||
object is compliant with the Stateful protocol, DCP will automatically | ||
call state_dict/load_state_dict as needed in the dcp.save/load APIs. | ||
|
||
We take advantage of this wrapper to hande calling distributed state dict | ||
methods on the model and optimizer. | ||
|
||
See https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html | ||
for more details. | ||
""" | ||
|
||
MODEL_KEY = "model" | ||
OPTIMIZER_KEY = "optimizer" | ||
APP_STATE_KEY = "app" | ||
|
||
def __init__(self, model: nn.Module, optimizer: Optional[optim.Optimizer] = None): | ||
self.model = model | ||
self.optimizer = optimizer | ||
|
||
def state_dict(self): | ||
model_state_dict = self.model.state_dict() | ||
optimizer_state_dict = self.optimizer.state_dict() if self.optimizer else None | ||
return { | ||
self.MODEL_KEY: model_state_dict, | ||
self.OPTIMIZER_KEY: optimizer_state_dict, | ||
} | ||
|
||
def load_state_dict(self, state_dict): | ||
# sets our state dicts on the model and optimizer, now that we've loaded | ||
self.model.load_state_dict(state_dict[self.MODEL_KEY]) | ||
if self.optimizer and state_dict.get(self.OPTIMIZER_KEY): | ||
self.optimizer.load_state_dict(state_dict[self.OPTIMIZER_KEY]) | ||
|
||
def to_state_dict(self) -> STATE_DICT_TYPE: | ||
""" | ||
Converts the AppState to a state dict that can be used with DCP. | ||
""" | ||
return { | ||
self.APP_STATE_KEY: self, | ||
} | ||
|
||
|
||
def load_checkpoint_from_uri( | ||
state_dict: STATE_DICT_TYPE, | ||
checkpoint_id: Uri, | ||
): | ||
assert isinstance(checkpoint_id, LocalUri) or isinstance( | ||
checkpoint_id, GcsUri | ||
), "checkpoint_id must be a LocalUri or GcsUri." | ||
|
||
created_temp_local_uri = False | ||
if isinstance(checkpoint_id, GcsUri): | ||
# If the URI is a GCS URI, we need to download it first | ||
local_uri = LocalUri(tempfile.mkdtemp(prefix="checkpoint")) | ||
created_temp_local_uri = True | ||
file_loader = FileLoader() | ||
file_loader.load_directory(dir_uri_src=checkpoint_id, dir_uri_dst=local_uri) | ||
logger.info(f"Downloaded checkpoint from GCS: {checkpoint_id} to {local_uri}") | ||
else: | ||
local_uri = checkpoint_id | ||
|
||
reader = dcp.FileSystemReader(path=local_uri.uri) | ||
dcp.load(state_dict=state_dict, storage_reader=reader) | ||
logger.info(f"Loaded checkpoint from {checkpoint_id}") | ||
|
||
# Clean up the temp local uri if it was created | ||
if created_temp_local_uri: | ||
delete_local_directory(local_path=local_uri) | ||
|
||
|
||
def save_checkpoint_to_uri( | ||
state_dict: STATE_DICT_TYPE, | ||
checkpoint_id: Uri, | ||
should_save_asynchronously: bool = False, | ||
) -> Union[Future[Uri], Uri]: | ||
""" | ||
Saves the state_dict to a specified checkpoint_id URI using the Torch Distributed Checkpointing API. | ||
|
||
If the checkpoint_id is a GCS URI, it will first save the checkpoint | ||
locally and then upload it to GCS. | ||
|
||
If `should_save_asynchronously` is True, the save operation will be | ||
performed asynchronously, returning a Future object. Otherwise, it will | ||
block until the save operation is complete. | ||
|
||
Args: | ||
state_dict (STATE_DICT_TYPE): The state dictionary to save. | ||
checkpoint_id (Uri): The URI where the checkpoint will be saved. | ||
should_save_asynchronously (bool): If True, saves the checkpoint asynchronously. | ||
Returns: | ||
Union[Future[Uri], Uri]: The URI where the checkpoint was saved, or | ||
a Future object if saved asynchronously. | ||
Raises: | ||
AssertionError: If checkpoint_id is not a LocalUri or GcsUri. | ||
""" | ||
|
||
def _save_checkpoint( | ||
checkpoint_id: Uri, | ||
local_uri: LocalUri, | ||
checkpoint_future: Optional[Future] = None, | ||
) -> Uri: | ||
# If we have a checkpoint future, we will wait for it to complete (async save) | ||
if checkpoint_future: | ||
checkpoint_future.result() | ||
|
||
if isinstance(checkpoint_id, GcsUri): | ||
# If the URI is a GCS URI, we need to ensure the file is uploaded | ||
# to GCS after saving it locally. | ||
file_loader = FileLoader() | ||
file_loader.load_directory(dir_uri_src=local_uri, dir_uri_dst=checkpoint_id) | ||
logger.info(f"Uploaded checkpoint to GCS: {checkpoint_id}") | ||
|
||
return checkpoint_id | ||
|
||
assert isinstance(checkpoint_id, LocalUri) or isinstance( | ||
checkpoint_id, GcsUri | ||
), "checkpoint_id must be a LocalUri or GcsUri." | ||
local_uri = ( | ||
checkpoint_id | ||
if isinstance(checkpoint_id, LocalUri) | ||
else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) | ||
nshah-sc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
writer = dcp.FileSystemWriter(path=local_uri.uri) | ||
|
||
if should_save_asynchronously: | ||
logger.info(f"Saving checkpoint asynchronously to {checkpoint_id}") | ||
checkpoint_future = dcp.async_save(state_dict, storage_writer=writer) | ||
executor = ThreadPoolExecutor(max_workers=1) | ||
future = executor.submit( | ||
_save_checkpoint, checkpoint_id, local_uri, checkpoint_future | ||
) | ||
return future | ||
else: | ||
logger.info(f"Saving checkpoint synchronously to {checkpoint_id}") | ||
dcp.save(state_dict, storage_writer=writer) | ||
return _save_checkpoint(checkpoint_id, local_uri, None) |
52 changes: 52 additions & 0 deletions
52
python/gigl/experimental/knowledge_graph_embedding/common/distributed.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
from typing import Tuple | ||
|
||
from gigl.common.logger import Logger | ||
from gigl.distributed.dist_context import DistributedContext | ||
|
||
logger = Logger() | ||
|
||
|
||
def set_process_env_vars_for_torch_dist( | ||
process_number_on_current_machine: int, | ||
num_processes_on_current_machine: int, | ||
machine_context: DistributedContext, | ||
port: int = 29500, | ||
) -> Tuple[int, int, int, int]: | ||
""" | ||
This function sets the environment variables required for | ||
distributed training with PyTorch. It assumes a multi-machine | ||
setup where each machine has a number of processes running. | ||
The number of machines and rendevous is determined by the | ||
`machine_context` provided. | ||
|
||
Args: | ||
process_number_on_current_machine (int): The process number on the current machine. | ||
num_processes_on_current_machine (int): The total number of processes on the current machine. | ||
machine_context (DistributedContext): The context containing information about the distributed setup. | ||
|
||
Returns: | ||
Tuple[int, int, int, int]: A tuple containing: | ||
- local_rank (int): The local rank of the process on the current machine. | ||
- rank (int): The global rank of the process across all machines. | ||
- local_world_size (int): The number of processes on the current machine. | ||
- world_size (int): The total number of processes across all machines. | ||
""" | ||
# Set the environment variables for the current process | ||
# This is required for distributed training | ||
os.environ["LOCAL_RANK"] = str(process_number_on_current_machine) | ||
os.environ["RANK"] = str( | ||
machine_context.global_rank * num_processes_on_current_machine | ||
+ process_number_on_current_machine | ||
) | ||
os.environ["WORLD_SIZE"] = str( | ||
num_processes_on_current_machine * machine_context.global_world_size | ||
) | ||
os.environ["LOCAL_WORLD_SIZE"] = str(num_processes_on_current_machine) | ||
os.environ["MASTER_ADDR"] = machine_context.main_worker_ip_address | ||
os.environ["MASTER_PORT"] = str(port) | ||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
rank = int(os.environ["RANK"]) | ||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
return local_rank, rank, local_world_size, world_size |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.