Skip to content

Commit c8b246a

Browse files
apaszkefacebook-github-bot
authored andcommitted
Prevent JIT from overspecializing to every single size configuration (#10844)
Summary: Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details. Summary of changes: - Renamed `TensorType` to `CompleteTensorType`. Added a new `TensorType` which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use `CompleteTensorType` less, as most passes will only have limited information available. To make transition easier `complete_type->cast<TensorType>()` works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail. - Renamed `ArgumentSpec` to `CompleteArgumentSpec`. Added a new `ArgumentSpec`, which matches argument only at the level of the new `TensorType`. - Shape analysis can process graphs with both `CompleteTensorType` and `TensorType`. - Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in `graph_fuser.cpp`. zdevito ezyang mruberry ngimel csarofeen Pull Request resolved: #10844 Differential Revision: D9498705 Pulled By: apaszke fbshipit-source-id: 0c53c2fcebd871cc2a29c260f8d012276479cc61
1 parent 9679fc5 commit c8b246a

File tree

53 files changed

+1746
-1057
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1746
-1057
lines changed
Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
graph(%0 : Float(4, 4)
2-
%1 : Float(4)
3-
%2 : Float(4)) {
4-
%3 : int[] = prim::Constant[value=[4, 4]]()
5-
%4 : int = prim::Constant[value=0]()
6-
%5 : Float(4!, 4) = aten::expand(%1, %3, %4)
7-
%6 : Float(4!, 4) = aten::expand(%2, %3, %4)
8-
%7 : Float(4, 4) = prim::FusionGroup_0[device=0](%6, %0, %5)
9-
return (%7);
1+
graph(%0 : Float(*, *)
2+
%1 : Float(*)
3+
%2 : Float(*)) {
4+
%3 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %0, %1)
5+
return (%3);
106
}
11-
with prim::FusionGroup_0 = graph(%1 : Float(4!, 4)
12-
%4 : Float(4, 4)
13-
%5 : Float(4!, 4)) {
14-
%6 : Float(4, 4) = aten::mul(%4, %5)
7+
with prim::FusionGroup_0 = graph(%1 : Float(*)
8+
%4 : Float(*, *)
9+
%5 : Float(*)) {
10+
%6 : Float(*, *) = aten::mul(%4, %5)
1511
%2 : int = prim::Constant[value=1]()
16-
%3 : Float(4, 4) = aten::add(%6, %1, %2)
12+
%3 : Float(*, *) = aten::add(%6, %1, %2)
1713
return (%3);
1814
}
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
graph(%0 : Float(3, 20)
2-
%1 : Float(3, 20)) {
3-
%2 : Float(6, 20) = prim::FusionGroup_0[device=0](%0, %1)
1+
graph(%0 : Float(*, *)
2+
%1 : Float(*, *)) {
3+
%2 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1)
44
return (%2);
55
}
6-
with prim::FusionGroup_0 = graph(%3 : Float(3, 20)
7-
%4 : Float(3, 20)) {
6+
with prim::FusionGroup_0 = graph(%3 : Float(*, *)
7+
%4 : Float(*, *)) {
88
%6 : int = prim::Constant[value=1]()
9-
%7 : Float(3, 20) = aten::add(%3, %4, %6)
10-
%5 : Float(3, 20) = aten::mul(%3, %4)
11-
%2 : Float(6, 20) = prim::FusedConcat[dim=0](%7, %5)
9+
%7 : Float(*, *) = aten::add(%3, %4, %6)
10+
%5 : Float(*, *) = aten::mul(%3, %4)
11+
%2 : Float(*, *) = prim::FusedConcat[dim=0](%7, %5)
1212
return (%2);
1313
}
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
graph(%0 : Float(2, 2)
2-
%1 : Float(2, 2)
3-
%2 : Float(4, 2)) {
1+
graph(%0 : Float(*, *)
2+
%1 : Float(*, *)
3+
%2 : Float(*, *)) {
44
%3 : int = prim::Constant[value=1]()
5-
%4 : Float(4, 2) = prim::FusionGroup_0[device=0](%0, %1)
6-
%5 : Float(4, 2) = aten::add(%4, %2, %3)
5+
%4 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1)
6+
%5 : Float(*, *) = aten::add(%4, %2, %3)
77
return (%5);
88
}
9-
with prim::FusionGroup_0 = graph(%3 : Float(2, 2)
10-
%4 : Float(2, 2)) {
9+
with prim::FusionGroup_0 = graph(%3 : Float(*, *)
10+
%4 : Float(*, *)) {
1111
%7 : int = prim::Constant[value=1]()
12-
%8 : Float(2, 2) = aten::add(%3, %4, %7)
12+
%8 : Float(*, *) = aten::add(%3, %4, %7)
1313
%5 : int = prim::Constant[value=1]()
14-
%6 : Float(2, 2) = aten::sub(%3, %4, %5)
15-
%2 : Float(4, 2) = prim::FusedConcat[dim=0](%8, %6)
14+
%6 : Float(*, *) = aten::sub(%3, %4, %5)
15+
%2 : Float(*, *) = prim::FusedConcat[dim=0](%8, %6)
1616
return (%2);
1717
}
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
graph(%0 : Float(1)
2-
%1 : Float(1)) {
3-
%2 : Float(1) = prim::FusionGroup_0[device=1](%0, %1)
1+
graph(%0 : Float(*)
2+
%1 : Float(*)) {
3+
%2 : Float(*) = prim::FusionGroup_0[device=1](%0, %1)
44
return (%2);
55
}
6-
with prim::FusionGroup_0 = graph(%5 : Float(1)
7-
%10 : Float(1)) {
6+
with prim::FusionGroup_0 = graph(%5 : Float(*)
7+
%10 : Float(*)) {
88
%11 : int = prim::Constant[value=1]()
9-
%12 : Float(1) = aten::add(%5, %10, %11)
10-
%9 : Float(1) = aten::mul(%5, %12)
9+
%12 : Float(*) = aten::add(%5, %10, %11)
10+
%9 : Float(*) = aten::mul(%5, %12)
1111
%6 : int = prim::Constant[value=1]()
12-
%7 : Float(1) = aten::add(%9, %5, %6)
13-
%3 : Float(1) = aten::tanh(%7)
14-
%1 : Float(1) = aten::sigmoid(%3)
12+
%7 : Float(*) = aten::add(%9, %5, %6)
13+
%3 : Float(*) = aten::tanh(%7)
14+
%1 : Float(*) = aten::sigmoid(%3)
1515
return (%1);
1616
}
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
graph(%0 : Float(4, 4)
2-
%1 : Float(4, 4)) {
3-
%2 : Float(4, 2) = prim::FusionGroup_0[device=0](%0, %1)
4-
return (%2);
1+
graph(%0 : Float(*, *)
2+
%1 : Float(*, *)) {
3+
%2 : Dynamic[] = prim::ListConstruct(%0, %1)
4+
%3 : Dynamic, %4 : Dynamic = aten::broadcast_tensors(%2)
5+
%5 : Float(*, *) = prim::FusionGroup_0[device=0](%3, %4)
6+
return (%5);
57
}
6-
with prim::FusionGroup_0 = graph(%11 : Float(4, 4)
7-
%14 : Float(4, 4)) {
8+
with prim::FusionGroup_0 = graph(%11 : Dynamic
9+
%14 : Dynamic) {
810
%15 : Dynamic, %16 : Dynamic = prim::FusedChunk[chunks=2, dim=1](%14)
911
%12 : Dynamic, %13 : Dynamic = prim::FusedChunk[chunks=2, dim=1](%11)
1012
%9 : int = prim::Constant[value=1]()
11-
%10 : Float(4, 2) = aten::add(%13, %16, %9)
13+
%10 : Float(*, *) = aten::add(%13, %16, %9)
1214
%5 : int = prim::Constant[value=1]()
13-
%6 : Float(4, 2) = aten::add(%12, %15, %5)
14-
%2 : Float(4, 2) = aten::mul(%6, %10)
15+
%6 : Float(*, *) = aten::add(%12, %15, %5)
16+
%2 : Float(*, *) = aten::mul(%6, %10)
1517
return (%2);
1618
}
Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,43 @@
1-
graph(%0 : Float(3, 10)
2-
%1 : Float(3, 20)
3-
%2 : Float(3, 20)
4-
%3 : Float(80, 10)
5-
%4 : Float(80, 20)
6-
%5 : Float(80)
7-
%6 : Float(80)) {
8-
%7 : Float(10!, 80!) = aten::t(%3)
1+
graph(%0 : Float(*, *)
2+
%1 : Float(*, *)
3+
%2 : Float(*, *)
4+
%3 : Float(*, *)
5+
%4 : Float(*, *)
6+
%5 : Float(*)
7+
%6 : Float(*)) {
8+
%7 : Float(*, *) = aten::t(%3)
99
%8 : int = prim::Constant[value=1]()
10-
%9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8)
11-
%10 : Float(20!, 80!) = aten::t(%4)
12-
%11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8)
13-
%12 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %9, %11)
14-
return (%12);
10+
%9 : Float(*, *) = aten::addmm(%5, %0, %7, %8, %8)
11+
%10 : Float(*, *) = aten::t(%4)
12+
%11 : Float(*, *) = aten::addmm(%6, %1, %10, %8, %8)
13+
%12 : Dynamic[] = prim::ListConstruct(%9, %11)
14+
%13 : Dynamic, %14 : Dynamic = aten::broadcast_tensors(%12)
15+
%15 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %13, %14)
16+
return (%15);
1517
}
16-
with prim::FusionGroup_0 = graph(%15 : Float(3, 20)
17-
%41 : Float(3, 80)
18-
%46 : Float(3, 80)) {
18+
with prim::FusionGroup_0 = graph(%15 : Float(*, *)
19+
%41 : Dynamic
20+
%46 : Dynamic) {
1921
%47 : Dynamic, %48 : Dynamic, %49 : Dynamic, %50 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%46)
2022
%42 : Dynamic, %43 : Dynamic, %44 : Dynamic, %45 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%41)
2123
%39 : int = prim::Constant[value=1]()
22-
%40 : Float(3, 20) = aten::add(%42, %47, %39)
24+
%40 : Float(*, *) = aten::add(%42, %47, %39)
2325
%35 : int = prim::Constant[value=1]()
24-
%36 : Float(3, 20) = aten::add(%43, %48, %35)
26+
%36 : Float(*, *) = aten::add(%43, %48, %35)
2527
%31 : int = prim::Constant[value=1]()
26-
%32 : Float(3, 20) = aten::add(%44, %49, %31)
28+
%32 : Float(*, *) = aten::add(%44, %49, %31)
2729
%27 : int = prim::Constant[value=1]()
28-
%28 : Float(3, 20) = aten::add(%45, %50, %27)
29-
%24 : Float(3, 20) = aten::sigmoid(%40)
30-
%22 : Float(3, 20) = aten::sigmoid(%36)
31-
%20 : Float(3, 20) = aten::tanh(%32)
32-
%18 : Float(3, 20) = aten::sigmoid(%28)
33-
%16 : Float(3, 20) = aten::mul(%22, %15)
34-
%13 : Float(3, 20) = aten::mul(%24, %20)
30+
%28 : Float(*, *) = aten::add(%45, %50, %27)
31+
%24 : Float(*, *) = aten::sigmoid(%40)
32+
%22 : Float(*, *) = aten::sigmoid(%36)
33+
%20 : Float(*, *) = aten::tanh(%32)
34+
%18 : Float(*, *) = aten::sigmoid(%28)
35+
%16 : Float(*, *) = aten::mul(%22, %15)
36+
%13 : Float(*, *) = aten::mul(%24, %20)
3537
%9 : int = prim::Constant[value=1]()
36-
%10 : Float(3, 20) = aten::add(%16, %13, %9)
37-
%6 : Float(3, 20) = aten::tanh(%10)
38-
%5 : Float(3, 20) = aten::mul(%18, %6)
39-
%2 : Float(6, 20) = prim::FusedConcat[dim=0](%5, %10)
38+
%10 : Float(*, *) = aten::add(%16, %13, %9)
39+
%6 : Float(*, *) = aten::tanh(%10)
40+
%5 : Float(*, *) = aten::mul(%18, %6)
41+
%2 : Float(*, *) = prim::FusedConcat[dim=0](%5, %10)
4042
return (%2);
4143
}
Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,42 @@
1-
graph(%0 : Float(3, 10)
2-
%1 : Float(3, 20)
3-
%2 : Float(3, 20)
4-
%3 : Float(80, 10)
5-
%4 : Float(80, 20)
6-
%5 : Float(80)
7-
%6 : Float(80)) {
8-
%7 : Float(10!, 80!) = aten::t(%3)
1+
graph(%0 : Float(*, *)
2+
%1 : Float(*, *)
3+
%2 : Float(*, *)
4+
%3 : Float(*, *)
5+
%4 : Float(*, *)
6+
%5 : Float(*)
7+
%6 : Float(*)) {
8+
%7 : Float(*, *) = aten::t(%3)
99
%8 : int = prim::Constant[value=1]()
10-
%9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8)
11-
%10 : Float(20!, 80!) = aten::t(%4)
12-
%11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8)
13-
%12 : Float(3, 20), %13 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %9, %11)
14-
return (%12, %13);
10+
%9 : Float(*, *) = aten::addmm(%5, %0, %7, %8, %8)
11+
%10 : Float(*, *) = aten::t(%4)
12+
%11 : Float(*, *) = aten::addmm(%6, %1, %10, %8, %8)
13+
%12 : Dynamic[] = prim::ListConstruct(%9, %11)
14+
%13 : Dynamic, %14 : Dynamic = aten::broadcast_tensors(%12)
15+
%15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %13, %14)
16+
return (%15, %16);
1517
}
16-
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
17-
%39 : Float(3, 80)
18-
%44 : Float(3, 80)) {
18+
with prim::FusionGroup_0 = graph(%13 : Float(*, *)
19+
%39 : Dynamic
20+
%44 : Dynamic) {
1921
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44)
2022
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39)
2123
%37 : int = prim::Constant[value=1]()
22-
%38 : Float(3, 20) = aten::add(%40, %45, %37)
24+
%38 : Float(*, *) = aten::add(%40, %45, %37)
2325
%33 : int = prim::Constant[value=1]()
24-
%34 : Float(3, 20) = aten::add(%41, %46, %33)
26+
%34 : Float(*, *) = aten::add(%41, %46, %33)
2527
%29 : int = prim::Constant[value=1]()
26-
%30 : Float(3, 20) = aten::add(%42, %47, %29)
28+
%30 : Float(*, *) = aten::add(%42, %47, %29)
2729
%25 : int = prim::Constant[value=1]()
28-
%26 : Float(3, 20) = aten::add(%43, %48, %25)
29-
%22 : Float(3, 20) = aten::sigmoid(%38)
30-
%20 : Float(3, 20) = aten::sigmoid(%34)
31-
%18 : Float(3, 20) = aten::tanh(%30)
32-
%16 : Float(3, 20) = aten::sigmoid(%26)
33-
%14 : Float(3, 20) = aten::mul(%20, %13)
34-
%11 : Float(3, 20) = aten::mul(%22, %18)
30+
%26 : Float(*, *) = aten::add(%43, %48, %25)
31+
%22 : Float(*, *) = aten::sigmoid(%38)
32+
%20 : Float(*, *) = aten::sigmoid(%34)
33+
%18 : Float(*, *) = aten::tanh(%30)
34+
%16 : Float(*, *) = aten::sigmoid(%26)
35+
%14 : Float(*, *) = aten::mul(%20, %13)
36+
%11 : Float(*, *) = aten::mul(%22, %18)
3537
%7 : int = prim::Constant[value=1]()
36-
%8 : Float(3, 20) = aten::add(%14, %11, %7)
37-
%4 : Float(3, 20) = aten::tanh(%8)
38-
%2 : Float(3, 20) = aten::mul(%16, %4)
38+
%8 : Float(*, *) = aten::add(%14, %11, %7)
39+
%4 : Float(*, *) = aten::tanh(%8)
40+
%2 : Float(*, *) = aten::mul(%16, %4)
3941
return (%2, %8);
4042
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
graph(%0 : Double(3, 4)
22
%1 : Double(4, 5)) {
33
%2 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
4-
%3 : Double(3, 5) = aten::neg(%2), scope: TracedModule/ScriptModule
4+
%3 : Double(*, *) = aten::neg(%2), scope: TracedModule/ScriptModule
55
return (%3);
66
}

