Skip to content

Commit 3ba6a5f

Browse files
authored
Broadcast in dim with expand (#1794)
Fixes #1788 Added expand in broadcast_in_dim to support expanding to concrete size. Note that we are not supporting dynamic shape for concrete size at this moment.
1 parent fd4be12 commit 3ba6a5f

File tree

2 files changed

+80
-11
lines changed

2 files changed

+80
-11
lines changed

torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22

33
from torch._C._nvfuser import Fusion, FusionDefinition
4+
import torch._prims as prims
5+
import torch._refs as refs
46

57
# Construct and Define Fusion
68
fusion1 = Fusion()
@@ -20,20 +22,25 @@
2022
fusion1.print_ir()
2123

2224
# Execute Fusion
23-
input1 = torch.ones(3, device='cuda')
24-
input2 = torch.ones(2, 3, 4, device='cuda')
25+
input1 = torch.randn(3, device='cuda')
26+
input2 = torch.randn(2, 3, 4, device='cuda')
2527

2628
# Kernel compilation should be cached for the 2nd iteration
2729
# with input tensors of the same shape
2830
for _ in range(5) :
29-
outputs = fusion1.execute([input1, input2])
31+
o = fusion1.execute([input1, input2])[0]
3032

31-
print(outputs[0])
33+
assert(o.shape == torch.Size([2, 3, 4]))
34+
35+
# Reference in prim torch
36+
ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2)
37+
assert(ref_o.allclose(o))
38+
assert(ref_o.shape == o.shape)
3239

3340
fusion2 = Fusion()
3441

35-
input1 = torch.ones(1, 1, 4, device='cuda')
36-
input2 = torch.ones(2, 3, 4, device='cuda')
42+
input1 = torch.randn(1, 1, 4, device='cuda')
43+
input2 = torch.randn(2, 3, 4, device='cuda')
3744

3845
with FusionDefinition(fusion2) as fd :
3946
t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
@@ -43,7 +50,6 @@
4350
fd.add_input(t1)
4451

4552
t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
46-
print("Broadcast TensorView", t0_b)
4753
t2 = fd.Ops.add(t0_b, t1)
4854

4955
fd.add_output(t2)
@@ -53,6 +59,45 @@
5359
# Kernel compilation should be cached for the 2nd iteration
5460
# with input tensors of the same shape
5561
for _ in range(5) :
56-
outputs = fusion2.execute([input1, input2])
62+
o = fusion2.execute([input1, input2])[0]
63+
64+
assert(o.shape == torch.Size([2, 3, 4]))
65+
66+
# Reference in prim torch
67+
ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2)
68+
assert(ref_o.allclose(o))
69+
assert(ref_o.shape == o.shape)
70+
71+
# Construct and Define Fusion
72+
fusion3 = Fusion()
73+
74+
with FusionDefinition(fusion3) as fd :
75+
# t0 = fd.define_tensor(2)
76+
t0 = fd.define_tensor([3, 1], [1, 1])
77+
t1 = fd.define_tensor(1)
78+
79+
fd.add_input(t0)
80+
fd.add_input(t1)
81+
82+
t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0
83+
t2 = fd.Ops.add(t0, t1_b)
84+
85+
fd.add_output(t2)
86+
87+
fusion3.print_ir()
88+
89+
# Execute Fusion
90+
input1 = torch.randn(3, 1, device='cuda')
91+
input2 = torch.randn(3, device='cuda')
92+
93+
# Kernel compilation should be cached for the 2nd iteration
94+
# with input tensors of the same shape
95+
for _ in range(5) :
96+
o = fusion3.execute([input1, input2])[0]
97+
98+
assert(o.shape == torch.Size([3, 3]))
5799

58-
print(outputs[0])
100+
# Reference in prim torch
101+
ref_o = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0]))
102+
assert(ref_o.allclose(o))
103+
assert(ref_o.shape == o.shape)

torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,8 @@ void initNvFuserPythonBindings(PyObject* module) {
607607
[](TensorView* input,
608608
std::vector<int>& output_shape,
609609
std::vector<int>& broadcast_dims) -> TensorView* {
610-
const auto input_ndims = input->domain()->noReductions().size();
610+
const auto& iter_domains = input->domain()->noReductions();
611+
const auto input_ndims = iter_domains.size();
611612
TORCH_CHECK(
612613
output_shape.size() >= input_ndims,
613614
"The new shape is expected to be greater-then-or-equal to the input",
@@ -619,7 +620,9 @@ void initNvFuserPythonBindings(PyObject* module) {
619620
input_ndims,
620621
broadcast_dims.size());
621622

623+
// default all dimensions to be broadcasted
622624
std::vector<bool> is_broadcast_dim(output_shape.size(), true);
625+
std::vector<bool> is_expand_dim(output_shape.size(), true);
623626
for (const auto idx : c10::irange(broadcast_dims.size())) {
624627
if (idx > 0) {
625628
TORCH_CHECK(
@@ -630,9 +633,30 @@ void initNvFuserPythonBindings(PyObject* module) {
630633
broadcast_dims[idx] < static_cast<int>(output_shape.size()),
631634
"Invalid broadcast_dims value.");
632635
is_broadcast_dim.at(broadcast_dims[idx]) = false;
636+
// Note: when we expand a broadcasted dimension, we need to expand it
637+
// to a concrete size, hence the need for `is_expand_dim` flag and the
638+
// expand operation following the broadcast.
639+
is_expand_dim.at(broadcast_dims[idx]) =
640+
iter_domains[idx]->isBroadcast();
633641
}
634642

635-
return torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim);
643+
std::vector<torch::jit::fuser::cuda::Val*> output_shape_on_bcast(
644+
output_shape.size(), nullptr);
645+
for (const auto idx : c10::irange(output_shape.size())) {
646+
if (is_expand_dim[idx]) {
647+
// TODO: this would be tricky to handle on dynamic shapes, we'll
648+
// need to pass-in a symbol instead somehow.
649+
output_shape_on_bcast[idx] =
650+
IrBuilder::create<Int>(output_shape[idx]);
651+
} else {
652+
output_shape_on_bcast[idx] = IrBuilder::create<Int>(-1);
653+
}
654+
}
655+
656+
auto bcasted_input =
657+
torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim);
658+
return torch::jit::fuser::cuda::expand(
659+
bcasted_input, output_shape_on_bcast);
636660
},
637661
py::return_value_policy::reference);
638662

0 commit comments

Comments
 (0)