Skip to content

Commit c4195c6

Browse files
author
Yinghai Lu
authored
fix eisum signature (#1480)
1 parent 04c5870 commit c4195c6

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

core/conversion/converters/impl/einsum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace impl {
1212
namespace {
1313

1414
auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
15-
{"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)",
15+
{"aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> (Tensor)",
1616
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1717
// Extract equation and list of tensors
1818
auto equation = args[0].unwrapToString();

tests/core/conversion/converters/test_einsum.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ TEST(Converters, ATenEinsumConvertsMatMulCorrectly) {
99
graph(%x.1 : Tensor, %x.2 : Tensor):
1010
%0 : str = prim::Constant[value="ij,jk->ik"]()
1111
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
12-
%4 : Tensor = aten::einsum(%0, %3)
12+
%none : NoneType = prim::Constant()
13+
%4 : Tensor = aten::einsum(%0, %3, %none)
1314
return (%4))IR";
1415

1516
auto g = std::make_shared<torch::jit::Graph>();
@@ -34,7 +35,8 @@ TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) {
3435
graph(%x.1 : Tensor, %x.2 : Tensor):
3536
%0 : str = prim::Constant[value="abcd,abcd->abcd"]()
3637
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
37-
%4 : Tensor = aten::einsum(%0, %3)
38+
%none : NoneType = prim::Constant()
39+
%4 : Tensor = aten::einsum(%0, %3, %none)
3840
return (%4))IR";
3941

4042
auto g = std::make_shared<torch::jit::Graph>();
@@ -59,7 +61,8 @@ TEST(Converters, ATenEinsumConvertsTransposeCorrectly) {
5961
graph(%x.1 : Tensor):
6062
%0 : str = prim::Constant[value="jk->kj"]()
6163
%3 : Tensor[] = prim::ListConstruct(%x.1)
62-
%4 : Tensor = aten::einsum(%0, %3)
64+
%none : NoneType = prim::Constant()
65+
%4 : Tensor = aten::einsum(%0, %3, %none)
6366
return (%4))IR";
6467

6568
auto g = std::make_shared<torch::jit::Graph>();
@@ -83,7 +86,8 @@ TEST(Converters, ATenEinsumConvertsVectorsCorrectly) {
8386
graph(%x.1 : Tensor, %x.2 : Tensor):
8487
%0 : str = prim::Constant[value="a,b->ab"]()
8588
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
86-
%4 : Tensor = aten::einsum(%0, %3)
89+
%none : NoneType = prim::Constant()
90+
%4 : Tensor = aten::einsum(%0, %3, %none)
8791
return (%4))IR";
8892

8993
auto g = std::make_shared<torch::jit::Graph>();

0 commit comments

Comments
 (0)