diff --git a/.circleci/config.yml b/.circleci/config.yml index dcbc84cc9a..16dda8609f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -435,6 +435,7 @@ commands: mkdir -p /tmp/artifacts/test_results cd tests/py pytest --junitxml=/tmp/artifacts/test_results/api/api_test_results.xml api/ + pytest --junitxml=/tmp/artifacts/test_results/models/models_test_results.xml models/ pytest --junitxml=/tmp/artifacts/test_results/integrations/integrations_test_results.xml integrations/ cd ~/project diff --git a/.github/workflows/docgen.yml b/.github/workflows/docgen.yml index 7b66b98be5..61af5bc5d9 100644 --- a/.github/workflows/docgen.yml +++ b/.github/workflows/docgen.yml @@ -31,7 +31,7 @@ jobs: - name: Set up Python 3.9.4 uses: actions/setup-python@v2 with: - python-version: 3.9.4 + python-version: 3.9.4 - uses: actions/checkout@v2 with: ref: ${{github.head_ref}} diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 58c8440684..b56a233169 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -39,7 +39,7 @@ jobs: pip3 install -r $GITHUB_WORKSPACE/.github/scripts/requirements.txt pip3 install -r $GITHUB_WORKSPACE/requirements-dev.txt - name: Lint C++ - run: | + run: | cd $GITHUB_WORKSPACE python3 $GITHUB_WORKSPACE/.github/scripts/run_cpp_linter.py env: diff --git a/noxfile.py b/noxfile.py index 41926b5ee1..eff8136fbb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,6 +30,9 @@ if USE_HOST_DEPS: print("Using dependencies from host python") +# Set epochs to train VGG model for accuracy tests +EPOCHS = 25 + SUPPORTED_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10"] nox.options.sessions = [ @@ -63,31 +66,6 @@ def install_torch_trt(session): session.run("python", "setup.py", "develop") -def download_datasets(session): - print( - "Downloading dataset to path", - os.path.join(TOP_DIR, "examples/int8/training/vgg16"), - ) - session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16")) - session.run_always( - "wget", "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", external=True - ) - session.run_always("tar", "-xvzf", "cifar-10-binary.tar.gz", external=True) - session.run_always( - "mkdir", - "-p", - os.path.join(TOP_DIR, "tests/accuracy/datasets/data"), - external=True, - ) - session.run_always( - "cp", - "-rpf", - os.path.join(TOP_DIR, "examples/int8/training/vgg16/cifar-10-batches-bin"), - os.path.join(TOP_DIR, "tests/accuracy/datasets/data/cidar-10-batches-bin"), - external=True, - ) - - def train_model(session): session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16")) session.install("-r", "requirements.txt") @@ -107,14 +85,14 @@ def train_model(session): "--ckpt-dir", "vgg16_ckpts", "--epochs", - "25", + str(EPOCHS), env={"PYTHONPATH": PYT_PATH}, ) session.run_always( "python", "export_ckpt.py", - "vgg16_ckpts/ckpt_epoch25.pth", + "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth", env={"PYTHONPATH": PYT_PATH}, ) else: @@ -130,10 +108,12 @@ def train_model(session): "--ckpt-dir", "vgg16_ckpts", "--epochs", - "25", + str(EPOCHS), ) - session.run_always("python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch25.pth") + session.run_always( + "python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth" + ) def finetune_model(session): @@ -156,9 +136,9 @@ def finetune_model(session): "--ckpt-dir", "vgg16_ckpts", "--start-from", - "25", + str(EPOCHS), "--epochs", - "26", + str(EPOCHS + 1), env={"PYTHONPATH": PYT_PATH}, ) @@ -166,7 +146,7 @@ def finetune_model(session): session.run_always( "python", "export_qat.py", - "vgg16_ckpts/ckpt_epoch26.pth", + "vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth", env={"PYTHONPATH": PYT_PATH}, ) else: @@ -182,13 +162,17 @@ def finetune_model(session): "--ckpt-dir", "vgg16_ckpts", "--start-from", - "25", + str(EPOCHS), "--epochs", - "26", + str(EPOCHS + 1), ) # Export model - session.run_always("python", "export_qat.py", "vgg16_ckpts/ckpt_epoch26.pth") + session.run_always( + "python", + "export_qat.py", + "vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth", + ) def cleanup(session): @@ -219,6 +203,19 @@ def run_base_tests(session): session.run_always("pytest", test) +def run_model_tests(session): + print("Running model tests") + session.chdir(os.path.join(TOP_DIR, "tests/py")) + tests = [ + "models", + ] + for test in tests: + if USE_HOST_DEPS: + session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) + else: + session.run_always("pytest", test) + + def run_accuracy_tests(session): print("Running accuracy tests") session.chdir(os.path.join(TOP_DIR, "tests/py")) @@ -268,8 +265,8 @@ def run_trt_compatibility_tests(session): copy_model(session) session.chdir(os.path.join(TOP_DIR, "tests/py")) tests = [ - "test_trt_intercompatibility.py", - "test_ptq_trt_calibrator.py", + "integrations/test_trt_intercompatibility.py", + # "ptq/test_ptq_trt_calibrator.py", ] for test in tests: if USE_HOST_DEPS: @@ -282,7 +279,7 @@ def run_dla_tests(session): print("Running DLA tests") session.chdir(os.path.join(TOP_DIR, "tests/py")) tests = [ - "test_api_dla.py", + "hw/test_api_dla.py", ] for test in tests: if USE_HOST_DEPS: @@ -295,7 +292,7 @@ def run_multi_gpu_tests(session): print("Running multi GPU tests") session.chdir(os.path.join(TOP_DIR, "tests/py")) tests = [ - "test_multi_gpu.py", + "hw/test_multi_gpu.py", ] for test in tests: if USE_HOST_DEPS: @@ -322,13 +319,12 @@ def run_l0_dla_tests(session): cleanup(session) -def run_l1_accuracy_tests(session): +def run_l1_model_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - download_datasets(session) - train_model(session) - run_accuracy_tests(session) + download_models(session) + run_model_tests(session) cleanup(session) @@ -336,7 +332,6 @@ def run_l1_int8_accuracy_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - download_datasets(session) train_model(session) finetune_model(session) run_int8_accuracy_tests(session) @@ -347,9 +342,6 @@ def run_l2_trt_compatibility_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - download_models(session) - download_datasets(session) - train_model(session) run_trt_compatibility_tests(session) cleanup(session) @@ -376,9 +368,9 @@ def l0_dla_tests(session): @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def l1_accuracy_tests(session): - """Checking accuracy performance on various usecases""" - run_l1_accuracy_tests(session) +def l1_model_tests(session): + """When a user needs to test the functionality of standard models compilation and results""" + run_l1_model_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) @@ -397,13 +389,3 @@ def l2_trt_compatibility_tests(session): def l2_multi_gpu_tests(session): """Makes sure that Torch-TensorRT can operate on multi-gpu systems""" run_l2_multi_gpu_tests(session) - - -@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def download_test_models(session): - """Grab all the models needed for testing""" - try: - import torch - except ModuleNotFoundError: - install_deps(session) - download_models(session) diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ptq.py index 326f35f942..e7f3411cd5 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ptq.py @@ -56,6 +56,13 @@ def write_calibration_cache(self, cache): return b"" +# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation. +# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy. +# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__ +def __reduce__(self): + return self.__class__.__name__ + + class DataLoaderCalibrator(object): """ Constructs a calibrator class in TensorRT and uses pytorch dataloader to load/preproces @@ -114,24 +121,27 @@ def __new__(cls, *args, **kwargs): "get_batch": get_cache_mode_batch if use_cache else get_batch, "read_calibration_cache": read_calibration_cache, "write_calibration_cache": write_calibration_cache, + "__reduce__": __reduce__, # used when you deepcopy the DataLoaderCalibrator object } # Using type metaclass to construct calibrator class based on algorithm type if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: return type( - "DataLoaderCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping + "Int8EntropyCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping )() elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: return type( - "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping + "Int8EntropyCalibrator2", + (_C.IInt8EntropyCalibrator2,), + attribute_mapping, )() elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: return type( - "DataLoaderCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping + "Int8LegacyCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping )() elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: return type( - "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping + "Int8MinMaxCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping )() else: log( diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 154b29dd7b..9616111caa 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -225,7 +225,7 @@ def _parse_input_signature(input_signature: Any): def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: - # TODO: Remove deep copy once collections does not need partial compilation + # TODO: Use deepcopy to support partial compilation of collections compile_spec = deepcopy(compile_spec_) info = _ts_C.CompileSpec() @@ -301,7 +301,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: compile_spec["enabled_precisions"] ) - if "calibrator" in compile_spec: + if "calibrator" in compile_spec and compile_spec["calibrator"]: info.ptq_calibrator = compile_spec["calibrator"] if "sparse_weights" in compile_spec: diff --git a/tests/core/lowering/test_module_fallback_passes.cpp b/tests/core/lowering/test_module_fallback_passes.cpp index f11882df8b..e6eb098079 100644 --- a/tests/core/lowering/test_module_fallback_passes.cpp +++ b/tests/core/lowering/test_module_fallback_passes.cpp @@ -124,5 +124,5 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) { } auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99)); } diff --git a/tests/core/partitioning/test_fallback_graph_output.cpp b/tests/core/partitioning/test_fallback_graph_output.cpp index 98fc4e6128..3da717074a 100644 --- a/tests/core/partitioning/test_fallback_graph_output.cpp +++ b/tests/core/partitioning/test_fallback_graph_output.cpp @@ -34,7 +34,7 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) { auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99)); } TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) { @@ -64,6 +64,6 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) { auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99)); } #endif diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp index d01665adcd..e3f0d91dfe 100644 --- a/tests/cpp/test_collections.cpp +++ b/tests/cpp/test_collections.cpp @@ -42,7 +42,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) { auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); auto trt_out = trt_mod.forward(inputs_); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99)); } TEST(CppAPITests, TestCollectionTupleInput) { @@ -85,7 +85,7 @@ TEST(CppAPITests, TestCollectionTupleInput) { auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); auto trt_out = trt_mod.forward(complex_inputs); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99)); } TEST(CppAPITests, TestCollectionListInput) { @@ -144,7 +144,7 @@ TEST(CppAPITests, TestCollectionListInput) { LOG_DEBUG("Finish compile"); auto trt_out = trt_mod.forward(complex_inputs); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99)); } TEST(CppAPITests, TestCollectionTupleInputOutput) { @@ -317,4 +317,4 @@ TEST(CppAPITests, TestCollectionComplexModel) { out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5)); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5)); -} \ No newline at end of file +} diff --git a/tests/cpp/test_compiled_modules.cpp b/tests/cpp/test_compiled_modules.cpp index 595dd7044f..3a81f0a531 100644 --- a/tests/cpp/test_compiled_modules.cpp +++ b/tests/cpp/test_compiled_modules.cpp @@ -42,7 +42,7 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { for (size_t i = 0; i < trt_results.size(); i++) { ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold)); + torch_tensorrt::tests::util::cosineSimEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 0.99)); } } @@ -52,11 +52,7 @@ INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}), PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}), diff --git a/tests/cpp/test_module_fallback.cpp b/tests/cpp/test_module_fallback.cpp index d1221cde4d..bfdfc46b04 100644 --- a/tests/cpp/test_module_fallback.cpp +++ b/tests/cpp/test_module_fallback.cpp @@ -30,7 +30,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) { auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::ts::compile(mod, cfg); auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99)); } TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) { @@ -69,6 +69,6 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) { ASSERT_TRUE(trt_count == 1); auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99)); } #endif diff --git a/tests/cpp/test_modules_as_engines.cpp b/tests/cpp/test_modules_as_engines.cpp index 4437b1218c..11b7a54fb0 100644 --- a/tests/cpp/test_modules_as_engines.cpp +++ b/tests/cpp/test_modules_as_engines.cpp @@ -14,41 +14,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) { jit_results.push_back(jit_results_ivalues.toTensor()); auto trt_results = torch_tensorrt::tests::util::RunModuleForwardAsEngine(mod, inputs); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold)); -} - -TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) { - std::vector inputs; - std::vector inputs_ivalues; - for (uint64_t i = 0; i < input_shapes.size(); i++) { - inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i])); - inputs_ivalues.push_back(inputs[inputs.size() - 1].clone()); - } - - torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, inputs_ivalues); - std::vector jit_results; - jit_results.push_back(jit_results_ivalues.toTensor()); - - std::vector> input_ranges; - for (auto in : inputs) { - input_ranges.push_back(in.sizes()); - } - - auto compile_spec = torch_tensorrt::ts::CompileSpec({input_ranges}); - int device_id = 0; - cudaGetDevice(&device_id); - compile_spec.device.device_type = torch_tensorrt::Device::DeviceType::kGPU; - compile_spec.device.gpu_id = device_id; - auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", input_ranges); - auto trt_mod = torch_tensorrt::ts::embed_engine_in_new_module(engine, compile_spec.device); - - torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, inputs_ivalues); - std::vector trt_results; - trt_results.push_back(trt_results_ivalues.toTensor()); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold)); } #ifndef DISABLE_TEST_IN_CI @@ -57,12 +24,8 @@ INSTANTIATE_TEST_SUITE_P( ModuleAsEngineForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), - PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 1e-4}), - PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2}))); + PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}), + PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}), + PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}))); #endif diff --git a/tests/cpp/test_multi_gpu_serde.cpp b/tests/cpp/test_multi_gpu_serde.cpp index 8672ae9517..0b3944125b 100644 --- a/tests/cpp/test_multi_gpu_serde.cpp +++ b/tests/cpp/test_multi_gpu_serde.cpp @@ -23,12 +23,12 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { trt_results.push_back(trt_results_ivalues.toTensor()); for (size_t i = 0; i < trt_results.size(); i++) { - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( - jit_results[i], trt_results[i].reshape_as(jit_results[i]).to(torch::Device("cuda:0")), 2e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + jit_results[i], trt_results[i].reshape_as(jit_results[i]).to(torch::Device("cuda:0")), threshold)); } } INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}))); + testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}))); diff --git a/tests/cpp/test_multiple_registered_engines.cpp b/tests/cpp/test_multiple_registered_engines.cpp index 2746687f68..658f59ca74 100644 --- a/tests/cpp/test_multiple_registered_engines.cpp +++ b/tests/cpp/test_multiple_registered_engines.cpp @@ -10,7 +10,7 @@ TEST(CppAPITest, CanRunMultipleEngines) { torch::jit::script::Module mod1; torch::jit::script::Module mod2; try { - mod1 = torch::jit::load("tests/modules/resnet50_traced.jit.pt"); + mod1 = torch::jit::load("tests/modules/resnet18_traced.jit.pt"); mod2 = torch::jit::load("tests/modules/resnet18_traced.jit.pt"); } catch (const c10::Error& e) { std::cerr << "error loading the model\n"; @@ -56,13 +56,13 @@ TEST(CppAPITest, CanRunMultipleEngines) { trt2_results.push_back(trt2_results_ivalues.toTensor()); for (size_t i = 0; i < trt1_results.size(); i++) { - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit1_results[i], trt1_results[i].reshape_as(jit1_results[i]), 2e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + jit1_results[i], trt1_results[i].reshape_as(jit1_results[i]), 0.99)); } for (size_t i = 0; i < trt2_results.size(); i++) { - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit2_results[i], trt2_results[i].reshape_as(jit2_results[i]), 2e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + jit2_results[i], trt2_results[i].reshape_as(jit2_results[i]), 0.99)); } } #endif diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index dfae3f18c9..936a4d5c73 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -3,6 +3,7 @@ import torch import torchvision.models as models import os +from utils import cosine_similarity, COSINE_THRESHOLD def find_repo_root(max_depth=10): @@ -40,12 +41,13 @@ def test_compile(self): } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) - same = ( - (trt_mod(self.input, self.input) - self.model(self.input, self.input)) - .abs() - .max() + cos_sim = cosine_similarity( + self.model(self.input, self.input), trt_mod(self.input, self.input) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"standard_tensor_input_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - self.assertTrue(same < 2e-2) class TestTupleInput(unittest.TestCase): @@ -68,12 +70,13 @@ def test_compile(self): } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) - same = ( - (trt_mod((self.input, self.input)) - self.model((self.input, self.input))) - .abs() - .max() + cos_sim = cosine_similarity( + self.model((self.input, self.input)), trt_mod((self.input, self.input)) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"tuple_input_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - self.assertTrue(same < 2e-2) class TestListInput(unittest.TestCase): @@ -94,12 +97,13 @@ def test_compile(self): } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) - same = ( - (trt_mod([self.input, self.input]) - self.model([self.input, self.input])) - .abs() - .max() + cos_sim = cosine_similarity( + self.model([self.input, self.input]), trt_mod([self.input, self.input]) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - self.assertTrue(same < 2e-2) class TestTupleInputOutput(unittest.TestCase): @@ -124,8 +128,12 @@ def test_compile(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)] - self.assertTrue(all(results)) + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) class TestListInputOutput(unittest.TestCase): @@ -150,8 +158,13 @@ def test_compile(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)] - self.assertTrue(all(results)) + + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) class TestListInputTupleOutput(unittest.TestCase): @@ -176,8 +189,12 @@ def test_compile(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)] - self.assertTrue(all(results)) + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) if __name__ == "__main__": diff --git a/tests/py/api/test_e2e_behavior.py b/tests/py/api/test_e2e_behavior.py index d1da3e0465..385fe916f4 100644 --- a/tests/py/api/test_e2e_behavior.py +++ b/tests/py/api/test_e2e_behavior.py @@ -6,102 +6,6 @@ from typing import Dict -class TestCompileHalf(unittest.TestCase): - def test_compile_script_half(self): - self.model = models.resnet18(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.scripted_model = torch.jit.script(self.model) - self.scripted_model.half() - - compile_spec = { - "inputs": [torchtrt.Input(shape=self.input.shape, dtype=torch.half)], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - }, - "enabled_precisions": {torch.half}, - } - - trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) - same = ( - (trt_mod(self.input.half()) - self.scripted_model(self.input.half())) - .abs() - .max() - ) - torchtrt.logging.log(torchtrt.logging.Level.Debug, "Max diff: " + str(same)) - self.assertTrue(same < 3e-2) - - def test_compile_script_half_by_default(self): - self.model = models.resnet18(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.scripted_model = torch.jit.script(self.model) - self.scripted_model.half() - - compile_spec = { - "inputs": [torchtrt.Input(shape=self.input.shape)], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - }, - "enabled_precisions": {torch.float, torch.half}, - } - - trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) - same = ( - (trt_mod(self.input.half()) - self.scripted_model(self.input.half())) - .abs() - .max() - ) - torchtrt.logging.log(torchtrt.logging.Level.Debug, "Max diff: " + str(same)) - self.assertTrue(same < 3e-2) - - -class TestFallbackToTorch(unittest.TestCase): - def test_fallback(self): - self.model = models.resnet18(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.scripted_model = torch.jit.script(self.model) - - compile_spec = { - "inputs": [torchtrt.Input(self.input.shape)], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - "allow_gpu_fallback": False, - "disable_tf32": False, - }, - "require_full_compilation": False, - "torch_executed_ops": ["aten::max_pool2d"], - "min_block_size": 1, - } - - trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() - self.assertTrue(same < 2e-3) - - def test_module_fallback(self): - self.model = models.resnet18(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.scripted_model = torch.jit.script(self.model) - - compile_spec = { - "inputs": [torchtrt.Input(self.input.shape)], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - "allow_gpu_fallback": False, - "disable_tf32": False, - }, - "require_full_compilation": False, - "torch_executed_modules": ["torchvision.models.resnet.BasicBlock"], - "min_block_size": 1, - } - - trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() - self.assertTrue(same < 2e-3) - - class TestInputTypeDefaultsFP32Model(unittest.TestCase): def test_input_use_default_fp32(self): self.model = models.resnet18(pretrained=True).eval().to("cuda") diff --git a/tests/py/api/test_embed_engines.py b/tests/py/api/test_embed_engines.py new file mode 100644 index 0000000000..d21e139eca --- /dev/null +++ b/tests/py/api/test_embed_engines.py @@ -0,0 +1,73 @@ +import unittest +import torch_tensorrt as torchtrt +import torch +import torchvision.models as models +import copy +import timm +from typing import Dict +from utils import cosine_similarity, COSINE_THRESHOLD + + +class TestModelToEngineToModel(unittest.TestCase): + def test_resnet50(self): + self.model = models.resnet50(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + } + + self.scripted_model = torch.jit.script(self.model) + trt_engine = torchtrt.ts.convert_method_to_trt_engine( + self.scripted_model, "forward", **compile_spec + ) + trt_mod = torchtrt.ts.embed_engine_in_new_module(trt_engine) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_efficientnet_b0(self): + self.model = ( + timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + ) + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + } + + self.scripted_model = torch.jit.script(self.model) + trt_engine = torchtrt.ts.convert_method_to_trt_engine( + self.scripted_model, "forward", **compile_spec + ) + trt_mod = torchtrt.ts.embed_engine_in_new_module(trt_engine) + + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/api/test_module_fallback.py b/tests/py/api/test_module_fallback.py new file mode 100644 index 0000000000..5eda2cdbfc --- /dev/null +++ b/tests/py/api/test_module_fallback.py @@ -0,0 +1,62 @@ +import unittest +import torch_tensorrt as torchtrt +import torch +import torchvision.models as models +import copy +from typing import Dict +from utils import cosine_similarity, COSINE_THRESHOLD + + +class TestModuleFallback(unittest.TestCase): + def test_fallback_resnet18(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "torch_executed_modules": ["torchvision.models.resnet.BasicBlock"], + } + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_fallback_mobilenet_v2(self): + self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "torch_executed_modules": [ + "torchvision.models.mobilenetv2.ConvBNActivation" + ], + "min_block_size": 5, + } + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Mobilenet V2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/api/test_operator_fallback.py b/tests/py/api/test_operator_fallback.py new file mode 100644 index 0000000000..302a663e24 --- /dev/null +++ b/tests/py/api/test_operator_fallback.py @@ -0,0 +1,59 @@ +import unittest +import torch_tensorrt as torchtrt +import torch +import torchvision.models as models +import copy +from typing import Dict +from utils import cosine_similarity, COSINE_THRESHOLD + + +class TestFallbackModels(unittest.TestCase): + def test_fallback_resnet18(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "torch_executed_ops": ["aten::add"], + } + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_fallback_mobilenet_v2(self): + self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "torch_executed_ops": ["aten::hardtanh"], + } + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Mobilenet V2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/api/test_ts_backend.py b/tests/py/api/test_ts_backend.py index d0654a8f75..e56ab4f902 100644 --- a/tests/py/api/test_ts_backend.py +++ b/tests/py/api/test_ts_backend.py @@ -4,6 +4,7 @@ import torchvision.models as models import copy from typing import Dict +from utils import cosine_similarity, COSINE_THRESHOLD class TestCompile(unittest.TestCase): @@ -26,8 +27,11 @@ def test_compile_traced(self): } trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"VGG16 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_compile_script(self): self.model = models.vgg16(pretrained=True).eval().to("cuda") @@ -40,8 +44,11 @@ def test_compile_script(self): device=torchtrt.Device(gpu_id=0), enabled_precisions={torch.float}, ) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"VGG16 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_compile_global(self): self.model = models.vgg16(pretrained=True).eval().to("cuda") @@ -53,21 +60,11 @@ def test_compile_global(self): device=torchtrt.Device(gpu_id=0), enabled_precisions={torch.float}, ) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) - - def test_compile_global_nn_mod(self): - self.model = models.vgg16(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - with torch.no_grad(): - trt_mod = torchtrt.compile( - self.model, - inputs=[self.input], - device=torchtrt.Device(gpu_id=0), - enabled_precisions={torch.float}, - ) - same = (trt_mod(self.input) - self.model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"VGG16 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_from_torch_tensor(self): self.model = models.vgg16(pretrained=True).eval().to("cuda") @@ -83,8 +80,11 @@ def test_from_torch_tensor(self): } trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"VGG16 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_device(self): self.model = models.vgg16(pretrained=True).eval().to("cuda") @@ -97,8 +97,11 @@ def test_device(self): } trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"VGG16 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_default_device(self): self.model = models.vgg16(pretrained=True).eval().to("cuda") @@ -107,51 +110,11 @@ def test_default_device(self): compile_spec = {"inputs": [self.input], "enabled_precisions": {torch.float}} trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) - - def test_compile_script_from_dict(self): - self.model = models.vgg16(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.traced_model = torch.jit.trace(self.model, [self.input]) - compile_spec = { - "inputs": [torchtrt.Input(shape=self.input.shape)], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - }, - "enabled_precisions": {torch.float}, - } - - trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) - - -class TestPTtoTRTtoPT(unittest.TestCase): - def test_pt_to_trt_to_pt(self): - self.model = models.vgg16(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.ts_model = torch.jit.trace(self.model, [self.input]) - - compile_spec = { - "inputs": [torchtrt.Input(self.input.shape)], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - "allow_gpu_fallback": False, - "disable_tf32": False, - }, - } - - trt_engine = torchtrt.ts.convert_method_to_trt_engine( - self.ts_model, "forward", **compile_spec - ) - trt_mod = torchtrt.ts.embed_engine_in_new_module( - trt_engine, torchtrt.Device("cuda:0") + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"VGG16 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max() - self.assertTrue(same < 2e-3) class TestCheckMethodOpSupport(unittest.TestCase): diff --git a/tests/py/api/utils.py b/tests/py/api/utils.py new file mode 100644 index 0000000000..b1e6632ec3 --- /dev/null +++ b/tests/py/api/utils.py @@ -0,0 +1,15 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res diff --git a/tests/py/hw/test_api_dla.py b/tests/py/hw/test_api_dla.py index 57b149faa7..5328b92233 100644 --- a/tests/py/hw/test_api_dla.py +++ b/tests/py/hw/test_api_dla.py @@ -2,6 +2,7 @@ import torch_tensorrt as torchtrt import torch import torchvision.models as models +from utils import cosine_similarity, COSINE_THRESHOLD class ModelTestCaseOnDLA(unittest.TestCase): @@ -39,8 +40,11 @@ def test_compile_traced(self): } trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"ModelTestCaseOnDLA traced TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_compile_script(self): compile_spec = { @@ -55,8 +59,11 @@ def test_compile_script(self): } trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() - self.assertTrue(same < 2e-2) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"ModelTestCaseOnDLA scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_suite(): diff --git a/tests/py/hw/test_multi_gpu.py b/tests/py/hw/test_multi_gpu.py index c068cc71b0..b6fa3f220b 100644 --- a/tests/py/hw/test_multi_gpu.py +++ b/tests/py/hw/test_multi_gpu.py @@ -35,9 +35,12 @@ def test_compile_traced(self): trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) torchtrt.set_device(self.target_gpu) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) torchtrt.set_device(0) - self.assertTrue(same < 2e-3) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TestMultiGpuSwitching traced TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_compile_script(self): torchtrt.set_device(0) @@ -54,9 +57,12 @@ def test_compile_script(self): trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) torchtrt.set_device(self.target_gpu) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) torchtrt.set_device(0) - self.assertTrue(same < 2e-3) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TestMultiGpuSwitching scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) class TestMultiGpuSerializeDeserializeSwitching(ModelTestCase): @@ -89,8 +95,11 @@ def test_compile_traced(self): trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec) # Changing the device ID deliberately. It should still run on correct device ID by context switching torchtrt.set_device(1) - same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() - self.assertTrue(same < 2e-3) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TestMultiGpuSerializeDeserializeSwitching traced TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_compile_script(self): torchtrt.set_device(0) @@ -108,8 +117,11 @@ def test_compile_script(self): trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec) # Changing the device ID deliberately. It should still run on correct device ID by context switching torchtrt.set_device(1) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() - self.assertTrue(same < 2e-3) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TestMultiGpuSerializeDeserializeSwitching scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) def test_suite(): diff --git a/tests/py/hw/utils.py b/tests/py/hw/utils.py new file mode 100644 index 0000000000..b1e6632ec3 --- /dev/null +++ b/tests/py/hw/utils.py @@ -0,0 +1,15 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res diff --git a/tests/py/integrations/test_to_backend_api.py b/tests/py/integrations/test_to_backend_api.py index 16d839b1b0..0f74a3af15 100644 --- a/tests/py/integrations/test_to_backend_api.py +++ b/tests/py/integrations/test_to_backend_api.py @@ -2,6 +2,7 @@ import torch_tensorrt as torchtrt import torch import torchvision.models as models +from utils import cosine_similarity, COSINE_THRESHOLD class TestToBackendLowering(unittest.TestCase): @@ -31,10 +32,11 @@ def setUp(self): def test_to_backend_lowering(self): trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec) - same = ( - (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max() + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TestToBackendLowering TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - self.assertTrue(same < 2e-3) if __name__ == "__main__": diff --git a/tests/py/integrations/test_trt_intercompatibility.py b/tests/py/integrations/test_trt_intercompatibility.py index 96b47b7ccc..b938e4a1ac 100644 --- a/tests/py/integrations/test_trt_intercompatibility.py +++ b/tests/py/integrations/test_trt_intercompatibility.py @@ -3,6 +3,7 @@ import torch import torchvision.models as models import tensorrt as trt +from utils import cosine_similarity, COSINE_THRESHOLD class TestPyTorchToTRTEngine(unittest.TestCase): @@ -42,8 +43,11 @@ def test_pt_to_trt(self): device="cuda:0" ).cuda_stream, ) - same = (out - self.ts_model(self.input)).abs().max() - self.assertTrue(same < 2e-3) + cos_sim = cosine_similarity(self.model(self.input), out) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TestPyTorchToTRTEngine TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) if __name__ == "__main__": diff --git a/tests/py/integrations/utils.py b/tests/py/integrations/utils.py new file mode 100644 index 0000000000..b1e6632ec3 --- /dev/null +++ b/tests/py/integrations/utils.py @@ -0,0 +1,15 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res diff --git a/tests/py/models/custom_models.py b/tests/py/models/custom_models.py new file mode 100644 index 0000000000..a19b9ca81c --- /dev/null +++ b/tests/py/models/custom_models.py @@ -0,0 +1,28 @@ +import torch +from transformers import BertModel, BertTokenizer, BertConfig + + +def BertModule(): + model_name = "bert-base-uncased" + enc = BertTokenizer.from_pretrained(model_name) + text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" + tokenized_text = enc.tokenize(text) + masked_index = 8 + tokenized_text[masked_index] = "[MASK]" + indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) + segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] + tokens_tensor = torch.tensor([indexed_tokens]) + segments_tensors = torch.tensor([segments_ids]) + config = BertConfig( + vocab_size_or_config_json_file=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + torchscript=True, + ) + model = BertModel(config) + model.eval() + model = BertModel.from_pretrained(model_name, torchscript=True) + traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) + return traced_model diff --git a/tests/py/models/test_models.py b/tests/py/models/test_models.py new file mode 100644 index 0000000000..6cc9759626 --- /dev/null +++ b/tests/py/models/test_models.py @@ -0,0 +1,153 @@ +import unittest +import torch_tensorrt as torchtrt +import torch +import torchvision.models as models +import copy +import timm +import custom_models as cm +from typing import Dict +from utils import cosine_similarity, COSINE_THRESHOLD + + +class TestModels(unittest.TestCase): + def test_resnet18(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_mobilenet_v2(self): + self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_efficientnet_b0(self): + self.model = ( + timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + ) + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_bert_base_uncased(self): + self.model = cm.BertModule().cuda() + self.input = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, + dtype=self.input.dtype, + format=torch.contiguous_format, + ), + torchtrt.Input( + self.input.shape, + dtype=self.input.dtype, + format=torch.contiguous_format, + ), + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + } + with torchtrt.logging.errors(): + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + + model_outputs = self.model(self.input, self.input) + trt_model_outputs = trt_mod(self.input, self.input) + for out, trt_out in zip(model_outputs, trt_model_outputs): + cos_sim = cosine_similarity(out, trt_out) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_resnet18_half(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.scripted_model = torch.jit.script(self.model) + self.scripted_model.half() + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.half, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.half}, + } + + trt_mod = torchtrt.compile(self.scripted_model, **compile_spec) + cos_sim = cosine_similarity( + self.model.half()(self.input.half()), trt_mod(self.input.half()) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/models/test_multiple_registered_engines.py b/tests/py/models/test_multiple_registered_engines.py new file mode 100644 index 0000000000..98f012597b --- /dev/null +++ b/tests/py/models/test_multiple_registered_engines.py @@ -0,0 +1,52 @@ +import unittest +import torch_tensorrt as torchtrt +import torch +import torchvision.models as models +import copy +import timm +import custom_models as cm +from typing import Dict +from utils import cosine_similarity, COSINE_THRESHOLD + + +class TestModelToEngineToModel(unittest.TestCase): + def test_multiple_engines(self): + self.resnet18 = models.resnet18(pretrained=True).eval().to("cuda") + self.resnet50 = models.resnet50(pretrained=True).eval().to("cuda") + self.input1 = torch.randn((1, 3, 224, 224)).to("cuda") + self.input2 = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input1.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + } + rn18_trt_mod = torchtrt.compile(self.resnet18, **compile_spec) + rn50_trt_mod = torchtrt.compile(self.resnet50, **compile_spec) + + cos_sim = cosine_similarity( + self.resnet18(self.input1), rn18_trt_mod(self.input1) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity( + self.resnet50(self.input1), rn50_trt_mod(self.input1) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/models/utils.py b/tests/py/models/utils.py new file mode 100644 index 0000000000..b1e6632ec3 --- /dev/null +++ b/tests/py/models/utils.py @@ -0,0 +1,15 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res diff --git a/tests/py/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ptq/test_ptq_dataloader_calibrator.py index 2ee1fa5b08..79c19dadbf 100644 --- a/tests/py/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ptq/test_ptq_dataloader_calibrator.py @@ -81,9 +81,6 @@ def test_compile_script(self): device=torch.device("cuda:0"), ) - fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model) - log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) - compile_spec = { "inputs": [torchtrt.Input([1, 3, 32, 32])], "enabled_precisions": {torch.float, torch.int8}, @@ -96,8 +93,11 @@ def test_compile_script(self): "allow_gpu_fallback": False, }, } - trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + + fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model) + log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) + int8_test_acc = compute_accuracy(self.testing_dataloader, trt_mod) log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc)) acc_diff = fp32_test_acc - int8_test_acc diff --git a/tests/py/utils.py b/tests/py/utils.py new file mode 100644 index 0000000000..b1e6632ec3 --- /dev/null +++ b/tests/py/utils.py @@ -0,0 +1,15 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res diff --git a/tests/util/util.cpp b/tests/util/util.cpp index 13d0d18566..8359d31576 100644 --- a/tests/util/util.cpp +++ b/tests/util/util.cpp @@ -1,10 +1,23 @@ #include "core/util/prelude.h" #include "torch/script.h" +#include "torch/torch.h" namespace torch_tensorrt { namespace tests { namespace util { +bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float threshold = 0.99f) { + torch::Tensor cosine_sim = torch::nn::functional::cosine_similarity( + computed_tensor.flatten(), gt_tensor.flatten(), torch::nn::functional::CosineSimilarityFuncOptions().dim(0)); + std::ostringstream ss; + ss << computed_tensor << std::endl << gt_tensor << std::endl; + LOG_GRAPH(ss.str()); + LOG_GRAPH(std::string("Cosine Similarity score: ") + std::to_string(cosine_sim.item())); + LOG_GRAPH(std::string("Acceptable Threshold: ") + std::to_string(threshold)); + + return cosine_sim.item() >= threshold; +} + bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5) { std::ostringstream ss; ss << computed_tensor << std::endl << gt_tensor << std::endl; diff --git a/tests/util/util.h b/tests/util/util.h index f39e2a5766..1ea62a16e0 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -11,6 +11,8 @@ namespace torch_tensorrt { namespace tests { namespace util { +bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float threshold); + bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5); bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);