test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
graph(%0 : Double(3, 4)) {
2-
%1 : Double(3, 4) = aten::neg(%0), scope: ScriptModule
2+
%1 : Double(*, *) = aten::neg(%0), scope: ScriptModule
33
%2 : Long() = prim::Constant[value={1}]()
44
%3 : int = prim::Constant[value=1]()
55
%4 : Double(3, 4) = aten::add(%1, %2, %3)

test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
graph(%0 : Double(3, 4)) {
22
%1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: ScriptMod
3-
%2 : Double(3, 3) = aten::mm(%0, %1), scope: ScriptMod
3+
%2 : Double(*, *) = aten::mm(%0, %1), scope: ScriptMod
44
%3 : Long() = prim::Constant[value={1}]()
55
%4 : int = prim::Constant[value=1]()
66
%5 : Double(3, 3) = aten::add(%2, %3, %4)

test/expect/TestScript.test_call_script_module_from_traced_module.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ graph(%0 : Double(3, 4)
22
%1 : Double(4, 5)
33
%2 : Double(5, 7)) {
44
%3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
5-
%4 : Double(3, 7) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod]
5+
%4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod]
66
%5 : Long() = prim::Constant[value={1}](), scope: TracedModule
77
%6 : int = prim::Constant[value=1](), scope: TracedModule
88
%7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
graph(%x : Float(10, 6)) {
2-
%1 : Float(10, 2) = prim::FusionGroup_0[device=0](%x)
1+
graph(%x : Float(*, *)) {
2+
%1 : Float(*, *) = prim::FusionGroup_0[device=0](%x)
33
return (%1);
44
}
5-
with prim::FusionGroup_0 = graph(%7 : Float(10, 6)) {
5+
with prim::FusionGroup_0 = graph(%7 : Float(*, *)) {
66
%8 : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::FusedChunk[chunks=3, dim=1](%7)
7-
%6 : Float(10, 2) = aten::mul(%8, %9)
7+
%6 : Float(*, *) = aten::mul(%8, %9)
88
%2 : int = prim::Constant[value=1]()
9-
%3 : Float(10, 2) = aten::add(%6, %10, %2)
9+
%3 : Float(*, *) = aten::add(%6, %10, %2)
1010
return (%3);
1111
}
Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
1-
graph(%s : Float(5, 2, 3)
2-
%x : Float(5, 6, 3)
3-
%y : Float(10, 2, 3)
4-
%z : Float(5, 2, 6)) {
5-
%4 : Float(5, 2, 3) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
1+
graph(%s : Float(*, *, *)
2+
%x : Float(*, *, *)
3+
%y : Float(*, *, *)
4+
%z : Float(*, *, *)) {
5+
%4 : Float(*, *, *) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
66
return (%4);
77
}
8-
with prim::FusionGroup_0 = graph(%24 : Float(5, 2, 3)
9-
%28 : Float(10, 2, 3)
10-
%31 : Float(5, 6, 3)
11-
%35 : Float(5, 2, 6)) {
8+
with prim::FusionGroup_0 = graph(%24 : Float(*, *, *)
9+
%28 : Float(*, *, *)
10+
%31 : Float(*, *, *)
11+
%35 : Float(*, *, *)) {
1212
%36 : Dynamic, %37 : Dynamic = prim::FusedChunk[chunks=2, dim=2](%35)
1313
%32 : Dynamic, %33 : Dynamic, %34 : Dynamic = prim::FusedChunk[chunks=3, dim=1](%31)
1414
%29 : Dynamic, %30 : Dynamic = prim::FusedChunk[chunks=2, dim=0](%28)
1515
%26 : int = prim::Constant[value=1]()
16-
%27 : Float(5, 2, 3) = aten::add(%24, %32, %26)
16+
%27 : Float(*, *, *) = aten::add(%24, %32, %26)
1717
%22 : int = prim::Constant[value=1]()
18-
%23 : Float(5, 2, 3) = aten::add(%27, %33, %22)
18+
%23 : Float(*, *, *) = aten::add(%27, %33, %22)
1919
%18 : int = prim::Constant[value=1]()
20-
%19 : Float(5, 2, 3) = aten::add(%23, %34, %18)
20+
%19 : Float(*, *, *) = aten::add(%23, %34, %18)
2121
%14 : int = prim::Constant[value=1]()
22-
%15 : Float(5, 2, 3) = aten::add(%19, %29, %14)
22+
%15 : Float(*, *, *) = aten::add(%19, %29, %14)
2323
%10 : int = prim::Constant[value=1]()
24-
%11 : Float(5, 2, 3) = aten::add(%15, %30, %10)
24+
%11 : Float(*, *, *) = aten::add(%15, %30, %10)
2525
%6 : int = prim::Constant[value=1]()
26-
%7 : Float(5, 2, 3) = aten::add(%11, %36, %6)
26+
%7 : Float(*, *, *) = aten::add(%11, %36, %6)
2727
%2 : int = prim::Constant[value=1]()
28-
%3 : Float(5, 2, 3) = aten::add(%7, %37, %2)
28+
%3 : Float(*, *, *) = aten::add(%7, %37, %2)
2929
return (%3);
3030
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
graph(%x : Double(*, *)) {
2+
%1 : int = prim::Constant[value=1]()
3+
%c : Dynamic[] = prim::If(%1)
4+
block0() {
5+
%c.1 : Dynamic[] = prim::ListConstruct(%x, %x)
6+
-> (%c.1)
7+
}
8+
block1() {
9+
%c.2 : Dynamic[] = prim::ListConstruct(%x, %x, %x)
10+
-> (%c.2)
11+
}
12+
%5 : int = prim::Constant[value=0]()
13+
%6 : Dynamic = aten::cat(%c, %5)
14+
return (%6);
15+
}

0 commit comments

Comments
 (0)