Skip to content

Commit 46280e5

Browse files
Chilleefacebook-github-bot
authored andcommitted
Added remove_duplicate parameter to nn.Module (#39)
Summary: Pull Request resolved: #39 Pull Request resolved: #6 This makes it so that shared parameters get their own entry in `named_parameters`. More broadly, this makes it so that ``` params_and_buffers = {**mod.named_named_parameters(remove_duplicate=False), **mod.named_buffers(remove_duplicate=False)} _stateless.functional_call(mod, params_and_buffers, args, kwargs) ``` is identical to calling the original module's forwards pass. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: pytorch/pytorch#71542 Reviewed By: jbschlosser, albanD Differential Revision: D33716716 Pulled By: Chillee fbshipit-source-id: 056727167dc1206dfd0f15a8a69cea3a95ce3ac4
1 parent 0aef069 commit 46280e5

8 files changed

+48
-31
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,12 +424,12 @@ def fused_optimizer(self) -> FusedOptimizer:
424424
return self._optim
425425

426426
def named_parameters(
427-
self, prefix: str = "", recurse: bool = True
427+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
428428
) -> Iterator[Tuple[str, nn.Parameter]]:
429429
yield from ()
430430

431431
def named_buffers(
432-
self, prefix: str = "", recurse: bool = True
432+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
433433
) -> Iterator[Tuple[str, torch.Tensor]]:
434434
for config, param in zip(
435435
self._config.embedding_tables,
@@ -471,7 +471,7 @@ def emb_module(
471471
return self._emb_module
472472

473473
def named_parameters(
474-
self, prefix: str = "", recurse: bool = True
474+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
475475
) -> Iterator[Tuple[str, nn.Parameter]]:
476476
combined_key = "/".join(
477477
[config.name for config in self._config.embedding_tables]
@@ -678,12 +678,12 @@ def fused_optimizer(self) -> FusedOptimizer:
678678
return self._optim
679679

680680
def named_parameters(
681-
self, prefix: str = "", recurse: bool = True
681+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
682682
) -> Iterator[Tuple[str, nn.Parameter]]:
683683
yield from ()
684684

685685
def named_buffers(
686-
self, prefix: str = "", recurse: bool = True
686+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
687687
) -> Iterator[Tuple[str, torch.Tensor]]:
688688
for config, param in zip(
689689
self._config.embedding_tables,
@@ -725,7 +725,7 @@ def emb_module(
725725
return self._emb_module
726726

727727
def named_parameters(
728-
self, prefix: str = "", recurse: bool = True
728+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
729729
) -> Iterator[Tuple[str, nn.Parameter]]:
730730
combined_key = "/".join(
731731
[config.name for config in self._config.embedding_tables]

torchrec/distributed/embedding_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def state_dict(
174174
)
175175

176176
def named_parameters(
177-
self, prefix: str = "", recurse: bool = True
177+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
178178
) -> Iterator[Tuple[str, nn.Parameter]]:
179179
for config, emb_module in zip(
180180
self._config.embedding_tables,
@@ -320,7 +320,7 @@ def state_dict(
320320
)
321321

322322
def named_parameters(
323-
self, prefix: str = "", recurse: bool = True
323+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
324324
) -> Iterator[Tuple[str, nn.Parameter]]:
325325
for config, emb_module in zip(
326326
self._config.embedding_tables,

torchrec/distributed/embedding_lookup.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,13 @@ def load_state_dict(
184184
return _IncompatibleKeys(missing_keys=m, unexpected_keys=u)
185185

186186
def named_parameters(
187-
self, prefix: str = "", recurse: bool = True
187+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
188188
) -> Iterator[Tuple[str, nn.Parameter]]:
189189
for emb_module in self._emb_modules:
190190
yield from emb_module.named_parameters(prefix, recurse)
191191

192192
def named_buffers(
193-
self, prefix: str = "", recurse: bool = True
193+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
194194
) -> Iterator[Tuple[str, torch.Tensor]]:
195195
for emb_module in self._emb_modules:
196196
yield from emb_module.named_buffers(prefix, recurse)
@@ -370,15 +370,15 @@ def load_state_dict(
370370
return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2)
371371

372372
def named_parameters(
373-
self, prefix: str = "", recurse: bool = True
373+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
374374
) -> Iterator[Tuple[str, nn.Parameter]]:
375375
for emb_module in self._emb_modules:
376376
yield from emb_module.named_parameters(prefix, recurse)
377377
for emb_module in self._score_emb_modules:
378378
yield from emb_module.named_parameters(prefix, recurse)
379379

380380
def named_buffers(
381-
self, prefix: str = "", recurse: bool = True
381+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
382382
) -> Iterator[Tuple[str, torch.Tensor]]:
383383
for emb_module in self._emb_modules:
384384
yield from emb_module.named_buffers(prefix, recurse)
@@ -466,13 +466,13 @@ def load_state_dict(
466466
)
467467

468468
def named_parameters(
469-
self, prefix: str = "", recurse: bool = True
469+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
470470
) -> Iterator[Tuple[str, nn.Parameter]]:
471471
for rank_modules in self._embedding_lookups_per_rank:
472472
yield from rank_modules.named_parameters(prefix, recurse)
473473

474474
def named_buffers(
475-
self, prefix: str = "", recurse: bool = True
475+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
476476
) -> Iterator[Tuple[str, torch.Tensor]]:
477477
for rank_modules in self._embedding_lookups_per_rank:
478478
yield from rank_modules.named_buffers(prefix, recurse)

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def named_modules(
441441
yield from [(prefix, self)]
442442

443443
def named_parameters(
444-
self, prefix: str = "", recurse: bool = True
444+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
445445
) -> Iterator[Tuple[str, nn.Parameter]]:
446446
for lookup in self._lookups:
447447
yield from lookup.named_parameters(
@@ -460,7 +460,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
460460
yield name
461461

462462
def named_buffers(
463-
self, prefix: str = "", recurse: bool = True
463+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
464464
) -> Iterator[Tuple[str, torch.Tensor]]:
465465
for lookup in self._lookups:
466466
yield from lookup.named_buffers(
@@ -731,7 +731,7 @@ def named_modules(
731731
yield from [(prefix, self)]
732732

733733
def named_parameters(
734-
self, prefix: str = "", recurse: bool = True
734+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
735735
) -> Iterator[Tuple[str, nn.Parameter]]:
736736
for name, parameter in self._lookup.named_parameters("", recurse):
737737
# update name to match embeddingBag parameter name
@@ -745,7 +745,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
745745
yield append_prefix(prefix, name.split(".")[-1])
746746

747747
def named_buffers(
748-
self, prefix: str = "", recurse: bool = True
748+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
749749
) -> Iterator[Tuple[str, torch.Tensor]]:
750750
for name, buffer in self._lookup.named_buffers("", recurse):
751751
yield append_prefix(prefix, name.split(".")[-1]), buffer

torchrec/distributed/grouped_position_weighted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
6969
)
7070

7171
def named_parameters(
72-
self, prefix: str = "", recurse: bool = True
72+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
7373
) -> Iterator[Tuple[str, nn.Parameter]]:
7474
for name, param in self.position_weights.items():
7575
yield append_prefix(prefix, f"position_weights.{name}"), param
7676

7777
def named_buffers(
78-
self, prefix: str = "", recurse: bool = True
78+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
7979
) -> Iterator[Tuple[str, torch.Tensor]]:
8080
yield from ()
8181

torchrec/distributed/model_parallel.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ def _load_state_dict(
396396
)
397397

398398
def _named_parameters(
399-
self, module: nn.Module, prefix: str = "", recurse: bool = True
399+
self,
400+
module: nn.Module,
401+
prefix: str = "",
402+
recurse: bool = True,
403+
remove_duplicate: bool = True,
400404
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
401405
if isinstance(module, ShardedModule):
402406
yield from module.named_parameters(prefix, recurse)
@@ -408,9 +412,11 @@ def _named_parameters(
408412
)
409413

410414
def named_parameters(
411-
self, prefix: str = "", recurse: bool = True
415+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
412416
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
413-
yield from self._named_parameters(self.dmp_module, prefix, recurse)
417+
yield from self._named_parameters(
418+
self.dmp_module, prefix, recurse, remove_duplicate
419+
)
414420

415421
@staticmethod
416422
def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]:
@@ -423,21 +429,32 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st
423429
)
424430

425431
def _named_buffers(
426-
self, module: nn.Module, prefix: str = "", recurse: bool = True
432+
self,
433+
module: nn.Module,
434+
prefix: str = "",
435+
recurse: bool = True,
436+
remove_duplicate: bool = True,
427437
) -> Iterator[Tuple[str, torch.Tensor]]:
428438
if isinstance(module, ShardedModule):
429-
yield from module.named_buffers(prefix, recurse)
439+
yield from module.named_buffers(prefix, recurse, remove_duplicate)
430440
else:
431-
yield from module.named_buffers(prefix, recurse=False)
441+
yield from module.named_buffers(
442+
prefix, recurse=False, remove_duplicate=True
443+
)
432444
for name, child in module.named_children():
433445
yield from self._named_buffers(
434-
child, append_prefix(prefix, name), recurse
446+
child, append_prefix(prefix, name), recurse, remove_duplicate
435447
)
436448

437449
def named_buffers(
438-
self, prefix: str = "", recurse: bool = True
450+
self,
451+
prefix: str = "",
452+
recurse: bool = True,
453+
remove_duplicate: bool = True,
439454
) -> Iterator[Tuple[str, torch.Tensor]]:
440-
yield from self._named_buffers(self.dmp_module, prefix, recurse)
455+
yield from self._named_buffers(
456+
self.dmp_module, prefix, recurse, remove_duplicate
457+
)
441458

442459
@property
443460
def fused_optimizer(self) -> KeyedOptimizer:

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
103103
)
104104

105105
def named_buffers(
106-
self, prefix: str = "", recurse: bool = True
106+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
107107
) -> Iterator[Tuple[str, torch.Tensor]]:
108108
for config, weight in zip(
109109
self._config.embedding_tables,

torchrec/quant/embedding_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def state_dict(
211211
return destination
212212

213213
def named_buffers(
214-
self, prefix: str = "", recurse: bool = True
214+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
215215
) -> Iterator[Tuple[str, nn.Parameter]]:
216216
state_dict = self.state_dict(prefix=prefix, keep_vars=True)
217217
for key, value in state_dict.items():

0 commit comments

Comments
 (0)