Skip to content

Commit 5bd321c

Browse files
apbosegs-olive
authored andcommitted
Embedding operator in dynamo
1 parent 4b0dfa6 commit 5bd321c

File tree

4 files changed

+230
-1
lines changed

4 files changed

+230
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tensorrt as trt
55
from torch_tensorrt.fx.converters import acc_ops_converters
66
from ..converter_registry import dynamo_tensorrt_converter
7-
from torch.fx.node import Argument, Target
7+
from torch.fx.node import Argument, Target, Node
88

99
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1010
from torch_tensorrt.dynamo.conversion import SourceIR, impl
@@ -71,6 +71,62 @@ def aten_ops_div(
7171
)
7272

7373

74+
def embedding_param_validator(embedding_node: Node):
75+
76+
max_norm = args_bounds_check(embedding_node.args, 2)
77+
norm_type = args_bounds_check(embedding_node.args, 3)
78+
scale_grad_by_freq = args_bounds_check(embedding_node.args, 4)
79+
sparse = args_bounds_check(embedding_node.args, 5)
80+
81+
if max_norm is not None:
82+
_LOGGER.debug(
83+
f"Currently we don't support specifying max_norm, got {max_norm}."
84+
)
85+
return False
86+
87+
if norm_type is not None and norm_type != 2.0:
88+
_LOGGER.debug(
89+
f"Currently we don't support specifying norm_type, got {norm_type}."
90+
)
91+
return False
92+
93+
if scale_grad_by_freq is not None:
94+
_LOGGER.debug(
95+
f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}."
96+
)
97+
return False
98+
99+
if sparse is not None:
100+
_LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.")
101+
return False
102+
103+
return True
104+
105+
106+
@dynamo_tensorrt_converter(
107+
torch.ops.aten.embedding.default, capability_validator=embedding_param_validator
108+
)
109+
def aten_ops_embedding(
110+
network: TRTNetwork,
111+
target: Target,
112+
args: Tuple[Argument, ...],
113+
kwargs: Dict[str, Argument],
114+
name: str,
115+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
116+
return impl.embedding.embedding(
117+
network,
118+
target,
119+
SourceIR.ATEN,
120+
name,
121+
input=args[1],
122+
weight=args[0],
123+
max_norm=or_none(args, 2),
124+
norm_type=or_none(args, 3),
125+
scale_grad_by_freq=or_none(args, 4),
126+
sparse=or_none(args, 5),
127+
)
128+
129+
74130
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar)
75131
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor)
76132
def aten_ops_fmod(

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from torch_tensorrt.fx.converters.impl import convolution
22
from . import condition
33
from . import elementwise
4+
from . import embedding
45
from . import normalization
56
from . import slice
67
from . import unary
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast, Any
4+
5+
import numpy as np
6+
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Target
10+
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
12+
13+
from torch_tensorrt.fx.converters.converter_utils import (
14+
SourceIR,
15+
set_layer_name,
16+
)
17+
18+
from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor
19+
20+
21+
def embedding(
22+
network: TRTNetwork,
23+
target: Target,
24+
source_ir: Optional[SourceIR],
25+
name: str,
26+
input: TRTTensor,
27+
weight: TRTTensor,
28+
max_norm: None,
29+
norm_type: None,
30+
scale_grad_by_freq: bool,
31+
sparse: bool,
32+
) -> TRTTensor:
33+
34+
if network.has_implicit_batch_dimension:
35+
raise RuntimeError(
36+
"The `embedding` function should be called with explicit batch dimension."
37+
)
38+
39+
indices_tensor = input
40+
embedding_tensor = weight
41+
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
42+
raise RuntimeError(
43+
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
44+
)
45+
indices_tensor = get_trt_tensor(network, indices_tensor, f"{name}_indices_tensor")
46+
embedding_tensor = get_trt_tensor(
47+
network, embedding_tensor, f"{name}_embedding_tensor"
48+
)
49+
# unsupported parameters
50+
# ignore padding_idx since it is meaningful for training only
51+
52+
if max_norm is not None:
53+
raise RuntimeError(
54+
f"Currently we don't support specifying max_norm, got {max_norm}."
55+
)
56+
57+
if norm_type is not None and norm_type != 2.0:
58+
raise RuntimeError(
59+
f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
60+
)
61+
62+
if scale_grad_by_freq:
63+
raise RuntimeError(
64+
"Currently we don't support scale gradient by word frequency."
65+
)
66+
67+
if sparse:
68+
raise RuntimeError("Currently we don't support sparse gradient.")
69+
70+
# Implement embedding lookup with gather layer
71+
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
72+
set_layer_name(gather_layer, target, name + "_gather")
73+
return gather_layer.get_output(0)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
5+
from parameterized import param, parameterized
6+
from torch_tensorrt import Input
7+
8+
9+
class TestEmbeddingConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
param(
13+
test_name="1d_indices",
14+
indices_tensor=torch.tensor([3, 1, 2]),
15+
weights_tensor=torch.randn(5, 10),
16+
),
17+
param(
18+
test_name="2d_indices",
19+
indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]]),
20+
weights_tensor=torch.randn(5, 10),
21+
),
22+
param(
23+
test_name="3d_indices",
24+
indices_tensor=torch.tensor([[[0, 1], [2, 3]], [[3, 4], [4, 0]]]),
25+
weights_tensor=torch.randn(5, 10),
26+
),
27+
]
28+
)
29+
def test_embedding(
30+
self,
31+
test_name,
32+
indices_tensor,
33+
weights_tensor,
34+
padding_idx=None,
35+
max_norm=None,
36+
norm_type=2.0,
37+
scale_grad_by_freq=False,
38+
sparse=False,
39+
):
40+
class TestEmbedding(torch.nn.Module):
41+
def forward(self, indices, weights):
42+
return torch.nn.functional.embedding(
43+
input=indices,
44+
weight=weights,
45+
padding_idx=padding_idx,
46+
max_norm=max_norm,
47+
norm_type=norm_type,
48+
scale_grad_by_freq=scale_grad_by_freq,
49+
sparse=sparse,
50+
)
51+
52+
self.run_test(
53+
TestEmbedding(),
54+
inputs=[indices_tensor.int(), weights_tensor.float()],
55+
expected_ops={torch.ops.aten.embedding.default},
56+
)
57+
58+
def test_embedding_with_dynamic_shape_four_dimensions(
59+
self,
60+
padding_idx=None,
61+
max_norm=None,
62+
norm_type=2.0,
63+
scale_grad_by_freq=False,
64+
sparse=False,
65+
):
66+
class TestEmbedding(torch.nn.Module):
67+
def forward(self, input, weights):
68+
return torch.nn.functional.embedding(
69+
input=input,
70+
weight=weights,
71+
padding_idx=padding_idx,
72+
max_norm=max_norm,
73+
norm_type=norm_type,
74+
scale_grad_by_freq=scale_grad_by_freq,
75+
sparse=sparse,
76+
)
77+
78+
input_specs = [
79+
Input(
80+
shape=(-1, -1, -1, -1),
81+
dtype=torch.float32,
82+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
83+
),
84+
Input(
85+
shape=(-1, -1, -1, -1),
86+
dtype=torch.float32,
87+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
88+
),
89+
]
90+
91+
self.run_test_with_dynamic_shape(
92+
TestEmbedding(),
93+
input_specs,
94+
expected_ops={torch.ops.aten.embedding.default},
95+
)
96+
97+
98+
if __name__ == "__main__":
99+
run_tests()

0 commit comments

Comments
 (0)