From de9bdf6afe35a70e0a06033f3b6220ab83804e36 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 2 Nov 2022 15:25:23 -0700 Subject: [PATCH 1/2] fix: Add check to ensure einsum converter has correct args - TRT einsum implementation currently supports 2 inputs, however the converter will accept any number of inputs and TRT throws an error at compilation - `aten::einsum` converter now checks that the tensor argument list does not exceed 2 elements, and throws an informative error otherwise --- core/conversion/converters/impl/einsum.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/conversion/converters/impl/einsum.cpp b/core/conversion/converters/impl/einsum.cpp index fb031f6c38..4cb501e3f7 100644 --- a/core/conversion/converters/impl/einsum.cpp +++ b/core/conversion/converters/impl/einsum.cpp @@ -18,6 +18,13 @@ auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat auto equation = args[0].unwrapToString(); auto in = args[1].IValue()->toListRef(); + TORCHTRT_CHECK( + in.size() <= 2, + "TensorRT currently supports up to 2 input tensors " + << "to einsum but operation had " << in.size() + << " input tensors, please specify torch_executed_ops=[aten::einsum] " + << "at compilation time to avoid this error."); + std::vector tensors; // Populate vector of ITensor pointers From 240db32c0724200e292062764fac693d92a43ca7 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 2 Nov 2022 16:47:53 -0700 Subject: [PATCH 2/2] Add escaped quotes so user can copy-paste printed solution --- core/conversion/converters/impl/einsum.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/einsum.cpp b/core/conversion/converters/impl/einsum.cpp index 4cb501e3f7..fdaafe4e33 100644 --- a/core/conversion/converters/impl/einsum.cpp +++ b/core/conversion/converters/impl/einsum.cpp @@ -22,7 +22,7 @@ auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat in.size() <= 2, "TensorRT currently supports up to 2 input tensors " << "to einsum but operation had " << in.size() - << " input tensors, please specify torch_executed_ops=[aten::einsum] " + << " input tensors, please specify torch_executed_ops=[\"aten::einsum\"] " << "at compilation time to avoid this error."); std::vector tensors;