Skip to content

Commit 543f436

Browse files
authored
chore: Make from and to methods use the same TRT API (#2858)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4b076c7 commit 543f436

File tree

4 files changed

+26
-20
lines changed

4 files changed

+26
-20
lines changed

.github/workflows/build-test.yml renamed to .github/workflows/build-test-linux.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Build and test linux wheels
1+
name: Build and test Linux wheels
22

33
on:
44
pull_request:
@@ -86,7 +86,7 @@ jobs:
8686
popd
8787
pushd .
8888
cd tests/py/ts
89-
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
89+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
9090
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
9191
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
9292
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
@@ -117,7 +117,7 @@ jobs:
117117
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
118118
pushd .
119119
cd tests/py/dynamo
120-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
120+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
121121
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
122122
popd
123123
@@ -146,7 +146,7 @@ jobs:
146146
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
147147
pushd .
148148
cd tests/py/dynamo
149-
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
149+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
150150
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
151151
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
152152
popd
@@ -176,7 +176,7 @@ jobs:
176176
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
177177
pushd .
178178
cd tests/py/dynamo
179-
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
179+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
180180
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
181181
popd
182182
@@ -205,7 +205,7 @@ jobs:
205205
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
206206
pushd .
207207
cd tests/py/dynamo
208-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
208+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
209209
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
210210
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
211211
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
@@ -236,7 +236,7 @@ jobs:
236236
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
237237
pushd .
238238
cd tests/py/dynamo
239-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
239+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
240240
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
241241
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
242242
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/
@@ -266,6 +266,6 @@ jobs:
266266
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
267267
pushd .
268268
cd tests/py/core
269-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
269+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
270270
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
271271
popd

.github/workflows/build-test-windows.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
export USE_HOST_DEPS=1
7373
pushd .
7474
cd tests/py/dynamo
75-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
75+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
7676
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
7777
popd
7878
@@ -98,7 +98,7 @@ jobs:
9898
export USE_HOST_DEPS=1
9999
pushd .
100100
cd tests/py/dynamo
101-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
101+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
102102
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
103103
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
104104
popd
@@ -125,7 +125,7 @@ jobs:
125125
export USE_HOST_DEPS=1
126126
pushd .
127127
cd tests/py/dynamo
128-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
128+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
129129
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
130130
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
131131
popd
@@ -152,7 +152,7 @@ jobs:
152152
export USE_HOST_DEPS=1
153153
pushd .
154154
cd tests/py/dynamo
155-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
155+
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
156156
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
157157
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
158158
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/

py/torch_tensorrt/_enums.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,21 @@ def _from(
9999
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
100100
)
101101
elif isinstance(t, trt.DataType):
102-
if t == trt.uint8:
102+
if t == trt.DataType.UINT8:
103103
return dtype.u8
104-
elif t == trt.int8:
104+
elif t == trt.DataType.INT8:
105105
return dtype.i8
106-
elif t == trt.int32:
106+
elif t == trt.DataType.INT32:
107107
return dtype.i32
108-
elif t == trt.int64:
108+
elif t == trt.DataType.INT64:
109109
return dtype.i64
110-
elif t == trt.float16:
110+
elif t == trt.DataType.HALF:
111111
return dtype.f16
112-
elif t == trt.float32:
112+
elif t == trt.DataType.FLOAT:
113113
return dtype.f32
114-
elif t == trt.bool:
114+
elif t == trt.DataType.BOOL:
115115
return dtype.b
116-
elif t == trt.bf16:
116+
elif t == trt.DataType.BF16:
117117
return dtype.bf16
118118
else:
119119
raise TypeError(

tests/py/requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pytest>=8.2.1
2+
pytest-xdist>=3.6.1
3+
timm>=1.0.3
4+
transformers==4.39.3
5+
parameterized>=0.2.0
6+
expecttest==0.1.6

0 commit comments

Comments
 (0)