|
1 | 1 | import torch
|
2 | 2 |
|
3 | 3 | from torch._C._nvfuser import Fusion, FusionDefinition
|
| 4 | +import torch._prims as prims |
| 5 | +import torch._refs as refs |
4 | 6 |
|
5 | 7 | # Construct and Define Fusion
|
6 | 8 | fusion1 = Fusion()
|
|
20 | 22 | fusion1.print_ir()
|
21 | 23 |
|
22 | 24 | # Execute Fusion
|
23 |
| -input1 = torch.ones(3, device='cuda') |
24 |
| -input2 = torch.ones(2, 3, 4, device='cuda') |
| 25 | +input1 = torch.randn(3, device='cuda') |
| 26 | +input2 = torch.randn(2, 3, 4, device='cuda') |
25 | 27 |
|
26 | 28 | # Kernel compilation should be cached for the 2nd iteration
|
27 | 29 | # with input tensors of the same shape
|
28 | 30 | for _ in range(5) :
|
29 |
| - outputs = fusion1.execute([input1, input2]) |
| 31 | + o = fusion1.execute([input1, input2])[0] |
30 | 32 |
|
31 |
| -print(outputs[0]) |
| 33 | +assert(o.shape == torch.Size([2, 3, 4])) |
| 34 | + |
| 35 | +# Reference in prim torch |
| 36 | +ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2) |
| 37 | +assert(ref_o.allclose(o)) |
| 38 | +assert(ref_o.shape == o.shape) |
32 | 39 |
|
33 | 40 | fusion2 = Fusion()
|
34 | 41 |
|
35 |
| -input1 = torch.ones(1, 1, 4, device='cuda') |
36 |
| -input2 = torch.ones(2, 3, 4, device='cuda') |
| 42 | +input1 = torch.randn(1, 1, 4, device='cuda') |
| 43 | +input2 = torch.randn(2, 3, 4, device='cuda') |
37 | 44 |
|
38 | 45 | with FusionDefinition(fusion2) as fd :
|
39 | 46 | t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
|
|
43 | 50 | fd.add_input(t1)
|
44 | 51 |
|
45 | 52 | t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
|
46 |
| - print("Broadcast TensorView", t0_b) |
47 | 53 | t2 = fd.Ops.add(t0_b, t1)
|
48 | 54 |
|
49 | 55 | fd.add_output(t2)
|
|
53 | 59 | # Kernel compilation should be cached for the 2nd iteration
|
54 | 60 | # with input tensors of the same shape
|
55 | 61 | for _ in range(5) :
|
56 |
| - outputs = fusion2.execute([input1, input2]) |
| 62 | + o = fusion2.execute([input1, input2])[0] |
| 63 | + |
| 64 | +assert(o.shape == torch.Size([2, 3, 4])) |
| 65 | + |
| 66 | +# Reference in prim torch |
| 67 | +ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2) |
| 68 | +assert(ref_o.allclose(o)) |
| 69 | +assert(ref_o.shape == o.shape) |
| 70 | + |
| 71 | +# Construct and Define Fusion |
| 72 | +fusion3 = Fusion() |
| 73 | + |
| 74 | +with FusionDefinition(fusion3) as fd : |
| 75 | + # t0 = fd.define_tensor(2) |
| 76 | + t0 = fd.define_tensor([3, 1], [1, 1]) |
| 77 | + t1 = fd.define_tensor(1) |
| 78 | + |
| 79 | + fd.add_input(t0) |
| 80 | + fd.add_input(t1) |
| 81 | + |
| 82 | + t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0 |
| 83 | + t2 = fd.Ops.add(t0, t1_b) |
| 84 | + |
| 85 | + fd.add_output(t2) |
| 86 | + |
| 87 | +fusion3.print_ir() |
| 88 | + |
| 89 | +# Execute Fusion |
| 90 | +input1 = torch.randn(3, 1, device='cuda') |
| 91 | +input2 = torch.randn(3, device='cuda') |
| 92 | + |
| 93 | +# Kernel compilation should be cached for the 2nd iteration |
| 94 | +# with input tensors of the same shape |
| 95 | +for _ in range(5) : |
| 96 | + o = fusion3.execute([input1, input2])[0] |
| 97 | + |
| 98 | +assert(o.shape == torch.Size([3, 3])) |
57 | 99 |
|
58 |
| -print(outputs[0]) |
| 100 | +# Reference in prim torch |
| 101 | +ref_o = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0])) |
| 102 | +assert(ref_o.allclose(o)) |
| 103 | +assert(ref_o.shape == o.shape) |
0 commit comments