Skip to content

feat: support 1d ITensor offsets for embedding_bag converter #2677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

42 changes: 8 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,26 +231,7 @@ def aten_ops_cat(
)


def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

if scale_grad_by_freq is not None:
_LOGGER.debug(
f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}."
)
return False

if sparse is not None:
_LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.")
return False

return True


@dynamo_tensorrt_converter(
torch.ops.aten.embedding.default, capability_validator=embedding_param_validator
)
@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
def aten_ops_embedding(
ctx: ConversionContext,
target: Target,
Expand All @@ -265,22 +246,19 @@ def aten_ops_embedding(
name,
input=args[1],
weight=args[0],
# args[2] is the padding index, which is useful for training only
scale_grad_by_freq=args_bounds_check(args, 3),
sparse=args_bounds_check(args, 4),
)


def embedding_bag_validator(node: Node) -> bool:
mode = args_bounds_check(node.args, 4, 0)
indices = node.args[1].meta.get("tensor_meta")
if not one_user_validator(node):
return False
meta = node.args[1].meta
indices = meta.get("tensor_meta")
if indices is None:
indices = meta.get("val")
if indices is None:
return False
return (
bool(node.args[2].op == "get_attr")
and (mode == 0 or mode == 1 or mode == 2)
and len(indices.shape) == 1
)
return len(indices.shape) == 1 # currently only support 1D indices


@dynamo_tensorrt_converter(
Expand All @@ -293,7 +271,6 @@ def embedding_bag_validator(node: Node) -> bool:
{
0: (TRTTensor,),
1: (TRTTensor,),
2: (np.ndarray, torch.Tensor),
}
)
def aten_ops_embedding_bag(
Expand All @@ -311,12 +288,9 @@ def aten_ops_embedding_bag(
weight=args[0],
indices=args[1],
offsets=args[2],
scale_grad_by_freq=args_bounds_check(args, 3, False),
mode=args_bounds_check(args, 4, 0),
sparse=args_bounds_check(args, 5, False),
per_sample_weights=args_bounds_check(args, 6, None),
include_last_offset=args_bounds_check(args, 7, False),
# padding index is useful for training only
)


Expand Down
109 changes: 109 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Argument, Target
from torch_tensorrt import _enums
Expand Down Expand Up @@ -530,3 +531,111 @@ def flatten_dims(
new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])

return new_shape


def append(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
original_tensor: TRTTensor,
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
dim: int = 0,
) -> TRTTensor:
"""
Append a new value to the last of the original tensor along the specified dimension (default 0).
For example, if the original tensor is [1, 2, 3], the new value is 4, and the dim is 0,
the new tensor will be [1, 2, 3, 4].

Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
target (Target): Target of calling node
source_ir (Optional[SourceIR]): SourceIR of calling converter
name (str): Name of the calling layer
original_tensor (TRTTensor): A TRTTensor to append the new value to
new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to append
dim (int, optional): Dimention to append the new value. Defaults to 0.

Returns:
TRTTensor: A new TRTTensor that is the result of appending the new value to the original tensor
"""
if isinstance(new_value, (int, float)):
new_value = np.array([new_value])
new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)

return impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_concat",
[original_tensor, new_value],
get_positive_dim(dim, len(original_tensor.shape)),
)


def set_item(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
original_tensor: TRTTensor,
index: int,
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
) -> TRTTensor:
"""
Set a new value to the original tensor at the specified index. For example,
if the original tensor is [1, 2, 3], the new value is 4, and the index is 1,
the new tensor will be [1, 4, 3].
If the index is out of bound, the new value will be appended to the end.

Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
target (Target): Target of calling node
source_ir (Optional[SourceIR]): SourceIR of calling converter
name (str): Name of the calling layer
original_tensor (TRTTensor): A TRTTensor to set the new value to
index (int): The index to set the new value
new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to set

Returns:
TRTTensor: A new TRTTensor that is the result of setting the new value to the original tensor
"""
if isinstance(new_value, (int, float)):
new_value = np.array([new_value])
new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)

len_original_tensor = original_tensor.shape[0]
index = get_positive_dim(index, len_original_tensor)

front_tensor = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_front",
original_tensor,
dim=0,
start=0,
stop=index,
step=1,
)
rear_tensor = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_rear",
original_tensor,
dim=0,
start=index + 1,
stop=len_original_tensor,
step=1,
)

ans = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_concat",
[front_tensor, new_value, rear_tensor],
0,
)
return ans
Loading
Loading