-
Notifications
You must be signed in to change notification settings - Fork 559
torchrec change for dynamic embedding #2533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @kanghui0204! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thanks for proposal; RE [1]: it would help to put together a toy example of what your doing potentially; so we can see how you intend to use this api; ideally to point you could create a few multi-gpu tests (with appropriate mocking if needed etc). RE [2]: So its not well documented, but we actually can support round robin based RW sharding today; its utilized in ZCH workflows (bucketization strategy is % world_size). Basically if you pass in RwSparseFeaturesDist(.., feature_hash_dim = [0,....0]) this will trigger this logic. This calls into FBGEMM block_bucketization kernels. Coincidently just added logic in this area, take a look at tests in PR: #2538 - specifically the case we set input_hash_size=0 on ZCH modules for full behavior (albeit a different use case). |
Hi @dstaay-fb thank you very much for quickly reply! RE1: I will prepare a example for you as a reference as soon as possible. Do you mean that setting the hash size of each table to 0 will make the |
Summary: # context * NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([pytorch#2533](pytorch#2533)) * The goal is to refactor the PR ([pytorch#2533](pytorch#2533)) on trunk so that torchrec can accept customized kernel. # design rationales * Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels. * `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. * In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary. NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359). ``` (Pdb) group_fused_params {'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1} ``` * besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc. * we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on. WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function. Differential Revision: D70723583
Summary: # context * NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([pytorch#2533](pytorch#2533)) * The goal is to refactor the PR ([pytorch#2533](pytorch#2533)) on trunk so that torchrec can accept customized kernel. # design rationales * Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels. * `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. * In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary. NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359). ``` (Pdb) group_fused_params {'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1} ``` * besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc. * we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on. WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function. Reviewed By: dstaay-fb Differential Revision: D70723583
Summary: # context * NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([pytorch#2533](pytorch#2533)) * The goal is to refactor the PR ([pytorch#2533](pytorch#2533)) on trunk so that torchrec can accept customized kernel. # design rationales * Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels. * `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. * In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary. NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359). ``` (Pdb) group_fused_params {'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1} ``` * besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc. * we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on. WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function. Reviewed By: dstaay-fb Differential Revision: D70723583
Summary: Pull Request resolved: #2887 # context * NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([#2533](#2533)) * The goal is to refactor the PR ([#2533](#2533)) on trunk so that torchrec can accept customized kernel. # design rationales * Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels. * `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. * In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary. NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359). ``` (Pdb) group_fused_params {'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1} ``` * besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc. * we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on. WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function. Reviewed By: dstaay-fb Differential Revision: D70723583 fbshipit-source-id: a86f2b59221a0dcbfe0577f6ecb2afd58c91207f
Hi TorchREC experts,
We would like to try incorporating NVIDIA HKV into the existing TorchREC workflow to extend TorchREC's capabilities for model-parallel dynamic embedding.
We aim to integrate HKV dynamic embedding as a new type of embedding table into the TorchREC workflow. To avoid disrupting the original TorchREC code, we have designed some code for registering new embedding tables, which will help us and other users to better register a customized embedding table into the TorchREC workflow. Our modifications mainly target the following two parts:
Registering a new customized compute table during the creation of the embedding table and lookup, and accepting its customized parameters.
Since the range of indices for dynamic embedding is unlimited, we need the input distribution to perform round-robin distribution.(Our current PR serves as a reference. For example, in the input dist section, we have only modified the RW code. However, it is necessary to support all sharding types, such as TWRW)
Our code is based on v0.7, and it can be easily migrated to the latest code. We are initiating this PR as a reference for further discussions with you. We hope to support a high-performance dynamic embedding feature.