Skip to content

[WIP] feat: support 1d ITensor offsets for embedding_bag converter #2676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/scripts/install-torch-tensorrt.sh
Original file line number Diff line number Diff line change
@@ -4,6 +4,6 @@ set -eou pipefail
source ${BUILD_ENV_FILE}
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision pyyaml
export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.nvidia.com

echo -e "Running test script";
echo -e "Running test script";
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -119,7 +119,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
- Libtorch 2.3.0.dev (latest nightly) (built with CUDA 12.1)
- CUDA 12.1
- cuDNN 8.9.5
- TensorRT 8.6.1
- TensorRT 9.2.0

## Prebuilt Binaries and Wheel files

6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -81,10 +81,10 @@ http_archive(
http_archive(
name = "tensorrt",
build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "0f8157a5fc5329943b338b893591373350afa90ca81239cdadd7580cd1eba254",
strip_prefix = "TensorRT-8.6.1.6",
sha256 = "3dd505a9e0d0adf9257080b543f51d91df736dbd1f75417b9dde1a7b7a5d87f2",
strip_prefix = "tensorrt-9.2.0.5",
urls = [
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/8.6.1/tars/TensorRT-8.6.1.6.Linux.x86_64-gnu.cuda-12.0.tar.gz",
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz",
],
)

2 changes: 1 addition & 1 deletion dev_dep_versions.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__: "2.3.0.dev0"
__cuda_version__: "12.1"
__cudnn_version__: "8.9"
__tensorrt_version__: "8.6"
__tensorrt_version__: "9.2.0.post12.dev5"
4 changes: 2 additions & 2 deletions docker/README.md
Original file line number Diff line number Diff line change
@@ -17,14 +17,14 @@ Note: By default the container uses the `pre-cxx11-abi` version of Torch + Torch

### Instructions

- The example below uses CUDNN 8.9 and TensorRT 8.6
- The example below uses CUDNN 8.9 and TensorRT 9.2
- See <a href="https://github.com/pytorch/TensorRT#dependencies">dependencies</a> for a list of current default dependencies.

> From root of Torch-TensorRT repo
Build:
```
DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=8.6 --build-arg CUDNN_VERSION=8.9 -f docker/Dockerfile -t torch_tensorrt:latest .
DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=9.2 --build-arg CUDNN_VERSION=8.9 -f docker/Dockerfile -t torch_tensorrt:latest .
```

Run:
24 changes: 22 additions & 2 deletions packaging/pre_build_script.sh
Original file line number Diff line number Diff line change
@@ -2,15 +2,35 @@

# Install dependencies
python3 -m pip install pyyaml
TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()")
CUDNN_VERSION=$(python3 -c "import versions; print(versions.__cudnn_version__.split('.')[0])")
yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
yum check-update
yum install -y ninja-build gettext tensorrt-${TRT_VERSION}.*
yum install -y ninja-build gettext libcudnn${CUDNN_VERSION} libcudnn${CUDNN_VERSION}-devel
wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \
&& mv bazelisk-linux-amd64 /usr/bin/bazel \
&& chmod +x /usr/bin/bazel

wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz
mkdir -p /usr/tensorrt
tar -xzvf tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz -C /usr/tensorrt --strip-components=1
mkdir -p /usr/lib
cp /usr/tensorrt/lib/* /usr/lib/ || :
mkdir -p /usr/lib64
cp /usr/tensorrt/lib/* /usr/lib64/ || :
mkdir -p /usr/include
cp /usr/tensorrt/include/* /usr/include/ || :

mkdir -p /usr/lib/x86_64-linux-gnu
cp /usr/tensorrt/targets/x86_64-linux-gnu/lib/* /usr/lib/x86_64-linux-gnu/ || :
mkdir -p /usr/include/x86_64-linux-gnu
cp /usr/tensorrt/targets/x86_64-linux-gnu/include/* /usr/include/x86_64-linux-gnu/ || :

rm tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz
rm -rf /usr/tensorrt

export TORCH_BUILD_NUMBER=$(python -c "import torch, urllib.parse as ul; print(ul.quote_plus(torch.__version__))")

cat toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl | envsubst > WORKSPACE
export CI_BUILD=1

python -m pip config set global.extra-index-url "https://pypi.nvidia.com"
4 changes: 3 additions & 1 deletion py/ci/soname_excludes.params
Original file line number Diff line number Diff line change
@@ -24,16 +24,18 @@
--exclude libcudart.so.11
--exclude libcudart.so.11.7.60
--exclude libnvrtc.so.11.2
--exclude libnvinfer_plugin.so.9
--exclude libnvinfer_plugin.so.8
--exclude libcublas.so.11
--exclude libcuda.so.1
--exclude libcuda.so.515
--exclude libcublasLt.so.11
--exclude libnvinfer.so.9
--exclude libnvinfer.so.8
--exclude libcudnn.so.8
--exclude libcublas.so.12
--exclude libcublasLt.so.12
--exclude libcublas.so.12.1.3.1
--exclude libcublasLt.so.12.1.3.1
--exclude libcudart.so.11.8.89
--exclude libcudart.so.11
--exclude libcudart.so.11
4 changes: 2 additions & 2 deletions py/requirements.txt
Original file line number Diff line number Diff line change
@@ -4,6 +4,6 @@ pybind11==2.6.2
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch>=2.3.0.dev,<2.4.0
torchvision>=0.18.0.dev,<0.19.0
--extra-index-url https://pypi.ngc.nvidia.com
tensorrt==8.6.1
--extra-index-url https://pypi.nvidia.com
tensorrt==9.2.0.post12.dev5
pyyaml
43 changes: 10 additions & 33 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -229,26 +229,7 @@ def aten_ops_cat(
)


def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

if scale_grad_by_freq is not None:
_LOGGER.debug(
f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}."
)
return False

if sparse is not None:
_LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.")
return False

return True


@dynamo_tensorrt_converter(
torch.ops.aten.embedding.default, capability_validator=embedding_param_validator
)
@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
def aten_ops_embedding(
ctx: ConversionContext,
target: Target,
@@ -263,22 +244,18 @@ def aten_ops_embedding(
name,
input=args[1],
weight=args[0],
# args[2] is the padding index, which is useful for training only
scale_grad_by_freq=args_bounds_check(args, 3),
sparse=args_bounds_check(args, 4),
# args[2, 3, 4] are useful for training only
padding_idx=args_bounds_check(args, 2, -1),
scale_grad_by_freq=args_bounds_check(args, 3, False),
sparse=args_bounds_check(args, 4, False),
)


def embedding_bag_validator(node: Node) -> bool:
mode = args_bounds_check(node.args, 4, 0)
indices = node.args[1].meta.get("tensor_meta")
if indices is None:
return False
return (
bool(node.args[2].op == "get_attr")
and (mode == 0 or mode == 1 or mode == 2)
and len(indices.shape) == 1
)
return len(indices.shape) == 1 # currently only support 1D indices


@dynamo_tensorrt_converter(
@@ -291,7 +268,6 @@ def embedding_bag_validator(node: Node) -> bool:
{
0: (TRTTensor,),
1: (TRTTensor,),
2: (np.ndarray, torch.Tensor),
}
)
def aten_ops_embedding_bag(
@@ -309,12 +285,13 @@ def aten_ops_embedding_bag(
weight=args[0],
indices=args[1],
offsets=args[2],
scale_grad_by_freq=args_bounds_check(args, 3, False),
mode=args_bounds_check(args, 4, 0),
sparse=args_bounds_check(args, 5, False),
per_sample_weights=args_bounds_check(args, 6, None),
include_last_offset=args_bounds_check(args, 7, False),
# padding index is useful for training only
# scale_grad_by_freq, sparse, and padding_idx are useful for training only
scale_grad_by_freq=args_bounds_check(args, 3, False),
sparse=args_bounds_check(args, 5, False),
padding_idx=args_bounds_check(args, 8, -1),
)


71 changes: 71 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -541,3 +542,73 @@ def flatten_dims(
new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])

return new_shape


def append(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
original_tensor: TRTTensor,
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
) -> TRTTensor:
if isinstance(new_value, (int, float)):
new_value = np.array([new_value])

return impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_concat",
[original_tensor, new_value],
0,
)


def set_item(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
original_tensor: TRTTensor,
index: int,
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
) -> TRTTensor:
if isinstance(new_value, (int, float)):
new_value = np.array([new_value])

len_original_tensor = original_tensor.shape[0]
index = get_positive_dim(index, len_original_tensor)

front_tensor = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_front",
original_tensor,
dim=0,
start=0,
stop=index,
step=1,
)
rear_tensor = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_rear",
original_tensor,
dim=0,
start=index + 1,
stop=len_original_tensor,
step=1,
)

ans = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_concat",
[front_tensor, new_value, rear_tensor],
0,
)
return ans
350 changes: 282 additions & 68 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@ requires = [
"cffi>=1.15.1",
"typing-extensions>=4.7.0",
"future>=0.18.3",
"tensorrt>=8.6,<8.7",
"torch >=2.3.0.dev,<2.4.0",
"tensorrt==9.2.0.post12.dev5",
"pybind11==2.6.2",
"numpy",
]
@@ -42,7 +42,7 @@ requires-python = ">=3.8"
keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"]
dependencies = [
"torch >=2.3.0.dev,<2.4.0",
"tensorrt>=8.6,<8.7",
"tensorrt==9.2.0.post12.dev5",
"packaging>=23",
"numpy",
"typing-extensions>=4.7.0",
6 changes: 4 additions & 2 deletions tests/py/dynamo/conversion/test_embedding_aten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
@@ -14,18 +13,21 @@ class TestEmbeddingConverter(DispatchTestCase):
test_name="1d_indices",
indices_tensor=torch.tensor([3, 1, 2], dtype=torch.int32),
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=False,
),
param(
test_name="2d_indices",
indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]], dtype=torch.int32),
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=True,
),
param(
test_name="3d_indices",
indices_tensor=torch.tensor(
[[[0, 1], [2, 3]], [[3, 4], [4, 0]]], dtype=torch.int32
),
weights_tensor=torch.randn((5, 10), dtype=torch.float32),
sparse=True,
),
]
)
@@ -38,7 +40,7 @@ def test_embedding(
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=None,
sparse=None,
sparse=False,
):
class TestEmbedding(torch.nn.Module):
def forward(self, indices, weights):
236 changes: 223 additions & 13 deletions tests/py/dynamo/conversion/test_embedding_bag_aten.py
Original file line number Diff line number Diff line change
@@ -8,43 +8,252 @@
class TestEmbeddingBagConverter(DispatchTestCase):
@parameterized.expand(
[
# mode=0: sum, mode=1: mean, mode=2: max
# 1D input
param(
test_name="1d_indices_1",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
offsets=torch.tensor([0, 3], dtype=torch.int32),
weight=torch.randn((10, 2), dtype=torch.float32),
indices=torch.tensor(
[1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=True,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_2",
weight=torch.randn((10, 2), dtype=torch.float32),
indices=torch.tensor(
[1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
scale_grad_by_freq=False,
mode=1,
sparse=True,
per_sample_weights=None,
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_3",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=2,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_4",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=torch.randn((8,)),
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_2",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
offsets=torch.tensor([0, 5], dtype=torch.int32),
test_name="1d_indices_5",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
scale_grad_by_freq=False,
mode=1,
sparse=False,
per_sample_weights=None,
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_6",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
scale_grad_by_freq=False,
mode=2,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_7",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=torch.randn((6,)),
per_sample_weights=None,
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_8",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=1,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
]
)
def test_embedding_bag_with_traversable_offsets(
self,
test_name,
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
class TestEmbeddingBag(torch.nn.Module):
def forward(self, weight, indices):
return torch.ops.aten._embedding_bag.default(
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)[0]

self.run_test(
TestEmbeddingBag(),
inputs=[weight, indices],
# use_dynamo_tracer=True,
enable_passes=True,
)

@parameterized.expand(
[
# mode=0: sum, mode=1: mean, mode=2: max
# 1D input
param(
test_name="1d_indices_1",
weight=torch.randn((10, 2), dtype=torch.float32),
indices=torch.tensor(
[1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=True,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
# TODO: BUG! outputs of tensor and ITensor not matched
# param(
# test_name="1d_indices_2",
# # weight=torch.randn((10, 2), dtype=torch.float32),
# # indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32),
# weight=torch.arange(12, dtype=torch.float32).reshape(6, 2),
# indices=torch.tensor([0,1, 2,3, 4,5], dtype=torch.int32),
# offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=0,
# sparse=True,
# per_sample_weights=None,
# include_last_offset=True,
# padding_idx=-1,
# ),
param(
test_name="1d_indices_3",
weight=torch.randn((10, 3), dtype=torch.float32),
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=2,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_4",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=torch.randn((8,)),
include_last_offset=True,
padding_idx=-1,
),
# TODO: BUG! outputs of tensor and ITensor not matched
# param(
# test_name="1d_indices_5",
# # weight=torch.randn((10, 4), dtype=torch.float32),
# # indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
# weight=torch.arange(12, dtype=torch.float32).reshape(6, 2),
# indices=torch.tensor([0,1,2, 3,4,5], dtype=torch.int32),
# offsets=torch.tensor([0, 3, 3], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=1,
# sparse=False,
# per_sample_weights=None,
# include_last_offset=True,
# padding_idx=-1,
# ),
param(
test_name="1d_indices_6",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
scale_grad_by_freq=False,
mode=2,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_7",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=None,
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_8",
weight=torch.randn((10, 4), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
scale_grad_by_freq=False,
mode=1,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
# 2D input
# param(
# test_name="2d_indices_1",
@@ -103,7 +312,7 @@ class TestEmbeddingBagConverter(DispatchTestCase):
# ),
]
)
def test_embedding_bag(
def test_embedding_bag_with_ITensor_offsets(
self,
test_name,
weight,
@@ -117,7 +326,7 @@ def test_embedding_bag(
padding_idx,
):
class TestEmbeddingBag(torch.nn.Module):
def forward(self, weight, indices):
def forward(self, weight, indices, offsets):
return torch.ops.aten._embedding_bag.default(
weight,
indices,
@@ -132,7 +341,8 @@ def forward(self, weight, indices):

self.run_test(
TestEmbeddingBag(),
inputs=[weight, indices],
inputs=[weight, indices, offsets],
# use_dynamo_tracer=True,
enable_passes=True,
)

Binary file modified tests/py/ts/models/hw_compat.ts
Binary file not shown.
4 changes: 2 additions & 2 deletions third_party/cudnn/local/BUILD
Original file line number Diff line number Diff line change
@@ -28,8 +28,8 @@ config_setting(

cc_library(
name = "cudnn_headers",
hdrs = ["include/cudnn.h"] + glob([
"include/cudnn_*.h",
hdrs = glob([
"include/cudnn*.h",
]),
includes = ["include/"],
visibility = ["//visibility:private"],
4 changes: 1 addition & 3 deletions third_party/tensorrt/local/BUILD
Original file line number Diff line number Diff line change
@@ -40,9 +40,7 @@ cc_library(
"include/aarch64-linux-gnu/NvInferPluginUtils.h",
],
),
":ci_rhel_x86_64_linux": [
"include/NvUtils.h",
] + glob(
":ci_rhel_x86_64_linux": glob(
[
"include/NvInfer*.h",
],
4 changes: 2 additions & 2 deletions toolchains/legacy/pyproject.toml
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ requires = [
"cffi>=1.15.1",
"typing-extensions>=4.7.0",
"future>=0.18.3",
"tensorrt>=8.6,<8.7",
"tensorrt==9.2.0.post12.dev5",
"torch>=1.13.0,<2.0",
"pybind11==2.6.2",
"numpy",
@@ -42,7 +42,7 @@ requires-python = ">=3.8"
keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"]
dependencies = [
"torch>=1.13.0,<2.0",
"tensorrt>=8.6,<8.7",
"tensorrt==9.2.0.post12.dev5",
"packaging>=23",
"numpy",
"typing-extensions>=4.7.0",