Skip to content

Commit 0f41cea

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
change ir_custom_op output to list of tensors
Summary: # context * the original implementation of "ir_custom_op" strategy has logic flaw: * input the sum of dim, and let the op return a contiguous tensor, then split it to multiple tensors * from the dynamic shape (ds) prespective, there is a sum(ds_i) before the op, then another split to (ds_i). the range calculation for these ds are unnecessary and create a lot of complexities * it's better to keep these ds transparent into and out from the op Differential Revision: D53558783
1 parent abd692a commit 0f41cea

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

torchrec/ir/serializer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,17 @@ def ebc_meta_forward(
8989
features: KeyedJaggedTensor,
9090
) -> KeyedTensor:
9191
batch_size = features.stride()
92-
dim = sum(ebc._lengths_per_embedding)
92+
dims = ebc._lengths_per_embedding
9393
arg_list = [
9494
features.values(),
9595
features.weights_or_none(),
9696
features.lengths_or_none(),
9797
features.offsets_or_none(),
9898
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
99-
output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim)
99+
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
100100
return KeyedTensor(
101101
keys=ebc._embedding_names,
102-
values=output,
102+
values=torch.cat(outputs, dim=1),
103103
length_per_key=ebc._lengths_per_embedding,
104104
)
105105

@@ -110,17 +110,17 @@ def fpebc_meta_forward(
110110
) -> KeyedTensor:
111111
batch_size = features.stride()
112112
ebc = fpebc._embedding_bag_collection
113-
dim = sum(ebc._lengths_per_embedding)
113+
dims = ebc._lengths_per_embedding
114114
arg_list = [
115115
features.values(),
116116
features.weights_or_none(),
117117
features.lengths_or_none(),
118118
features.offsets_or_none(),
119119
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
120-
output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim)
120+
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
121121
return KeyedTensor(
122122
keys=ebc._embedding_names,
123-
values=output,
123+
values=torch.cat(outputs, dim=1),
124124
length_per_key=ebc._lengths_per_embedding,
125125
)
126126

torchrec/ir/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,28 @@
3030

3131
@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
3232
def ir_custom_op_impl(
33-
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
34-
) -> torch.Tensor:
33+
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
34+
) -> List[torch.Tensor]:
3535
device = None
3636
for t in tensors:
3737
if t is not None:
3838
device = t.device
3939
break
40-
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}")
41-
return torch.empty(batch_size, dim, device=device)
40+
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dims}) {device}")
41+
return [torch.empty(batch_size, dim, device=device) for dim in dims]
4242

4343

4444
@torch.library.register_fake("torchrec::ir_custom_op")
4545
def ir_custom_op_fake(
46-
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
47-
) -> torch.Tensor:
46+
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
47+
) -> List[torch.Tensor]:
4848
device = None
4949
for t in tensors:
5050
if t is not None:
5151
device = t.device
5252
break
53-
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}")
54-
return torch.empty(batch_size, dim, device=device)
53+
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dims}) {device}")
54+
return [torch.empty(batch_size, dim, device=device) for dim in dims]
5555

5656

5757
def encapsulate_ir_modules(

0 commit comments

Comments
 (0)