Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ac3ab77

Browse files
gs-oliveperi044
andauthoredMay 12, 2023
fix: Upgrade main to TRT 8.6, CUDA 11.8, CuDNN 8.8, Torch Dev (#1852)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: Dheeraj Peri <[email protected]>
1 parent 39585b1 commit ac3ab77

33 files changed

+1697
-1579
lines changed
 

‎.circleci/config.yml

Lines changed: 120 additions & 59 deletions
Large diffs are not rendered by default.

‎README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
116116
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
117117

118118
- Bazel 5.2.0
119-
- Libtorch 2.1.0.dev20230314 (built with CUDA 11.7)
120-
- CUDA 11.7
121-
- cuDNN 8.5.0
122-
- TensorRT 8.5.1.7
119+
- Libtorch 2.1.0.dev20230419 (built with CUDA 11.8)
120+
- CUDA 11.8
121+
- cuDNN 8.8.0
122+
- TensorRT 8.6.0
123123

124124
## Prebuilt Binaries and Wheel files
125125

@@ -247,7 +247,7 @@ A tarball with the include files and library can then be found in bazel-bin
247247
### Running Torch-TensorRT on a JIT Graph
248248

249249
> Make sure to add LibTorch to your LD_LIBRARY_PATH <br>
250-
> `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)/bazel-Torch-TensorRT/external/libtorch/lib`
250+
> `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)/bazel-TensorRT/external/libtorch/lib`
251251
252252
``` shell
253253
bazel run //cpp/bin/torchtrtc -- $(realpath <PATH TO GRAPH>) out.ts <input-size>

‎WORKSPACE

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,32 +41,27 @@ local_repository(
4141
new_local_repository(
4242
name = "cuda",
4343
build_file = "@//third_party/cuda:BUILD",
44-
path = "/usr/local/cuda-11.7/",
44+
path = "/usr/local/cuda-11.8/",
4545
)
4646

47-
new_local_repository(
48-
name = "cublas",
49-
build_file = "@//third_party/cublas:BUILD",
50-
path = "/usr",
51-
)
5247
#############################################################################################################
5348
# Tarballs and fetched dependencies (default - use in cases when building from precompiled bin and tarballs)
5449
#############################################################################################################
5550

5651
http_archive(
5752
name = "libtorch",
5853
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "7c4b8754830fef23ec19c5eaf414794cee9597b435df055f5c1d0471d3e81568",
54+
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
6055
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
56+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
6257
)
6358

6459
http_archive(
6560
name = "libtorch_pre_cxx11_abi",
6661
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "f1e64a75dd12d0ba4c8c1f61947299e0a9c50684dff64f0cfbf355aa7a13e8cf",
62+
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
6863
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
64+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
7065
)
7166

7267
# Download these tarballs manually from the NVIDIA website
@@ -76,20 +71,20 @@ http_archive(
7671
http_archive(
7772
name = "cudnn",
7873
build_file = "@//third_party/cudnn/archive:BUILD",
79-
sha256 = "5454a6fd94f008728caae9adad993c4e85ef36302e26bce43bea7d458a5e7b6d",
80-
strip_prefix = "cudnn-linux-x86_64-8.5.0.96_cuda11-archive",
74+
sha256 = "36fff137153ef73e6ee10bfb07f4381240a86fb9fb78ce372414b528cbab2293",
75+
strip_prefix = "cudnn-linux-x86_64-8.8.0.121_cuda11-archive",
8176
urls = [
82-
"https://developer.nvidia.com/compute/cudnn/secure/8.5.0/local_installers/11.7/cudnn-linux-x86_64-8.5.0.96_cuda11-archive.tar.xz",
77+
"https://developer.download.nvidia.com/compute/cudnn/secure/8.8.0/local_installers/11.8/cudnn-linux-x86_64-8.8.0.121_cuda11-archive.tar.xz",
8378
],
8479
)
8580

8681
http_archive(
8782
name = "tensorrt",
8883
build_file = "@//third_party/tensorrt/archive:BUILD",
89-
sha256 = "39cc7f077057d1363794e8ff51c4cf21a5dbeccf1116b0020ba0dae0f3063076",
90-
strip_prefix = "TensorRT-8.5.1.7",
84+
sha256 = "c1732a1093c57ab79fa0b687f061be369e449c9c17792b660f3663ecd8fa7b63",
85+
strip_prefix = "TensorRT-8.6.0.12",
9186
urls = [
92-
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.5.1/tars/TensorRT-8.5.1.7.Linux.x86_64-gnu.cuda-11.8.cudnn8.6.tar.gz",
87+
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/8.6.0/tars/TensorRT-8.6.0.12.Linux.x86_64-gnu.cuda-11.8.tar.gz",
9388
],
9489
)
9590

‎core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ TRTEngine::TRTEngine(
150150
}
151151

152152
TRTEngine::~TRTEngine() {
153-
rt.reset();
154153
trt_engine_profiler.reset();
155154
exec_ctx.reset();
156155
cuda_engine.reset();
156+
rt.reset();
157157
}
158158

159159
void TRTEngine::disable_profiling() {

‎cpp/include/torch_tensorrt/macros.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#define STR(x) XSTR(x)
2525

2626
#define TORCH_TENSORRT_MAJOR_VERSION 1
27-
#define TORCH_TENSORRT_MINOR_VERSION 3
27+
#define TORCH_TENSORRT_MINOR_VERSION 5
2828
#define TORCH_TENSORRT_PATCH_VERSION 0
2929
#define TORCH_TENSORRT_VERSION \
3030
STR(TORCH_TENSORRT_MAJOR_VERSION) \

‎docker/Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Base image starts with CUDA
2-
ARG BASE_IMG=nvidia/cuda:11.7.1-devel-ubuntu22.04
2+
ARG BASE_IMG=nvidia/cuda:11.8.0-devel-ubuntu22.04
33
FROM ${BASE_IMG} as base
4-
ENV BASE_IMG=nvidia/cuda:11.7.1-devel-ubuntu22.04
4+
ENV BASE_IMG=nvidia/cuda:11.8.0-devel-ubuntu22.04
55

66
ARG TENSORRT_VERSION
77
ENV TENSORRT_VERSION=${TENSORRT_VERSION}
8-
RUN test -n "$TENSORRT_VERSION" || (echo "No tensorrt version specified, please use --build-arg TENSORRT_VERSION=x.y.z to specify a version." && exit 1)
8+
RUN test -n "$TENSORRT_VERSION" || (echo "No tensorrt version specified, please use --build-arg TENSORRT_VERSION=x.y to specify a version." && exit 1)
99
ARG CUDNN_VERSION
1010
ENV CUDNN_VERSION=${CUDNN_VERSION}
11-
RUN test -n "$CUDNN_VERSION" || (echo "No cudnn version specified, please use --build-arg CUDNN_VERSION=x.y.z to specify a version." && exit 1)
11+
RUN test -n "$CUDNN_VERSION" || (echo "No cudnn version specified, please use --build-arg CUDNN_VERSION=x.y to specify a version." && exit 1)
1212

1313
ARG PYTHON_VERSION=3.10
1414
ENV PYTHON_VERSION=${PYTHON_VERSION}

‎docker/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
* The `Dockerfile` currently uses <a href="https://github.com/bazelbuild/bazelisk">Bazelisk</a> to select the Bazel version, and uses the exact library versions of Torch and CUDA listed in <a href="https://github.com/pytorch/TensorRT#dependencies">dependencies</a>.
66
* The desired versions of CUDNN and TensorRT must be specified as build-args, with major and minor versions as in: `--build-arg TENSORRT_VERSION=a.b --build-arg CUDNN_VERSION=x.y`
7-
* [**Optional**] The desired base image be changed by explicitly setting a base image, as in `--build-arg BASE_IMG=nvidia/cuda:11.7.1-devel-ubuntu22.04`, though this is optional
7+
* [**Optional**] The desired base image be changed by explicitly setting a base image, as in `--build-arg BASE_IMG=nvidia/cuda:11.8.0-devel-ubuntu22.04`, though this is optional
88
* [**Optional**] Additionally, the desired Python version can be changed by explicitly setting a version, as in `--build-arg PYTHON_VERSION=3.10`, though this is optional as well.
99

1010
* This `Dockerfile` installs `pre-cxx11-abi` versions of Pytorch and builds Torch-TRT using `pre-cxx11-abi` libtorch as well.
@@ -17,14 +17,14 @@ Note: By default the container uses the `pre-cxx11-abi` version of Torch + Torch
1717

1818
### Instructions
1919

20-
- The example below uses CUDNN 8.5 and TensorRT 8.5
20+
- The example below uses CUDNN 8.8 and TensorRT 8.6
2121
- See <a href="https://github.com/pytorch/TensorRT#dependencies">dependencies</a> for a list of current default dependencies.
2222

2323
> From root of Torch-TensorRT repo
2424
2525
Build:
2626
```
27-
DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=8.5 --build-arg CUDNN_VERSION=8.5 -f docker/Dockerfile -t torch_tensorrt:latest .
27+
DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=8.6 --build-arg CUDNN_VERSION=8.8 -f docker/Dockerfile -t torch_tensorrt:latest .
2828
```
2929

3030
Run:

‎py/ci/build_whl.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Example usage: docker run -it -v$(pwd)/..:/workspace/TRTorch build_trtorch_wheel /bin/bash /workspace/TRTorch/py/build_whl.sh
44

55
export CXX=g++
6-
export CUDA_HOME=/usr/local/cuda-11.7
6+
export CUDA_HOME=/usr/local/cuda-11.8
77
export PROJECT_DIR=/workspace/project
88

99
cp -r $CUDA_HOME /usr/local/cuda
@@ -108,4 +108,4 @@ libtorchtrt_pre_cxx11_abi() {
108108
CUDNN_VERSION=$(cd ${PROJECT_DIR}/py && ${PY_DIR}/bin/python3 -c "from versions import __cudnn_version__;print(__cudnn_version__)")
109109
TORCH_VERSION=$(${PY_DIR}/bin/python -c "from torch import __version__;print(__version__.split('+')[0])")
110110
cp ${PROJECT_DIR}/bazel-bin/libtorchtrt.tar.gz ${PROJECT_DIR}/py/wheelhouse/libtorchtrt-${TORCHTRT_VERSION}-pre-cxx11-abi-cudnn${CUDNN_VERSION}-tensorrt${TRT_VERSION}-cuda${CUDA_VERSION}-libtorch${TORCH_VERSION}-x86_64-linux.tar.gz
111-
}
111+
}

‎py/requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy
22
packaging
33
pybind11==2.6.2
4-
--extra-index-url https://download.pytorch.org/whl/nightly/cu117
5-
torch==2.1.0.dev20230314+cu117
6-
torchvision==0.15.0.dev20230314+cu117
4+
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
5+
torch==2.1.0.dev20230419+cu118
6+
torchvision==0.16.0.dev20230419+cu118
77
--extra-index-url https://pypi.ngc.nvidia.com
8-
tensorrt==8.5.1.7
8+
tensorrt==8.6.0

‎py/setup.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
FX_ONLY = False
3333

34+
LEGACY = False
35+
3436
RELEASE = False
3537

3638
CI_RELEASE = False
@@ -48,6 +50,10 @@ def get_git_revision_short_hash() -> str:
4850
FX_ONLY = True
4951
sys.argv.remove("--fx-only")
5052

53+
if "--legacy" in sys.argv:
54+
LEGACY = True
55+
sys.argv.remove("--legacy")
56+
5157
if "--release" not in sys.argv:
5258
__version__ = __version__ + "+" + get_git_revision_short_hash()
5359
else:
@@ -420,7 +426,7 @@ def run(self):
420426
long_description=long_description,
421427
ext_modules=ext_modules,
422428
install_requires=[
423-
"torch>=1.13.1",
429+
"torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0",
424430
],
425431
setup_requires=[],
426432
cmdclass={
@@ -449,7 +455,7 @@ def run(self):
449455
"Topic :: Software Development",
450456
"Topic :: Software Development :: Libraries",
451457
],
452-
python_requires=">=3.7",
458+
python_requires=">=3.8",
453459
include_package_data=True,
454460
package_data={
455461
"torch_tensorrt": package_data_list,

‎py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _find_lib(name, paths):
5959

6060
elif sys.platform.startswith("linux"):
6161
LINUX_PATHS = [
62-
"/usr/local/cuda/lib64",
62+
"/usr/local/cuda-11.8/lib64",
6363
]
6464

6565
if "LD_LIBRARY_PATH" in os.environ:

‎py/torch_tensorrt/fx/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ FX2TRT is merged as FX module in Torch-TensorRT
88
```
99
$ conda create --name python_env python=3.8
1010
$ conda activate python_env
11-
# Recommend to install PyTorch 1.12 and later
12-
$ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly
11+
# Recommend to install PyTorch 2.0 and later
12+
$ conda install pytorch torchvision torchtext cudatoolkit=11.8 -c pytorch-nightly
1313
# Install TensorRT python package
1414
$ pip3 install nvidia-pyindex
15-
$ pip3 install tensorrt==8.5.1.7
15+
$ pip3 install tensorrt==8.6.0
1616
$ git clone https://github.com/pytorch/TensorRT.git
1717
$ cd TensorRT/py && python setup.py install --fx-only && cd ..
18-
$ pyton -c "import torch_tensorrt.fx"
18+
$ python -c "import torch_tensorrt.fx"
1919
# Test an example by
2020
$ python py/torch_tensorrt/fx/example/lower_example.py
2121
```

‎py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@ def forward(self, x):
5454
apply_passes=[fuse_permute_linear],
5555
)
5656

57-
# TODO: The following test has been disabled due to a bug in TRT 8.5.1.7
58-
# with self.linear2. Issue : https://github.com/pytorch/TensorRT/issues/1444
59-
@unittest.skip(
60-
reason="test_multi_fuse_permute_linear has been disabled due to a bug in TRT 8.5.1.7 https://github.com/pytorch/TensorRT/issues/1444"
61-
)
6257
def test_multi_fuse_permute_linear(self):
6358
"""
6459
Fusion when permute output is shared by multiple linears

‎py/versions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.4.0.dev0"
2-
__cuda_version__ = "11.7"
3-
__cudnn_version__ = "8.5"
4-
__tensorrt_version__ = "8.5"
1+
__version__ = "1.5.0.dev0"
2+
__cuda_version__ = "11.8"
3+
__cudnn_version__ = "8.8"
4+
__tensorrt_version__ = "8.6"

‎pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ requires = [
99
"cffi",
1010
"typing_extensions",
1111
"future",
12-
"nvidia-pyindex",
13-
"nvidia-tensorrt==8.4.3.1"
12+
"tensorrt >=8.6,<8.7"
1413
]
1514

1615
# Use legacy backend to import local packages in setup.py
@@ -20,7 +19,7 @@ requires = [
2019
[tool.black]
2120
# Uncomment if pyproject.toml worked fine to ensure consistency with flake8
2221
# line-length = 120
23-
target-versions = ["py37", "py38", "py39", "py310"]
22+
target-versions = ["py38", "py39", "py310"]
2423
force-exclude = """
2524
elu_converter/setup.py
2625
"""

‎tests/core/conversion/converters/BUILD

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ converter_test(
9595
name = "test_matrix_multiply",
9696
)
9797

98+
converter_test(
99+
name = "test_masked_fill",
100+
)
101+
98102
converter_test(
99103
name = "test_max",
100104
)
@@ -115,6 +119,10 @@ converter_test(
115119
name = "test_reduce",
116120
)
117121

122+
converter_test(
123+
name = "test_roll",
124+
)
125+
118126
converter_test(
119127
name = "test_reflection_pad",
120128
)
@@ -123,6 +131,10 @@ converter_test(
123131
name = "test_replication_pad",
124132
)
125133

134+
converter_test(
135+
name = "test_scatter",
136+
)
137+
126138
converter_test(
127139
name = "test_shuffle",
128140
)
@@ -139,6 +151,10 @@ converter_test(
139151
name = "test_interpolate",
140152
)
141153

154+
converter_test(
155+
name = "test_index",
156+
)
157+
142158
converter_test(
143159
name = "test_select",
144160
)
@@ -147,6 +163,14 @@ converter_test(
147163
name = "test_stack",
148164
)
149165

166+
converter_test(
167+
name = "test_slice",
168+
)
169+
170+
converter_test(
171+
name = "test_split",
172+
)
173+
150174
converter_test(
151175
name = "test_topk",
152176
)
@@ -159,10 +183,22 @@ converter_test(
159183
name = "test_unsqueeze",
160184
)
161185

186+
converter_test(
187+
name = "test_unbind",
188+
)
189+
190+
converter_test(
191+
name = "test_unpack",
192+
)
193+
162194
converter_test(
163195
name = "test_squeeze",
164196
)
165197

198+
converter_test(
199+
name = "test_where",
200+
)
201+
166202
test_suite(
167203
name = "converter_tests",
168204
tests = [
@@ -185,22 +221,31 @@ test_suite(
185221
":test_expand",
186222
":test_instance_norm",
187223
":test_interpolate",
224+
":test_index",
188225
":test_layer_norm",
189226
":test_linear",
190227
":test_lstm_cell",
191228
":test_matrix_multiply",
229+
":test_masked_fill",
192230
":test_max",
193231
":test_normalize",
194232
":test_pooling",
195233
":test_reduce",
234+
":test_roll",
196235
":test_replication_pad",
236+
":test_scatter",
197237
":test_select",
198238
":test_shuffle",
199239
":test_softmax",
200240
":test_squeeze",
201241
":test_stack",
242+
":test_split",
243+
":test_slice",
202244
":test_topk",
203245
":test_unary",
204246
":test_unsqueeze",
247+
":test_unbind",
248+
":test_unpack",
249+
":test_where",
205250
],
206251
)
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
#include <torch/torch.h>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
9+
TEST(Converters, ATenIndexSelectConvertsCorrectly) {
10+
const auto graph = R"IR(
11+
graph(%0 : Tensor, %index : Int (2)):
12+
%2 : int = prim::Constant[value=0]()
13+
%3 : Tensor = aten::index_select(%0, %2, %index)
14+
return (%3))IR";
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, g.get());
17+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
18+
auto index = at::randint(0, 4, {2}, {at::kCUDA}).to(torch::kI32);
19+
20+
auto jit_in = at::clone(in);
21+
auto jit_index = at::clone(index);
22+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index});
23+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
24+
25+
auto trt_in = at::clone(in);
26+
auto trt_index = at::clone(index);
27+
auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index});
28+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in});
29+
30+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
31+
32+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
33+
}
34+
35+
TEST(Converters, ATenIndexSelectNegativeDimConvertsCorrectly) {
36+
const auto graph = R"IR(
37+
graph(%0 : Tensor, %index : Int (5)):
38+
%2 : int = prim::Constant[value=-1]()
39+
%3 : Tensor = aten::index_select(%0, %2, %index)
40+
return (%3))IR";
41+
auto g = std::make_shared<torch::jit::Graph>();
42+
43+
torch::jit::parseIR(graph, g.get());
44+
45+
auto in = at::randint(1, 10, {5, 3, 9}, {at::kCUDA});
46+
auto index = at::randint(0, 9, {5}, {at::kCUDA}).to(torch::kI32);
47+
48+
auto jit_in = at::clone(in);
49+
auto jit_index = at::clone(index);
50+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index});
51+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
52+
53+
auto trt_in = at::clone(in);
54+
auto trt_index = at::clone(index);
55+
auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index});
56+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in});
57+
58+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
59+
60+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
61+
}
62+
63+
TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) {
64+
const auto graph = R"IR(
65+
graph(%x.1 : Tensor,
66+
%index : Tensor):
67+
%18 : Tensor?[] = prim::ListConstruct(%index)
68+
%19 : Tensor = aten::index(%x.1, %18)
69+
return (%19))IR";
70+
71+
auto g = std::make_shared<torch::jit::Graph>();
72+
torch::jit::parseIR(graph, g.get());
73+
74+
auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA});
75+
auto in2 = at::full({2}, 4, {at::kCUDA});
76+
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
77+
auto in2_trt = at::full({2}, 4, {options});
78+
79+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
80+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
81+
82+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
83+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt});
84+
85+
ASSERT_TRUE(
86+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
87+
}
88+
89+
TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) {
90+
const auto graph = R"IR(
91+
graph(%x.1 : Tensor,
92+
%index0 : Tensor,
93+
%index1 : Tensor,
94+
%index2 : Tensor):
95+
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2)
96+
%19 : Tensor = aten::index(%x.1, %18)
97+
return (%19))IR";
98+
99+
auto g = std::make_shared<torch::jit::Graph>();
100+
torch::jit::parseIR(graph, g.get());
101+
102+
auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
103+
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
104+
auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong);
105+
auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong);
106+
auto index0_trt = index0.to(torch::kInt32);
107+
auto index1_trt = index1.to(torch::kInt32);
108+
auto index2_trt = index2.to(torch::kInt32);
109+
110+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
111+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});
112+
113+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
114+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});
115+
116+
ASSERT_TRUE(
117+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
118+
}
119+
120+
TEST(Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) {
121+
const auto graph = R"IR(
122+
graph(%x.1 : Tensor,
123+
%index0 : Tensor,
124+
%index1 : Tensor,
125+
%index2 : Tensor):
126+
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2)
127+
%19 : Tensor = aten::index(%x.1, %18)
128+
%20 : Tensor = aten::index(%x.1, %18)
129+
return (%19, %20))IR";
130+
131+
auto g = std::make_shared<torch::jit::Graph>();
132+
torch::jit::parseIR(graph, g.get());
133+
134+
auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
135+
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
136+
auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong);
137+
auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong);
138+
auto index0_trt = index0.to(torch::kInt32);
139+
auto index1_trt = index1.to(torch::kInt32);
140+
auto index2_trt = index2.to(torch::kInt32);
141+
142+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
143+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});
144+
145+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
146+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});
147+
148+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
149+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6));
150+
}
151+
152+
TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) {
153+
const auto graph = R"IR(
154+
graph(%x.1 : Tensor,
155+
%index0 : Tensor,
156+
%index1 : Tensor):
157+
%5 : NoneType = prim::Constant()
158+
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %5)
159+
%19 : Tensor = aten::index(%x.1, %18)
160+
return (%19))IR";
161+
162+
auto g = std::make_shared<torch::jit::Graph>();
163+
torch::jit::parseIR(graph, g.get());
164+
165+
auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
166+
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
167+
auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong);
168+
auto index0_trt = index0.to(torch::kInt32);
169+
auto index1_trt = index1.to(torch::kInt32);
170+
171+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
172+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1});
173+
174+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
175+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt});
176+
LOG_DEBUG(trt_results);
177+
178+
ASSERT_TRUE(
179+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
180+
}
181+
182+
TEST(Converters, ATenIndexTensorIdx0NoneIdx1ConvertsCorrectly) {
183+
const auto graph = R"IR(
184+
graph(%x.1 : Tensor,
185+
%index0 : Tensor,
186+
%index1 : Tensor):
187+
%5 : NoneType = prim::Constant()
188+
%18 : Tensor?[] = prim::ListConstruct(%index0, %5, %index1)
189+
%19 : Tensor = aten::index(%x.1, %18)
190+
return (%19))IR";
191+
192+
auto g = std::make_shared<torch::jit::Graph>();
193+
torch::jit::parseIR(graph, g.get());
194+
195+
auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
196+
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
197+
auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong);
198+
auto index0_trt = index0.to(torch::kInt32);
199+
auto index1_trt = index1.to(torch::kInt32);
200+
201+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
202+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1});
203+
204+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
205+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt});
206+
207+
ASSERT_TRUE(
208+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
209+
}
210+
211+
TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) {
212+
const auto graph = R"IR(
213+
graph(%x.1 : Tensor,
214+
%index0 : Tensor,
215+
%index1 : Tensor):
216+
%5 : NoneType = prim::Constant()
217+
%18 : Tensor?[] = prim::ListConstruct(%5, %index0, %index1)
218+
%19 : Tensor = aten::index(%x.1, %18)
219+
return (%19))IR";
220+
221+
auto g = std::make_shared<torch::jit::Graph>();
222+
torch::jit::parseIR(graph, g.get());
223+
224+
auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA});
225+
auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong);
226+
auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong);
227+
auto index0_trt = index0.to(torch::kInt32);
228+
auto index1_trt = index1.to(torch::kInt32);
229+
230+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
231+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1});
232+
233+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
234+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt});
235+
236+
ASSERT_TRUE(
237+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
238+
}
239+
240+
TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
241+
const auto graph = R"IR(
242+
graph(%x.1 : Tensor,
243+
%index0 : Tensor,
244+
%index1 : Tensor,
245+
%index2 : Tensor):
246+
%5 : NoneType = prim::Constant()
247+
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5)
248+
%19 : Tensor = aten::index(%x.1, %18)
249+
return (%19))IR";
250+
251+
auto g = std::make_shared<torch::jit::Graph>();
252+
torch::jit::parseIR(graph, g.get());
253+
254+
auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA});
255+
auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong);
256+
auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong);
257+
auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong);
258+
auto index0_trt = index0.to(torch::kInt32);
259+
auto index1_trt = index1.to(torch::kInt32);
260+
auto index2_trt = index2.to(torch::kInt32);
261+
262+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
263+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});
264+
265+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
266+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});
267+
268+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
269+
}
270+
271+
TEST(Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) {
272+
const auto graph = R"IR(
273+
graph(%x.1 : Tensor,
274+
%index0 : Tensor):
275+
%5 : NoneType = prim::Constant()
276+
%18 : Tensor?[] = prim::ListConstruct(%5, %index0)
277+
%19 : Tensor = aten::index(%x.1, %18)
278+
return (%19))IR";
279+
280+
auto g = std::make_shared<torch::jit::Graph>();
281+
torch::jit::parseIR(graph, g.get());
282+
283+
auto in1 = at::randint(1, 10, {1, 3, 480, 928}, {at::kCUDA});
284+
auto index0 = at::tensor({2, 1, 0}, {at::kCUDA}).to(torch::kLong);
285+
286+
auto index0_trt = index0.to(torch::kInt32);
287+
288+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
289+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0});
290+
291+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
292+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt});
293+
294+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
295+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include <torch/torch.h>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
9+
TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
10+
const auto graph = R"IR(
11+
graph(%x.1 : Tensor):
12+
%44 : Device = prim::Constant[value="cuda"]()
13+
%8 : bool = prim::Constant[value=0]()
14+
%7 : None = prim::Constant()
15+
%f32_dtype: int = prim::Constant[value=11]()
16+
%1 : int = prim::Constant[value=0]() # bert.py:5:26
17+
%2 : int = prim::Constant[value=1]() # bert.py:5:32
18+
%33 : int = prim::Constant[value=2]() # bert.py:6:31
19+
%3 : int[] = prim::ListConstruct(%1, %1, %2)
20+
%4 : int[] = prim::ListConstruct(%2, %2, %1)
21+
%5 : int[][] = prim::ListConstruct(%3, %4)
22+
%9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11
23+
%mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11
24+
%mask.2 : Tensor = trt::const(%mask.1)
25+
%34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11
26+
return (%34, %mask.2))IR";
27+
28+
auto g = std::make_shared<torch::jit::Graph>();
29+
30+
torch::jit::parseIR(graph, &*g);
31+
32+
auto in = at::zeros({1, 2, 3}, {at::kCUDA});
33+
34+
auto jit_in = at::clone(in);
35+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
36+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
37+
38+
auto trt_in = at::clone(in);
39+
torch_tensorrt::core::lowering::passes::RemoveNOPs(g);
40+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
41+
42+
ASSERT_TRUE(
43+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
44+
}
45+
46+
TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) {
47+
const auto graph = R"IR(
48+
graph(%x.1 : Tensor, %x.2 : Tensor):
49+
%val : float = prim::Constant[value=4.0]()
50+
%out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
51+
return (%out))IR";
52+
53+
auto g = std::make_shared<torch::jit::Graph>();
54+
55+
torch::jit::parseIR(graph, &*g);
56+
57+
// Input is a float tensor, filled with an int --> expecting float tensor out
58+
auto in1 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kFloat32);
59+
auto in2 = (2 * at::rand({2, 3, 5, 7}, {at::kCUDA})).to(torch::kBool);
60+
61+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
62+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
63+
64+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
65+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
66+
67+
ASSERT_TRUE(
68+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
69+
70+
// Ensure data types match in outputs
71+
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
72+
}
73+
74+
TEST(Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) {
75+
const auto graph = R"IR(
76+
graph(%x.1 : Tensor, %x.2 : Tensor):
77+
%val : int = prim::Constant[value=4]()
78+
%out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
79+
return (%out))IR";
80+
81+
auto g = std::make_shared<torch::jit::Graph>();
82+
83+
torch::jit::parseIR(graph, &*g);
84+
85+
// Input is an integer tensor, filled with a float --> expecting integer tensor out
86+
auto in1 = at::rand({1, 3, 5, 7}, {at::kCUDA}).to(torch::kInt32);
87+
auto in2 = (2 * at::rand({1, 3, 5, 7}, {at::kCUDA})).to(torch::kBool);
88+
89+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
90+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
91+
92+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
93+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
94+
95+
ASSERT_TRUE(
96+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
97+
98+
// Ensure data types match in outputs
99+
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
100+
}

‎tests/core/conversion/converters/test_reduce.cpp

Lines changed: 0 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -392,240 +392,3 @@ TEST(Converters, ATenAllDimDynamicConvertsCorrectly) {
392392
auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf);
393393
test_body(graph, in, true);
394394
}
395-
396-
TEST(Converters, UnpackVarLowersCorrectly) {
397-
const auto graph = R"IR(
398-
graph(%x.1 : Tensor):
399-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
400-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
401-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
402-
%6 : int[] = prim::ListConstruct(%3)
403-
%7 : Tensor = aten::var(%x.1, %6, %5, %4) # test_zeros.py:10:26
404-
return (%7))IR";
405-
406-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
407-
408-
auto g = std::make_shared<torch::jit::Graph>();
409-
torch::jit::parseIR(graph, g.get());
410-
411-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
412-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
413-
414-
in = at::clone(in);
415-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
416-
torch::jit::EliminateCommonSubexpression(g);
417-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
418-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
419-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
420-
}
421-
422-
TEST(Converters, UnpackVarKeepDimsLowersCorrectly) {
423-
const auto graph = R"IR(
424-
graph(%x.1 : Tensor):
425-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
426-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
427-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
428-
%6 : int[] = prim::ListConstruct(%3)
429-
%7 : Tensor = aten::var(%x.1, %6, %5, %5) # test_zeros.py:10:26
430-
return (%7))IR";
431-
432-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
433-
434-
auto g = std::make_shared<torch::jit::Graph>();
435-
torch::jit::parseIR(graph, g.get());
436-
437-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
438-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
439-
440-
in = at::clone(in);
441-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
442-
torch::jit::EliminateCommonSubexpression(g);
443-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
444-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
445-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
446-
}
447-
448-
TEST(Converters, UnpackVarUnbiasedLowersCorrectly) {
449-
const auto graph = R"IR(
450-
graph(%x.1 : Tensor):
451-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
452-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
453-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
454-
%6 : int[] = prim::ListConstruct(%3)
455-
%7 : Tensor = aten::var(%x.1, %6, %4, %4) # test_zeros.py:10:26
456-
return (%7))IR";
457-
458-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
459-
460-
auto g = std::make_shared<torch::jit::Graph>();
461-
torch::jit::parseIR(graph, g.get());
462-
463-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
464-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
465-
466-
in = at::clone(in);
467-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
468-
torch::jit::EliminateCommonSubexpression(g);
469-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
470-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
471-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
472-
}
473-
474-
TEST(Converters, UnpackVarUnbiasedKeepDimsLowersCorrectly) {
475-
const auto graph = R"IR(
476-
graph(%x.1 : Tensor):
477-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
478-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
479-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
480-
%6 : int[] = prim::ListConstruct(%3)
481-
%7 : Tensor = aten::var(%x.1, %6, %4, %5) # test_zeros.py:10:26
482-
return (%7))IR";
483-
484-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
485-
486-
auto g = std::make_shared<torch::jit::Graph>();
487-
torch::jit::parseIR(graph, g.get());
488-
489-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
490-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
491-
492-
in = at::clone(in);
493-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
494-
torch::jit::EliminateCommonSubexpression(g);
495-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
496-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
497-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
498-
}
499-
500-
TEST(Converters, UnpackStdLowersCorrectly) {
501-
const auto graph = R"IR(
502-
graph(%x.1 : Tensor):
503-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
504-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
505-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
506-
%6 : int[] = prim::ListConstruct(%3)
507-
%7 : Tensor = aten::std(%x.1, %6, %5, %4) # test_zeros.py:10:26
508-
return (%7))IR";
509-
510-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
511-
512-
auto g = std::make_shared<torch::jit::Graph>();
513-
torch::jit::parseIR(graph, g.get());
514-
515-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
516-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
517-
518-
in = at::clone(in);
519-
torch_tensorrt::core::lowering::passes::UnpackStd(g);
520-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
521-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
522-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
523-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
524-
}
525-
526-
TEST(Converters, UnpackStdKeepDimsLowersCorrectly) {
527-
const auto graph = R"IR(
528-
graph(%x.1 : Tensor):
529-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
530-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
531-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
532-
%6 : int[] = prim::ListConstruct(%3)
533-
%7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26
534-
return (%7))IR";
535-
536-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
537-
538-
auto g = std::make_shared<torch::jit::Graph>();
539-
torch::jit::parseIR(graph, g.get());
540-
541-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
542-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
543-
544-
in = at::clone(in);
545-
torch_tensorrt::core::lowering::passes::UnpackStd(g);
546-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
547-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
548-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
549-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
550-
}
551-
552-
TEST(Converters, UnpackStdUnbiasedLowersCorrectly) {
553-
const auto graph = R"IR(
554-
graph(%x.1 : Tensor):
555-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
556-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
557-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
558-
%6 : int[] = prim::ListConstruct(%3)
559-
%7 : Tensor = aten::std(%x.1, %6, %4, %4) # test_zeros.py:10:26
560-
return (%7))IR";
561-
562-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
563-
564-
auto g = std::make_shared<torch::jit::Graph>();
565-
torch::jit::parseIR(graph, g.get());
566-
567-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
568-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
569-
570-
in = at::clone(in);
571-
torch_tensorrt::core::lowering::passes::UnpackStd(g);
572-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
573-
torch::jit::EliminateCommonSubexpression(g);
574-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
575-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
576-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
577-
}
578-
579-
TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
580-
const auto graph = R"IR(
581-
graph(%x.1 : Tensor):
582-
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
583-
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
584-
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
585-
%one : int = prim::Constant[value=1]()
586-
%6 : int[] = prim::ListConstruct(%3, %one)
587-
%7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26
588-
return (%7))IR";
589-
590-
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
591-
592-
auto g = std::make_shared<torch::jit::Graph>();
593-
torch::jit::parseIR(graph, g.get());
594-
595-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
596-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
597-
598-
in = at::clone(in);
599-
torch_tensorrt::core::lowering::passes::UnpackStd(g);
600-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
601-
torch::jit::EliminateCommonSubexpression(g);
602-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
603-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
604-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
605-
}
606-
607-
TEST(Converters, UnpackVarUnbiasedNegAxisLowersCorrectly) {
608-
const auto graph = R"IR(
609-
graph(%x.1 : Tensor):
610-
%37 : bool = prim::Constant[value=1]()
611-
%53 : int[] = prim::Constant[value=[-1]]()
612-
%69 : Tensor = aten::var(%x.1, %53, %37, %37)
613-
return (%69))IR";
614-
615-
auto in = at::randint(-5, 5, {2, 20, 768}, at::kCUDA).to(at::kFloat);
616-
617-
auto jit_in = at::clone(in);
618-
auto g = std::make_shared<torch::jit::Graph>();
619-
torch::jit::parseIR(graph, g.get());
620-
621-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
622-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
623-
624-
in = at::clone(in);
625-
torch_tensorrt::core::lowering::passes::UnpackVar(g);
626-
torch::jit::EliminateCommonSubexpression(g);
627-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
628-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {jit_in});
629-
630-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
631-
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenRollConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%1 : Tensor):
11+
%2 : int[] = prim::Constant[value=[1, 0, 3, 7]]()
12+
%3 : int[] = prim::Constant[value=[0, 1, 2, 3]]()
13+
%4 : Tensor = aten::roll(%1, %2, %3)
14+
return (%4))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
18+
torch::jit::parseIR(graph, g.get());
19+
20+
// Run Pytorch
21+
auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA});
22+
23+
auto jit_in = at::clone(in);
24+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
25+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
26+
27+
auto trt_in = at::clone(in);
28+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
29+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
30+
31+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
32+
}
33+
34+
TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%1 : Tensor):
37+
%2 : int[] = prim::Constant[value=[0, -3, -3]]()
38+
%3 : int[] = prim::Constant[value=[1, 2, 3]]()
39+
%4 : Tensor = aten::roll(%1, %2, %3)
40+
return (%4))IR";
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
44+
torch::jit::parseIR(graph, g.get());
45+
46+
// Run Pytorch
47+
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
48+
49+
auto jit_in = at::clone(in);
50+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
52+
53+
auto trt_in = at::clone(in);
54+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
55+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
56+
57+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
58+
}
59+
60+
TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) {
61+
const auto graph = R"IR(
62+
graph(%1 : Tensor):
63+
%2 : int[] = prim::Constant[value=[0, -3, -3]]()
64+
%3 : int[] = prim::Constant[value=[1, 2, -1]]()
65+
%4 : Tensor = aten::roll(%1, %2, %3)
66+
return (%4))IR";
67+
68+
auto g = std::make_shared<torch::jit::Graph>();
69+
70+
torch::jit::parseIR(graph, g.get());
71+
72+
// Run Pytorch
73+
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
74+
75+
auto jit_in = at::clone(in);
76+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
77+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
78+
79+
auto trt_in = at::clone(in);
80+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
81+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
82+
83+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
84+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ScatterValueConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%data : Tensor,
11+
%index.1 : Tensor):
12+
%value : int = prim::Constant[value=100]()
13+
%dim : int = prim::Constant[value=1]()
14+
%5 : NoneType = prim::Constant()
15+
%6 : bool = prim::Constant[value=0]()
16+
%7 : int = prim::Constant[value=4]()
17+
%index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
18+
%10 : Tensor = aten::scatter(%data, %dim, %index, %value)
19+
return (%10))IR";
20+
21+
auto g = std::make_shared<torch::jit::Graph>();
22+
23+
torch::jit::parseIR(graph, g.get());
24+
25+
auto index = at::randint(0, 5, {2, 2}, {at::kCUDA});
26+
auto data = at::randn({5, 5}, {at::kCUDA});
27+
28+
auto jit_index = at::clone(index);
29+
auto jit_data = at::clone(data);
30+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
31+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index});
32+
33+
auto trt_index = at::clone(index);
34+
auto trt_data = at::clone(data);
35+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index});
36+
37+
for (size_t i = 0; i < jit_results.size(); i++) {
38+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
39+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
40+
}
41+
}
42+
43+
TEST(Converters, ScatterSrcConvertsCorrectly) {
44+
const auto graph = R"IR(
45+
graph(%data : Tensor,
46+
%src : Tensor,
47+
%index.1 : Tensor):
48+
%dim : int = prim::Constant[value=1]()
49+
%5 : NoneType = prim::Constant()
50+
%6 : bool = prim::Constant[value=0]()
51+
%7 : int = prim::Constant[value=4]()
52+
%index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
53+
%10 : Tensor = aten::scatter(%data, %dim, %index, %src)
54+
return (%10))IR";
55+
56+
auto g = std::make_shared<torch::jit::Graph>();
57+
58+
torch::jit::parseIR(graph, g.get());
59+
60+
auto index = at::randint(0, 4, {2, 2}, {at::kCUDA});
61+
auto data = at::randn({5, 5}, {at::kCUDA});
62+
auto src = at::randn({2, 2}, {at::kCUDA});
63+
64+
auto jit_index = at::clone(index);
65+
auto jit_data = at::clone(data);
66+
auto jit_src = at::clone(src);
67+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
68+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index});
69+
70+
auto trt_index = at::clone(index);
71+
auto trt_data = at::clone(data);
72+
auto trt_src = at::clone(src);
73+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index});
74+
75+
for (size_t i = 0; i < jit_results.size(); i++) {
76+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
77+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
78+
}
79+
}

‎tests/core/conversion/converters/test_select.cpp

Lines changed: 0 additions & 1200 deletions
Large diffs are not rendered by default.
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenSliceConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%x.1 : Tensor):
11+
%2 : None = prim::Constant()
12+
%3 : int = prim::Constant[value=2]()
13+
%4 : int = prim::Constant[value=4]()
14+
%5 : int = prim::Constant[value=1]()
15+
%6 : int = prim::Constant[value=0]()
16+
%7 : Tensor = aten::select(%x.1, %6, %6)
17+
%8 : Tensor = aten::select(%7, %6, %5)
18+
%9 : Tensor = aten::slice(%8, %6, %5, %4, %3)
19+
%10 : Tensor = aten::slice(%9, %5, %2, %2, %5)
20+
return (%10))IR";
21+
22+
auto g = std::make_shared<torch::jit::Graph>();
23+
24+
torch::jit::parseIR(graph, g.get());
25+
26+
auto in = at::randint(1, 10, {1, 3, 5, 5}, {at::kCUDA});
27+
28+
auto jit_in = at::clone(in);
29+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
30+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
31+
32+
auto trt_in = at::clone(in);
33+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
34+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
35+
36+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
37+
}
38+
39+
TEST(Converters, ATenSliceNegStartIndexConvertsCorrectly) {
40+
const auto graph = R"IR(
41+
graph(%x.1 : Tensor):
42+
%2 : int = prim::Constant[value=1]()
43+
%3 : int = prim::Constant[value=9223372036854775807]()
44+
%4 : int = prim::Constant[value=-2]()
45+
%5 : int = prim::Constant[value=0]()
46+
%6 : Tensor = aten::slice(%x.1, %5, %4, %3, %2)
47+
%7 : Tensor = aten::slice(%6, %2, %5, %3, %2)
48+
return (%7))IR";
49+
50+
auto g = std::make_shared<torch::jit::Graph>();
51+
52+
torch::jit::parseIR(graph, g.get());
53+
54+
auto in = at::randint(1, 10, {6, 3}, {at::kCUDA});
55+
56+
auto jit_in = at::clone(in);
57+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
58+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
59+
60+
auto trt_in = at::clone(in);
61+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
62+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
63+
64+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
65+
}
66+
67+
TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
68+
const auto graph = R"IR(
69+
graph(%x.1 : Tensor):
70+
%2 : int = prim::Constant[value=3]()
71+
%3 : int = prim::Constant[value=9223372036854775807]()
72+
%4 : int = prim::Constant[value=2]()
73+
%5 : int = prim::Constant[value=-3]()
74+
%6 : int = prim::Constant[value=1]()
75+
%7 : int = prim::Constant[value=-2]()
76+
%8 : int = prim::Constant[value=0]()
77+
%9 : Tensor = aten::slice(%x.1, %8, %8, %7, %6)
78+
%10 : Tensor = aten::slice(%9, %6, %8, %5, %6)
79+
%11 : Tensor = aten::slice(%10, %4, %8, %3, %6)
80+
%12 : Tensor = aten::slice(%11, %2, %8, %3, %6)
81+
return (%12))IR";
82+
83+
auto g = std::make_shared<torch::jit::Graph>();
84+
85+
torch::jit::parseIR(graph, g.get());
86+
87+
auto in = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA});
88+
89+
auto jit_in = at::clone(in);
90+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
91+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
92+
93+
auto trt_in = at::clone(in);
94+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
95+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
96+
97+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
98+
}
99+
100+
TEST(Converters, ATenSliceListConvertsCorrectly) {
101+
const auto graph = R"IR(
102+
graph(%x : Tensor):
103+
%1 : NoneType = prim::Constant()
104+
%2 : int = prim::Constant[value=2]()
105+
%3 : int = prim::Constant[value=1]()
106+
%4 : int = prim::Constant[value=3]()
107+
%list : Tensor[] = aten::unbind(%x, %4)
108+
%slice : Tensor[] = aten::slice(%list, %1, %2, %3)
109+
%out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice)
110+
return (%out.1, %out.2))IR";
111+
112+
auto g = std::make_shared<torch::jit::Graph>();
113+
114+
torch::jit::parseIR(graph, g.get());
115+
116+
auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA});
117+
118+
auto jit_in_x = at::clone(in_x);
119+
120+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
121+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x});
122+
123+
auto trt_in_x = at::clone(in_x);
124+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x});
125+
126+
for (size_t i = 0; i < jit_results.size(); i++) {
127+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
128+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
129+
}
130+
}
131+
132+
TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) {
133+
const auto graph = R"IR(
134+
graph(%x.1 : Tensor):
135+
%2 : None = prim::Constant()
136+
%dim : int = prim::Constant[value=0]()
137+
%start : int = prim::Constant[value=1]()
138+
%end : int = prim::Constant[value=15]()
139+
%step : int = prim::Constant[value=2]()
140+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
141+
return (%9))IR";
142+
143+
auto g = std::make_shared<torch::jit::Graph>();
144+
145+
torch::jit::parseIR(graph, g.get());
146+
147+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
148+
149+
auto jit_in = at::clone(in);
150+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
151+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
152+
153+
auto trt_in = at::clone(in);
154+
// dynamic shape in batch
155+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
156+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
157+
158+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
159+
}
160+
161+
TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) {
162+
const auto graph = R"IR(
163+
graph(%x.1 : Tensor):
164+
%2 : None = prim::Constant()
165+
%dim : int = prim::Constant[value=0]()
166+
%start : int = prim::Constant[value=1]()
167+
%end : int = prim::Constant[value=9223372036854775807]()
168+
%step : int = prim::Constant[value=2]()
169+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
170+
return (%9))IR";
171+
172+
auto g = std::make_shared<torch::jit::Graph>();
173+
174+
torch::jit::parseIR(graph, g.get());
175+
176+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
177+
178+
auto jit_in = at::clone(in);
179+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
180+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
181+
182+
auto trt_in = at::clone(in);
183+
// dynamic shape in batch
184+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
185+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
186+
187+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
188+
}
189+
190+
TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) {
191+
const auto graph = R"IR(
192+
graph(%x.1 : Tensor):
193+
%2 : None = prim::Constant()
194+
%dim : int = prim::Constant[value=0]()
195+
%start : int = prim::Constant[value=-15]()
196+
%end : int = prim::Constant[value=15]()
197+
%step : int = prim::Constant[value=2]()
198+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
199+
return (%9))IR";
200+
201+
auto g = std::make_shared<torch::jit::Graph>();
202+
203+
torch::jit::parseIR(graph, g.get());
204+
205+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
206+
207+
auto jit_in = at::clone(in);
208+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
209+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
210+
211+
auto trt_in = at::clone(in);
212+
// dynamic shape in batch
213+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
214+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
215+
216+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
217+
}
218+
219+
TEST(Converters, ATenSliceDynamicNegEndBatchConvertsCorrectly) {
220+
const auto graph = R"IR(
221+
graph(%x.1 : Tensor):
222+
%2 : None = prim::Constant()
223+
%dim : int = prim::Constant[value=0]()
224+
%start : int = prim::Constant[value=1]()
225+
%end : int = prim::Constant[value=-2]()
226+
%step : int = prim::Constant[value=3]()
227+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
228+
return (%9))IR";
229+
230+
auto g = std::make_shared<torch::jit::Graph>();
231+
232+
torch::jit::parseIR(graph, g.get());
233+
234+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
235+
236+
auto jit_in = at::clone(in);
237+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
238+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
239+
240+
auto trt_in = at::clone(in);
241+
// dynamic shape in batch
242+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
243+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
244+
245+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
246+
}
247+
248+
TEST(Converters, ATenSliceDynamicNoneBatchConvertsCorrectly) {
249+
const auto graph = R"IR(
250+
graph(%x.1 : Tensor):
251+
%dim : int = prim::Constant[value=0]()
252+
%start : None = prim::Constant()
253+
%end : None = prim::Constant()
254+
%step : int = prim::Constant[value=3]()
255+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
256+
return (%9))IR";
257+
258+
auto g = std::make_shared<torch::jit::Graph>();
259+
260+
torch::jit::parseIR(graph, g.get());
261+
262+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
263+
264+
auto jit_in = at::clone(in);
265+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
266+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
267+
268+
auto trt_in = at::clone(in);
269+
// dynamic shape in batch
270+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
271+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
272+
273+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
274+
}
275+
276+
TEST(Converters, ATenSliceDynamicConvertsCorrectly) {
277+
const auto graph = R"IR(
278+
graph(%x.1 : Tensor):
279+
%2 : None = prim::Constant()
280+
%dim : int = prim::Constant[value=1]()
281+
%start : int = prim::Constant[value=3]()
282+
%end : int = prim::Constant[value=32]()
283+
%step : int = prim::Constant[value=3]()
284+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
285+
return (%9))IR";
286+
287+
auto g = std::make_shared<torch::jit::Graph>();
288+
289+
torch::jit::parseIR(graph, g.get());
290+
291+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
292+
293+
auto jit_in = at::clone(in);
294+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
295+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
296+
297+
auto trt_in = at::clone(in);
298+
// dynamic shape in dim 1, slice in dim 1
299+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false);
300+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
301+
302+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
303+
}
304+
305+
TEST(Converters, ATenSliceDynamic2ConvertsCorrectly) {
306+
const auto graph = R"IR(
307+
graph(%x.1 : Tensor):
308+
%2 : None = prim::Constant()
309+
%dim : int = prim::Constant[value=1]()
310+
%start : int = prim::Constant[value=3]()
311+
%end : int = prim::Constant[value=17]()
312+
%step : int = prim::Constant[value=3]()
313+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
314+
return (%9))IR";
315+
316+
auto g = std::make_shared<torch::jit::Graph>();
317+
318+
torch::jit::parseIR(graph, g.get());
319+
320+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
321+
322+
auto jit_in = at::clone(in);
323+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
324+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
325+
326+
auto trt_in = at::clone(in);
327+
// dynamic shape in batch, slice in dim 1
328+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
329+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
330+
331+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
332+
}
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%x.1 : Tensor):
11+
%2 : int[] = prim::Constant[value=[1, 2]]()
12+
%3 : int = prim::Constant[value=1]()
13+
%4 : Tensor[] = aten::split(%x.1, %2, %3)
14+
%x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4)
15+
return (%x1.1, %x2.1))IR";
16+
17+
auto g = std::make_shared<torch::jit::Graph>();
18+
19+
torch::jit::parseIR(graph, g.get());
20+
21+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
22+
23+
auto jit_in = at::clone(in);
24+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
25+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
26+
27+
auto trt_in = at::clone(in);
28+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
29+
30+
for (size_t i = 0; i < jit_results.size(); i++) {
31+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
32+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
33+
}
34+
}
35+
36+
TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) {
37+
const auto graph = R"IR(
38+
graph(%argument_1.1 : Tensor):
39+
%2 : int[] = prim::Constant[value=[1, 2]]()
40+
%3 : int = prim::Constant[value=1]()
41+
%4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3)
42+
%5 : Tensor, %6 : Tensor = prim::ListUnpack(%4)
43+
return (%5, %6))IR";
44+
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
47+
torch::jit::parseIR(graph, g.get());
48+
49+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
50+
51+
auto jit_in = at::clone(in);
52+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
53+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
54+
55+
auto trt_in = at::clone(in);
56+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
57+
58+
for (size_t i = 0; i < jit_results.size(); i++) {
59+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
60+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
61+
}
62+
}
63+
64+
TEST(Converters, ATenSplitFixedConvertsCorrectly) {
65+
const auto graph = R"IR(
66+
graph(%argument_1.1 : Tensor):
67+
%2 : int = prim::Constant[value=1]()
68+
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2)
69+
%4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3)
70+
return (%4, %5, %6))IR";
71+
72+
auto g = std::make_shared<torch::jit::Graph>();
73+
74+
torch::jit::parseIR(graph, g.get());
75+
76+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
77+
78+
auto jit_in = at::clone(in);
79+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
80+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
81+
82+
auto trt_in = at::clone(in);
83+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
84+
85+
for (size_t i = 0; i < jit_results.size(); i++) {
86+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
87+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
88+
}
89+
}
90+
91+
TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) {
92+
const auto graph = R"IR(
93+
graph(%argument_1.1 : Tensor):
94+
%2 : int = prim::Constant[value=2]()
95+
%2.1 : int = prim::Constant[value=1]()
96+
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1)
97+
%4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3)
98+
return (%4, %5, %6))IR";
99+
100+
auto g = std::make_shared<torch::jit::Graph>();
101+
102+
torch::jit::parseIR(graph, &*g);
103+
104+
auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA});
105+
106+
auto jit_in = at::clone(in);
107+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
108+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
109+
110+
auto trt_in = at::clone(in);
111+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
112+
113+
for (size_t i = 0; i < jit_results.size(); i++) {
114+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
115+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
116+
}
117+
}
118+
119+
TEST(Converters, ATenSplitAndAddConvertsCorrectly) {
120+
const auto graph = R"IR(
121+
graph(%argument_1.1 : Tensor):
122+
%2 : int = prim::Constant[value=2]()
123+
%2.1 : int = prim::Constant[value=1]()
124+
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1)
125+
%4 : Tensor, %5 : Tensor = prim::ListUnpack(%3)
126+
%6 : Tensor = aten::add(%4, %5, %2.1)
127+
return (%6))IR";
128+
129+
auto g = std::make_shared<torch::jit::Graph>();
130+
131+
torch::jit::parseIR(graph, &*g);
132+
133+
auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
134+
135+
auto jit_in = at::clone(in);
136+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
137+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
138+
139+
auto trt_in = at::clone(in);
140+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
141+
142+
for (size_t i = 0; i < jit_results.size(); i++) {
143+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
144+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
145+
}
146+
}
147+
148+
TEST(Converters, ATenSplitNegativeDimsConvertsCorrectly) {
149+
const auto graph = R"IR(
150+
graph(%x.1 : Tensor):
151+
%2 : int = prim::Constant[value=1]()
152+
%n1 : int = prim::Constant[value=-1]()
153+
%3 : Tensor[] = aten::split(%x.1, %2, %n1)
154+
%4 : Tensor, %5 : Tensor, %6 : Tensor, %7 : Tensor = prim::ListUnpack(%3)
155+
return (%4, %5, %6, %7))IR";
156+
157+
auto g = std::make_shared<torch::jit::Graph>();
158+
159+
torch::jit::parseIR(graph, g.get());
160+
161+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
162+
163+
auto jit_in = at::clone(in);
164+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
165+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
166+
167+
auto trt_in = at::clone(in);
168+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
169+
170+
for (size_t i = 0; i < jit_results.size(); i++) {
171+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
172+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
173+
}
174+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenUnbindConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%x.1 : Tensor):
11+
%2 : int = prim::Constant[value=0]()
12+
%3 : Tensor[] = aten::unbind(%x.1, %2)
13+
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
14+
return (%o1.1, %o2.1))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
18+
torch::jit::parseIR(graph, g.get());
19+
20+
auto in = at::randint(1, 10, {2, 3, 4, 4}, {at::kCUDA});
21+
22+
auto jit_in = at::clone(in);
23+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
24+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
25+
26+
auto trt_in = at::clone(in);
27+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
28+
29+
for (size_t i = 0; i < jit_results.size(); i++) {
30+
auto trt = trt_results[i];
31+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
32+
}
33+
}
34+
35+
TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
36+
const auto graph = R"IR(
37+
graph(%x.1 : Tensor):
38+
%2 : int = prim::Constant[value=-1]()
39+
%3 : Tensor[] = aten::unbind(%x.1, %2)
40+
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
41+
return (%o1.1, %o2.1))IR";
42+
43+
auto g = std::make_shared<torch::jit::Graph>();
44+
45+
torch::jit::parseIR(graph, g.get());
46+
47+
auto in = at::randint(1, 10, {5, 2}, {at::kCUDA});
48+
49+
auto jit_in = at::clone(in);
50+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
52+
53+
auto trt_in = at::clone(in);
54+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
55+
56+
for (size_t i = 0; i < jit_results.size(); i++) {
57+
auto trt = trt_results[i];
58+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
59+
}
60+
}
61+
62+
TEST(Converters, ATenUnbindEvaluatedTensor) {
63+
const auto graph = R"IR(
64+
graph(%x.1 : Tensor):
65+
%2 : None = prim::Constant()
66+
%3 : int[] = aten::size(%x.1)
67+
%z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2)
68+
%5 : int = prim::Constant[value=-1]()
69+
%6 : Tensor[] = aten::unbind(%z.1, %5)
70+
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6)
71+
return (%o1.1, %o2.1))IR";
72+
73+
auto in = at::randint(1, 10, {2}, {at::kCUDA});
74+
75+
auto g = std::make_shared<torch::jit::Graph>();
76+
77+
torch::jit::parseIR(graph, g.get());
78+
79+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
80+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
81+
82+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
83+
84+
for (size_t i = 0; i < jit_results.size(); i++) {
85+
auto trt = trt_results[i];
86+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i].cuda(), trt, 2e-6));
87+
}
88+
}
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
8+
#include "torch/torch.h"
9+
10+
TEST(Converters, UnpackVarLowersCorrectly) {
11+
const auto graph = R"IR(
12+
graph(%x.1 : Tensor):
13+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
14+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
15+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
16+
%6 : int[] = prim::ListConstruct(%3)
17+
%7 : Tensor = aten::var(%x.1, %6, %5, %4) # test_zeros.py:10:26
18+
return (%7))IR";
19+
20+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
21+
22+
auto g = std::make_shared<torch::jit::Graph>();
23+
torch::jit::parseIR(graph, g.get());
24+
25+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
26+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
27+
28+
in = at::clone(in);
29+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
30+
torch::jit::EliminateCommonSubexpression(g);
31+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
32+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
33+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
34+
}
35+
36+
TEST(Converters, UnpackVarKeepDimsLowersCorrectly) {
37+
const auto graph = R"IR(
38+
graph(%x.1 : Tensor):
39+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
40+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
41+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
42+
%6 : int[] = prim::ListConstruct(%3)
43+
%7 : Tensor = aten::var(%x.1, %6, %5, %5) # test_zeros.py:10:26
44+
return (%7))IR";
45+
46+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
47+
48+
auto g = std::make_shared<torch::jit::Graph>();
49+
torch::jit::parseIR(graph, g.get());
50+
51+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
52+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
53+
54+
in = at::clone(in);
55+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
56+
torch::jit::EliminateCommonSubexpression(g);
57+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
58+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
59+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
60+
}
61+
62+
TEST(Converters, UnpackVarUnbiasedLowersCorrectly) {
63+
const auto graph = R"IR(
64+
graph(%x.1 : Tensor):
65+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
66+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
67+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
68+
%6 : int[] = prim::ListConstruct(%3)
69+
%7 : Tensor = aten::var(%x.1, %6, %4, %4) # test_zeros.py:10:26
70+
return (%7))IR";
71+
72+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
73+
74+
auto g = std::make_shared<torch::jit::Graph>();
75+
torch::jit::parseIR(graph, g.get());
76+
77+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
78+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
79+
80+
in = at::clone(in);
81+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
82+
torch::jit::EliminateCommonSubexpression(g);
83+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
84+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
85+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
86+
}
87+
88+
TEST(Converters, UnpackVarUnbiasedKeepDimsLowersCorrectly) {
89+
const auto graph = R"IR(
90+
graph(%x.1 : Tensor):
91+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
92+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
93+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
94+
%6 : int[] = prim::ListConstruct(%3)
95+
%7 : Tensor = aten::var(%x.1, %6, %4, %5) # test_zeros.py:10:26
96+
return (%7))IR";
97+
98+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
99+
100+
auto g = std::make_shared<torch::jit::Graph>();
101+
torch::jit::parseIR(graph, g.get());
102+
103+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
104+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
105+
106+
in = at::clone(in);
107+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
108+
torch::jit::EliminateCommonSubexpression(g);
109+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
110+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
111+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
112+
}
113+
114+
TEST(Converters, UnpackStdLowersCorrectly) {
115+
const auto graph = R"IR(
116+
graph(%x.1 : Tensor):
117+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
118+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
119+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
120+
%6 : int[] = prim::ListConstruct(%3)
121+
%7 : Tensor = aten::std(%x.1, %6, %5, %4) # test_zeros.py:10:26
122+
return (%7))IR";
123+
124+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
125+
126+
auto g = std::make_shared<torch::jit::Graph>();
127+
torch::jit::parseIR(graph, g.get());
128+
129+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
130+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
131+
132+
in = at::clone(in);
133+
torch_tensorrt::core::lowering::passes::UnpackStd(g);
134+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
135+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
136+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
137+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
138+
}
139+
140+
TEST(Converters, UnpackStdKeepDimsLowersCorrectly) {
141+
const auto graph = R"IR(
142+
graph(%x.1 : Tensor):
143+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
144+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
145+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
146+
%6 : int[] = prim::ListConstruct(%3)
147+
%7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26
148+
return (%7))IR";
149+
150+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
151+
152+
auto g = std::make_shared<torch::jit::Graph>();
153+
torch::jit::parseIR(graph, g.get());
154+
155+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
156+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
157+
158+
in = at::clone(in);
159+
torch_tensorrt::core::lowering::passes::UnpackStd(g);
160+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
161+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
162+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
163+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
164+
}
165+
166+
TEST(Converters, UnpackStdUnbiasedLowersCorrectly) {
167+
const auto graph = R"IR(
168+
graph(%x.1 : Tensor):
169+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
170+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
171+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
172+
%6 : int[] = prim::ListConstruct(%3)
173+
%7 : Tensor = aten::std(%x.1, %6, %4, %4) # test_zeros.py:10:26
174+
return (%7))IR";
175+
176+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
177+
178+
auto g = std::make_shared<torch::jit::Graph>();
179+
torch::jit::parseIR(graph, g.get());
180+
181+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
182+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
183+
184+
in = at::clone(in);
185+
torch_tensorrt::core::lowering::passes::UnpackStd(g);
186+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
187+
torch::jit::EliminateCommonSubexpression(g);
188+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
189+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
190+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
191+
}
192+
193+
TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
194+
const auto graph = R"IR(
195+
graph(%x.1 : Tensor):
196+
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
197+
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
198+
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
199+
%one : int = prim::Constant[value=1]()
200+
%6 : int[] = prim::ListConstruct(%3, %one)
201+
%7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26
202+
return (%7))IR";
203+
204+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
205+
206+
auto g = std::make_shared<torch::jit::Graph>();
207+
torch::jit::parseIR(graph, g.get());
208+
209+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
210+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
211+
212+
in = at::clone(in);
213+
torch_tensorrt::core::lowering::passes::UnpackStd(g);
214+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
215+
torch::jit::EliminateCommonSubexpression(g);
216+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
217+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
218+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
219+
}
220+
221+
TEST(Converters, UnpackVarUnbiasedNegAxisLowersCorrectly) {
222+
const auto graph = R"IR(
223+
graph(%x.1 : Tensor):
224+
%37 : bool = prim::Constant[value=1]()
225+
%53 : int[] = prim::Constant[value=[-1]]()
226+
%69 : Tensor = aten::var(%x.1, %53, %37, %37)
227+
return (%69))IR";
228+
229+
auto in = at::randint(-5, 5, {2, 20, 768}, at::kCUDA).to(at::kFloat);
230+
231+
auto jit_in = at::clone(in);
232+
auto g = std::make_shared<torch::jit::Graph>();
233+
torch::jit::parseIR(graph, g.get());
234+
235+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
236+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
237+
238+
in = at::clone(in);
239+
torch_tensorrt::core::lowering::passes::UnpackVar(g);
240+
torch::jit::EliminateCommonSubexpression(g);
241+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
242+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {jit_in});
243+
244+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
245+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <torch/torch.h>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
9+
TEST(Converters, WhereConvertsCorrectly) {
10+
const auto graph = R"IR(
11+
graph(%condition : Tensor,
12+
%x : Tensor,
13+
%y : Tensor):
14+
%out : Tensor = aten::where(%condition, %x, %y)
15+
return (%out))IR";
16+
17+
auto g = std::make_shared<torch::jit::Graph>();
18+
19+
torch::jit::parseIR(graph, g.get());
20+
21+
auto condition = at::randint(0, 2, {5, 5}, {at::kCUDA}).to(torch::kBool);
22+
auto x = at::randn({5, 5}, {at::kCUDA});
23+
auto y = at::randn({5, 5}, {at::kCUDA});
24+
25+
auto jit_condition = at::clone(condition);
26+
auto jit_x = at::clone(x);
27+
auto jit_y = at::clone(y);
28+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
29+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y});
30+
31+
auto trt_condition = at::clone(condition);
32+
auto trt_x = at::clone(x);
33+
auto trt_y = at::clone(y);
34+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y});
35+
36+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
37+
}
38+
39+
TEST(Converters, WhereConvertsMismatchedShapesCorrectly) {
40+
const auto graph = R"IR(
41+
graph(%condition : Tensor,
42+
%x : Tensor,
43+
%y : Tensor):
44+
%out : Tensor = aten::where(%condition, %x, %y)
45+
return (%out))IR";
46+
47+
auto g = std::make_shared<torch::jit::Graph>();
48+
49+
torch::jit::parseIR(graph, g.get());
50+
51+
// As per Torch behavior, the input Tensors are expected to be broadcasted
52+
// along their respective dimension in the largest-rank Tensor provided
53+
auto condition = at::randint(0, 2, {7, 5}, {at::kCUDA}).to(torch::kBool);
54+
auto x = at::randn({2, 7, 5}, {at::kCUDA});
55+
auto y = at::randn({5}, {at::kCUDA});
56+
57+
auto jit_condition = at::clone(condition);
58+
auto jit_x = at::clone(x);
59+
auto jit_y = at::clone(y);
60+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
61+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y});
62+
63+
auto trt_condition = at::clone(condition);
64+
auto trt_x = at::clone(x);
65+
auto trt_y = at::clone(y);
66+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y});
67+
68+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
69+
}

‎third_party/tensorrt/archive/BUILD

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ cc_library(
4646
"nvinfer_lib",
4747
"@cuda//:cudart",
4848
"@cudnn",
49-
] + select({
50-
":windows": ["@cuda//:cublas"],
51-
"//conditions:default": ["@cuda//:cublas"],
52-
}),
49+
],
5350
)
5451

5552
####################################################################################
@@ -186,8 +183,5 @@ cc_library(
186183
"nvinferplugin_lib",
187184
"@cuda//:cudart",
188185
"@cudnn",
189-
] + select({
190-
":windows": ["@cuda//:cublas"],
191-
"//conditions:default": ["@cuda//:cublas"],
192-
}),
186+
],
193187
)

‎third_party/tensorrt/local/BUILD

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,7 @@ cc_library(
113113
"nvinfer_lib",
114114
"@cuda//:cudart",
115115
"@cudnn",
116-
] + select({
117-
":windows": ["@cuda//:cublas"],
118-
"//conditions:default": ["@cuda//:cublas"],
119-
}),
116+
],
120117
)
121118

122119
####################################################################################
@@ -370,9 +367,6 @@ cc_library(
370367
"nvinfer",
371368
"@cuda//:cudart",
372369
"@cudnn",
373-
] + select({
374-
":windows": ["@cuda//:cublas"],
375-
"//conditions:default": ["@cuda//:cublas"],
376-
}),
370+
],
377371
alwayslink = True,
378372
)

‎toolchains/ci_workspaces/WORKSPACE.x86_64

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local_repository(
4141
new_local_repository(
4242
name = "cuda",
4343
build_file = "@//third_party/cuda:BUILD",
44-
path = "/usr/local/cuda/",
44+
path = "/usr/local/cuda-11.8/",
4545
)
4646

4747
new_local_repository(

‎toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local_repository(
4141
new_local_repository(
4242
name = "cuda",
4343
build_file = "@//third_party/cuda:BUILD",
44-
path = "/usr/local/cuda-11.7",
44+
path = "/usr/local/cuda-11.8",
4545
)
4646

4747
new_local_repository(
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "7c4b8754830fef23ec19c5eaf414794cee9597b435df055f5c1d0471d3e81568",
59+
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "f1e64a75dd12d0ba4c8c1f61947299e0a9c50684dff64f0cfbf355aa7a13e8cf",
67+
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
7070
)
7171

7272
####################################################################################

‎toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local_repository(
4141
new_local_repository(
4242
name = "cuda",
4343
build_file = "@//third_party/cuda:BUILD",
44-
path = "/usr/local/cuda",
44+
path = "/usr/local/cuda-11.8",
4545
)
4646

4747
new_local_repository(
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "7c4b8754830fef23ec19c5eaf414794cee9597b435df055f5c1d0471d3e81568",
59+
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "f1e64a75dd12d0ba4c8c1f61947299e0a9c50684dff64f0cfbf355aa7a13e8cf",
67+
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
7070
)
7171

7272
####################################################################################

‎tools/cpp_benchmark/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This is a quick benchmarking application for Torch-TensorRT. It lets you run sup
66

77
Run with bazel:
88

9-
> Note: Make sure libtorch and TensorRT are in your LD_LIBRARY_PATH before running, if you need a location you can `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:[WORKSPACE ROOT]/bazel-Torch-TensorRT/external/libtorch/lib:[WORKSPACE ROOT]/bazel-Torch-TensorRT/external/tensorrt/lib`
9+
> Note: Make sure libtorch and TensorRT are in your LD_LIBRARY_PATH before running, if you need a location you can `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:[WORKSPACE ROOT]/bazel-TensorRT/external/libtorch/lib:[WORKSPACE ROOT]/bazel-TensorRT/external/tensorrt/lib`
1010
1111
``` sh
1212
bazel run //tools/cpp_benchmark --cxxopt="-DNDEBUG" --cxxopt="-DJIT" --cxxopt="-DTRT" -- [PATH TO JIT MODULE FILE] [INPUT SIZE (explicit batch)]

0 commit comments

Comments
 (0)
Please sign in to comment.