-
Notifications
You must be signed in to change notification settings - Fork 6
Add some 'shared' KGE utils #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still working through this, but approving to unblock
from torchrec.streamable import Pipelineable | ||
|
||
|
||
class BatchBase(Pipelineable, abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this case class get used anywhere except DataclassBatch
?
If not, it might be easier for us to just maintain DataclassBatch
which subclasses torchrec.datasets.utils.Batch
|
||
|
||
class LargeEmbeddingLookup(nn.Module): | ||
def __init__(self, embeddings_config: List[torchrec.EmbeddingBagConfig]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should embeddings_config
be called tables
?
super().__init__() | ||
self.ebc = torchrec.EmbeddingBagCollection( | ||
tables=embeddings_config, | ||
device=torch.device("meta"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own knowledge, trying to understand, will this always be "meta" ?
) | ||
|
||
|
||
class GcsIterableDataset(torch.utils.data.IterableDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like these abstractions!
Albeit we probably dont need so many classes - it can be auto inferred based off file suffix, et al. We can revisit these though.
otherwise the model itself. | ||
""" | ||
|
||
if torch.distributed.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im assuming this also only works on nccl backend?
i.e. both cuda needs to be available and nccl backend enabled?
# Build a sharding plan | ||
logger.info("***** Wrapping in DistributedModelParallel *****") | ||
logger.info(f"Model before wrapping: {model}") | ||
model = DistributedModelParallel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: I see there are a lot more configurable params here.
Curious why we only chose to parameterize sharding_plan
?
return model | ||
|
||
|
||
def get_sharding_plan( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am suprised DistributedModelParallel
doesnt do this for us already?
Is there something here that I am not seeing that would be model/business logic specific?
if self.optimizer and state_dict.get(self.OPTIMIZER_KEY): | ||
self.optimizer.load_state_dict(state_dict[self.OPTIMIZER_KEY]) | ||
|
||
def to_state_dict(self) -> STATE_DICT_TYPE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function doesn't seem to follow the traditional Stateful protocol.
Curious where is it used?
local_uri = ( | ||
checkpoint_id | ||
if isinstance(checkpoint_id, LocalUri) | ||
else LocalUri(tempfile.mkdtemp(prefix="checkpoint")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very much a nit.
The user of mkdtemp() is responsible for deleting the temporary directory and its contents when done with it.
One thing we can do is always write to tmpdir then copy to either local or gcs dir then delete the local dir.
These utils are used in KGE experimentation and are likely to be generically useful for other graph applications (outside KGE).
Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: YES