@@ -9,7 +9,8 @@ TEST(Converters, ATenEinsumConvertsMatMulCorrectly) {
9
9
graph(%x.1 : Tensor, %x.2 : Tensor):
10
10
%0 : str = prim::Constant[value="ij,jk->ik"]()
11
11
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
12
- %4 : Tensor = aten::einsum(%0, %3)
12
+ %none : NoneType = prim::Constant()
13
+ %4 : Tensor = aten::einsum(%0, %3, %none)
13
14
return (%4))IR" ;
14
15
15
16
auto g = std::make_shared<torch::jit::Graph>();
@@ -34,7 +35,8 @@ TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) {
34
35
graph(%x.1 : Tensor, %x.2 : Tensor):
35
36
%0 : str = prim::Constant[value="abcd,abcd->abcd"]()
36
37
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
37
- %4 : Tensor = aten::einsum(%0, %3)
38
+ %none : NoneType = prim::Constant()
39
+ %4 : Tensor = aten::einsum(%0, %3, %none)
38
40
return (%4))IR" ;
39
41
40
42
auto g = std::make_shared<torch::jit::Graph>();
@@ -59,7 +61,8 @@ TEST(Converters, ATenEinsumConvertsTransposeCorrectly) {
59
61
graph(%x.1 : Tensor):
60
62
%0 : str = prim::Constant[value="jk->kj"]()
61
63
%3 : Tensor[] = prim::ListConstruct(%x.1)
62
- %4 : Tensor = aten::einsum(%0, %3)
64
+ %none : NoneType = prim::Constant()
65
+ %4 : Tensor = aten::einsum(%0, %3, %none)
63
66
return (%4))IR" ;
64
67
65
68
auto g = std::make_shared<torch::jit::Graph>();
@@ -83,7 +86,8 @@ TEST(Converters, ATenEinsumConvertsVectorsCorrectly) {
83
86
graph(%x.1 : Tensor, %x.2 : Tensor):
84
87
%0 : str = prim::Constant[value="a,b->ab"]()
85
88
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
86
- %4 : Tensor = aten::einsum(%0, %3)
89
+ %none : NoneType = prim::Constant()
90
+ %4 : Tensor = aten::einsum(%0, %3, %none)
87
91
return (%4))IR" ;
88
92
89
93
auto g = std::make_shared<torch::jit::Graph>();
0 commit comments