Skip to content

Add some 'shared' KGE utils #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Conversation

nshah171
Copy link
Contributor

These utils are used in KGE experimentation and are likely to be generically useful for other graph applications (outside KGE).

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: YES

Copy link
Collaborator

@svij-sc svij-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still working through this, but approving to unblock

from torchrec.streamable import Pipelineable


class BatchBase(Pipelineable, abc.ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this case class get used anywhere except DataclassBatch ?
If not, it might be easier for us to just maintain DataclassBatch which subclasses torchrec.datasets.utils.Batch



class LargeEmbeddingLookup(nn.Module):
def __init__(self, embeddings_config: List[torchrec.EmbeddingBagConfig]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should embeddings_config be called tables ?

super().__init__()
self.ebc = torchrec.EmbeddingBagCollection(
tables=embeddings_config,
device=torch.device("meta"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own knowledge, trying to understand, will this always be "meta" ?

)


class GcsIterableDataset(torch.utils.data.IterableDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these abstractions!
Albeit we probably dont need so many classes - it can be auto inferred based off file suffix, et al. We can revisit these though.

otherwise the model itself.
"""

if torch.distributed.is_initialized():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im assuming this also only works on nccl backend?
i.e. both cuda needs to be available and nccl backend enabled?

# Build a sharding plan
logger.info("***** Wrapping in DistributedModelParallel *****")
logger.info(f"Model before wrapping: {model}")
model = DistributedModelParallel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: I see there are a lot more configurable params here.
Curious why we only chose to parameterize sharding_plan ?

return model


def get_sharding_plan(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am suprised DistributedModelParallel doesnt do this for us already?
Is there something here that I am not seeing that would be model/business logic specific?

if self.optimizer and state_dict.get(self.OPTIMIZER_KEY):
self.optimizer.load_state_dict(state_dict[self.OPTIMIZER_KEY])

def to_state_dict(self) -> STATE_DICT_TYPE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't seem to follow the traditional Stateful protocol.
Curious where is it used?

local_uri = (
checkpoint_id
if isinstance(checkpoint_id, LocalUri)
else LocalUri(tempfile.mkdtemp(prefix="checkpoint"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very much a nit.

The user of mkdtemp() is responsible for deleting the temporary directory and its contents when done with it.

One thing we can do is always write to tmpdir then copy to either local or gcs dir then delete the local dir.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants