Skip to content

Conversation

kanghui0204
Copy link

@kanghui0204 kanghui0204 commented Nov 4, 2024

Hi TorchREC experts,

We would like to try incorporating NVIDIA HKV into the existing TorchREC workflow to extend TorchREC's capabilities for model-parallel dynamic embedding.

We aim to integrate HKV dynamic embedding as a new type of embedding table into the TorchREC workflow. To avoid disrupting the original TorchREC code, we have designed some code for registering new embedding tables, which will help us and other users to better register a customized embedding table into the TorchREC workflow. Our modifications mainly target the following two parts:

  1. Registering a new customized compute table during the creation of the embedding table and lookup, and accepting its customized parameters.

  2. Since the range of indices for dynamic embedding is unlimited, we need the input distribution to perform round-robin distribution.(Our current PR serves as a reference. For example, in the input dist section, we have only modified the RW code. However, it is necessary to support all sharding types, such as TWRW)

Our code is based on v0.7, and it can be easily migrated to the latest code. We are initiating this PR as a reference for further discussions with you. We hope to support a high-performance dynamic embedding feature.

@facebook-github-bot
Copy link
Contributor

Hi @kanghui0204!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@dstaay-fb
Copy link
Contributor

Thanks for proposal;

RE [1]: it would help to put together a toy example of what your doing potentially; so we can see how you intend to use this api; ideally to point you could create a few multi-gpu tests (with appropriate mocking if needed etc).

RE [2]: So its not well documented, but we actually can support round robin based RW sharding today; its utilized in ZCH workflows (bucketization strategy is % world_size). Basically if you pass in RwSparseFeaturesDist(.., feature_hash_dim = [0,....0]) this will trigger this logic. This calls into FBGEMM block_bucketization kernels. Coincidently just added logic in this area, take a look at tests in PR: #2538 - specifically the case we set input_hash_size=0 on ZCH modules for full behavior (albeit a different use case).

@kanghui0204
Copy link
Author

kanghui0204 commented Nov 7, 2024

Hi @dstaay-fb thank you very much for quickly reply!

RE1: I will prepare a example for you as a reference as soon as possible.
RE2: Sorry , I didn't find the test for input_hash_size=0 on ZCH modules in PR2538,

Do you mean that setting the hash size of each table to 0 will make the block_bucketize_sparse_features in FBGEMM switch from contiguous block partitioning to round-robin partitioning? It looks like we need to modify the information of sharding_infos input to BaseRwEmbeddingSharding(https://github.com/dstaay-fb/torchrec/blob/export-D62483238/torchrec/distributed/sharding/rw_sharding.py#L115), is that correct?

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 15, 2025
Summary:
# context
* NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([pytorch#2533](pytorch#2533))
* The goal is to refactor the PR ([pytorch#2533](pytorch#2533)) on trunk so that torchrec can accept customized kernel.

# design rationales
* Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels.
* `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. 
* In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary.

NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359).
```
(Pdb) group_fused_params
{'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1}
```

* besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc.
* we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on.

WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function.

Differential Revision: D70723583
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 15, 2025
Summary:

# context
* NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([pytorch#2533](pytorch#2533))
* The goal is to refactor the PR ([pytorch#2533](pytorch#2533)) on trunk so that torchrec can accept customized kernel.

# design rationales
* Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels.
* `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. 
* In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary.

NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359).
```
(Pdb) group_fused_params
{'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1}
```

* besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc.
* we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on.

WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function.

Reviewed By: dstaay-fb

Differential Revision: D70723583
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 15, 2025
Summary:

# context
* NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([pytorch#2533](pytorch#2533))
* The goal is to refactor the PR ([pytorch#2533](pytorch#2533)) on trunk so that torchrec can accept customized kernel.

# design rationales
* Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels.
* `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. 
* In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary.

NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359).
```
(Pdb) group_fused_params
{'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1}
```

* besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc.
* we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on.

WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function.

Reviewed By: dstaay-fb

Differential Revision: D70723583
facebook-github-bot pushed a commit that referenced this pull request Apr 16, 2025
Summary:
Pull Request resolved: #2887

# context
* NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([#2533](#2533))
* The goal is to refactor the PR ([#2533](#2533)) on trunk so that torchrec can accept customized kernel.

# design rationales
* Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels.
* `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many.
* In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary.

NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359).
```
(Pdb) group_fused_params
{'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1}
```

* besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc.
* we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on.

WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function.

Reviewed By: dstaay-fb

Differential Revision: D70723583

fbshipit-source-id: a86f2b59221a0dcbfe0577f6ecb2afd58c91207f
@TroyGarden TroyGarden closed this Jun 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants