Skip to content

Commit 72b814e

Browse files
emlinfacebook-github-bot
authored andcommitted
fix non sharding model publish (#3333)
Summary: Pull Request resolved: #3333 since zch weight tensor includes extra metadata in checkpoint, we need to shift the weight processing start and end column index, currently weight processing relies on DI sharding pass to generate shifted_weight_shard, but not every model has DI sharding. This diff added a default behavior to generate shifted_weight_shard if this is empty and it's for ZCH weight tensor. Reviewed By: EddyLXJ Differential Revision: D80434727 fbshipit-source-id: a40903217e4c661780915b3675bc25bca0f4b9c0
1 parent 61b7449 commit 72b814e

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

torchrec/distributed/tests/test_infer_shardings.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ def test_rw_with_virtual_table_eviction(
380380
batch_size = 4
381381
local_device = torch.device(f"{device_type}:0")
382382
eviction_policy = TimestampBasedEvictionPolicy()
383-
eviction_policy.init_metaheader_config(dtype_to_data_type(torch.float16))
383+
eviction_policy.init_metaheader_config(
384+
dtype_to_data_type(torch.float16), emb_dim
385+
)
384386
mi = create_test_model(
385387
num_embeddings,
386388
emb_dim,
@@ -392,6 +394,13 @@ def test_rw_with_virtual_table_eviction(
392394
weight_dtype=weight_dtype,
393395
virtual_table_eviction_policy=eviction_policy,
394396
)
397+
for t in mi.tables:
398+
self.assertIsNotNone(t.virtual_table_eviction_policy)
399+
self.assertEqual(
400+
# pyre-ignore [16]
401+
t.virtual_table_eviction_policy.get_embedding_dim(),
402+
emb_dim,
403+
)
395404

396405
non_sharded_model = mi.quant_model
397406
num_emb_half = num_embeddings // 2
@@ -430,19 +439,18 @@ def test_rw_with_virtual_table_eviction(
430439
["table_0"],
431440
ShardingType.ROW_WISE.value,
432441
)
433-
print(weights_spec)
434-
assert (
442+
443+
self.assertIsNotNone(
435444
weights_spec[
436445
"_module.sparse.ebc.tbes.0.0.table_0.weight"
437446
].virtual_table_dim_offsets
438-
is not None
439447
)
440-
assert (
448+
self.assertEqual(
441449
# pyre-ignore [16]
442450
weights_spec[
443451
"_module.sparse.ebc.tbes.0.0.table_0.weight"
444-
].virtual_table_dim_offsets[0]
445-
== 8
452+
].virtual_table_dim_offsets[0],
453+
8,
446454
)
447455

448456
@unittest.skipIf(

torchrec/modules/embedding_configs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,29 @@ def data_type_to_dtype(data_type: DataType) -> torch.dtype:
172172
class VirtualTableEvictionPolicy:
173173
# metadata header length in element size for virtual table in weight tensor value
174174
meta_header_len: int = 0
175+
embedding_dim: int = 0
175176
initialized: bool = False
176177

177178
"""
178179
Eviction policy for virtual table.
179180
"""
180181

181-
def init_metaheader_config(self, data_type: DataType) -> None:
182+
def init_metaheader_config(self, data_type: DataType, embedding_dim: int) -> None:
182183
# the eviction metaheader is set for training data type only. Once initialized, we don't need to reinitialize again
183184
if self.initialized:
184185
return
185186
# 8 bytes for key, 4 bytes timestamp, 4 bytes shared by used and count: 1 bit for used, 31 bits for count
186187
# for more details, please refer to: https://github.com/pytorch/FBGEMM/pull/4187
187188
self.meta_header_len = 16 // data_type_to_dtype(data_type).itemsize
189+
self.embedding_dim = embedding_dim
188190
self.initialized = True
189191

190192
def get_meta_header_len(self) -> int:
191193
return self.meta_header_len
192194

195+
def get_embedding_dim(self) -> int:
196+
return self.embedding_dim
197+
193198

194199
@dataclass
195200
class CountBasedEvictionPolicy(VirtualTableEvictionPolicy):

0 commit comments

Comments
 (0)