diff --git a/cpp/bin/torchtrtc/README.md b/cpp/bin/torchtrtc/README.md index 242c0eebad..d889d8ffdd 100644 --- a/cpp/bin/torchtrtc/README.md +++ b/cpp/bin/torchtrtc/README.md @@ -89,10 +89,12 @@ torchtrtc [input_file_path] [output_file_path] used to select kernels --workspace-size=[workspace_size] Maximum size of workspace given to TensorRT - -t[threshold], - --threshold=[threshold] Maximum acceptable numerical deviation - from standard torchscript output - (default 2e-5) + --atol=[atol] Absolute tolerance threshold for acceptable + numerical deviation from standard torchscript + output (default 1e-8) + --rtol=[rtol] Relative tolerance threshold for acceptable + numerical deviation from standard torchscript + output (default 1e-5) --no-threshold-check Skip checking threshold compliance --truncate-long-double, --truncate, --truncate-64bit Truncate weights that are provided in diff --git a/cpp/bin/torchtrtc/accuracy.cpp b/cpp/bin/torchtrtc/accuracy.cpp index 255bfdb1fa..3268920a4b 100644 --- a/cpp/bin/torchtrtc/accuracy.cpp +++ b/cpp/bin/torchtrtc/accuracy.cpp @@ -19,9 +19,24 @@ bool check_rtol(const at::Tensor& diff, const std::vector inputs, fl return diff.abs().max().item() <= threshold * maxValue; } -bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold) { - return check_rtol(a - b, {a, b}, threshold); +bool almost_equal( + const at::Tensor& computed_tensor, + const at::Tensor& gt_tensor, // gt_tensor : Ground Truth Tensor + float atol, + float rtol) { + auto computed_tensor_float = computed_tensor.toType(at::kFloat); + auto gt_tensor_float = gt_tensor.toType(at::kFloat); + + auto diff = computed_tensor_float - gt_tensor_float; + auto result = diff.abs().max().item(); + auto threshold = atol + (rtol * gt_tensor.abs().max().item()); + + torchtrt::logging::log(torchtrt::logging::Level::kDEBUG, std::string("Max Difference: ") + std::to_string(result)); + torchtrt::logging::log( + torchtrt::logging::Level::kDEBUG, std::string("Acceptable Threshold: ") + std::to_string(threshold)); + + return result <= threshold; } } // namespace accuracy -} // namespace torchtrtc \ No newline at end of file +} // namespace torchtrtc diff --git a/cpp/bin/torchtrtc/accuracy.h b/cpp/bin/torchtrtc/accuracy.h index ee54cc2eef..9c751f568c 100644 --- a/cpp/bin/torchtrtc/accuracy.h +++ b/cpp/bin/torchtrtc/accuracy.h @@ -12,7 +12,7 @@ namespace torchtrtc { namespace accuracy { bool check_rtol(const at::Tensor& diff, const std::vector inputs, float threshold); -bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold); +bool almost_equal(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5); } // namespace accuracy -} // namespace torchtrtc \ No newline at end of file +} // namespace torchtrtc diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index 4d733f274d..f43642584e 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -119,11 +119,16 @@ int main(int argc, char** argv) { parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"}); args::ValueFlag workspace_size( parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"}); - args::ValueFlag threshold( + args::ValueFlag atol( parser, - "threshold", - "Maximum acceptable numerical deviation from standard torchscript output (default 2e-5)", - {'t', "threshold"}); + "atol", + "Absolute tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-8)", + {"atol"}); + args::ValueFlag rtol( + parser, + "rtol", + "Relative tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-5)", + {"rtol"}); args::Flag no_threshold_check( parser, "no-threshold-check", "Skip checking threshold compliance", {"no-threshold-check", "no-threshold-check"}); @@ -392,9 +397,13 @@ int main(int argc, char** argv) { (compile_settings.enabled_precisions.size() == 1 && compile_settings.enabled_precisions.find(torchtrt::DataType::kFloat) != compile_settings.enabled_precisions.end())) { - double threshold_val = 2e-5; - if (threshold) { - threshold_val = args::get(threshold); + double atol_val = 1e-8; + double rtol_val = 1e-5; + if (atol) { + atol_val = args::get(atol); + } + if (rtol) { + rtol_val = args::get(rtol); } std::vector jit_inputs_ivalues; @@ -431,14 +440,18 @@ int main(int argc, char** argv) { } for (size_t i = 0; i < trt_results.size(); i++) { + std::ostringstream threshold_ss; + threshold_ss << "atol: " << atol_val << " rtol: " << rtol_val; if (!torchtrtc::accuracy::almost_equal( - jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold_val)) { - std::ostringstream threshold_ss; - threshold_ss << threshold_val; + jit_results[i], trt_results[i].reshape_as(jit_results[i]), atol_val, rtol_val)) { torchtrt::logging::log( torchtrt::logging::Level::kWARNING, - std::string("Maximum numerical deviation for output exceeds set threshold (") + threshold_ss.str() + - std::string(")")); + std::string("Maximum numerical deviation for output exceeds tolerance thresholds (") + + threshold_ss.str() + std::string(")")); + } else { + torchtrt::logging::log( + torchtrt::logging::Level::kDEBUG, + std::string("Maximum numerical deviation within threshold limits ") + threshold_ss.str()); } } } else { diff --git a/docsrc/tutorials/torchtrtc.rst b/docsrc/tutorials/torchtrtc.rst index b841c891e5..dc9e4b6768 100644 --- a/docsrc/tutorials/torchtrtc.rst +++ b/docsrc/tutorials/torchtrtc.rst @@ -92,10 +92,12 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r used to select kernels --workspace-size=[workspace_size] Maximum size of workspace given to TensorRT - -t[threshold], - --threshold=[threshold] Maximum acceptable numerical deviation - from standard torchscript output - (default 2e-5) + --atol=[atol] Absolute tolerance threshold for acceptable + numerical deviation from standard torchscript + output (default 1e-8) + --rtol=[rtol] Relative tolerance threshold for acceptable + numerical deviation from standard torchscript + output (default 1e-5) --no-threshold-check Skip checking threshold compliance --truncate-long-double, --truncate, --truncate-64bit Truncate weights that are provided in