Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch

from torch._C._nvfuser import Fusion, FusionDefinition
import torch._prims as prims
import torch._refs as refs

# Construct and Define Fusion
fusion1 = Fusion()
Expand All @@ -20,20 +22,25 @@
fusion1.print_ir()

# Execute Fusion
input1 = torch.ones(3, device='cuda')
input2 = torch.ones(2, 3, 4, device='cuda')
input1 = torch.randn(3, device='cuda')
input2 = torch.randn(2, 3, 4, device='cuda')

# Kernel compilation should be cached for the 2nd iteration
# with input tensors of the same shape
for _ in range(5) :
outputs = fusion1.execute([input1, input2])
o = fusion1.execute([input1, input2])[0]

print(outputs[0])
assert(o.shape == torch.Size([2, 3, 4]))

# Reference in prim torch
ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2)
assert(ref_o.allclose(o))
assert(ref_o.shape == o.shape)

fusion2 = Fusion()

input1 = torch.ones(1, 1, 4, device='cuda')
input2 = torch.ones(2, 3, 4, device='cuda')
input1 = torch.randn(1, 1, 4, device='cuda')
input2 = torch.randn(2, 3, 4, device='cuda')

with FusionDefinition(fusion2) as fd :
t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
Expand All @@ -43,7 +50,6 @@
fd.add_input(t1)

t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
print("Broadcast TensorView", t0_b)
t2 = fd.Ops.add(t0_b, t1)

fd.add_output(t2)
Expand All @@ -53,6 +59,45 @@
# Kernel compilation should be cached for the 2nd iteration
# with input tensors of the same shape
for _ in range(5) :
outputs = fusion2.execute([input1, input2])
o = fusion2.execute([input1, input2])[0]

assert(o.shape == torch.Size([2, 3, 4]))

# Reference in prim torch
ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2)
assert(ref_o.allclose(o))
assert(ref_o.shape == o.shape)

# Construct and Define Fusion
fusion3 = Fusion()

with FusionDefinition(fusion3) as fd :
# t0 = fd.define_tensor(2)
t0 = fd.define_tensor([3, 1], [1, 1])
t1 = fd.define_tensor(1)

fd.add_input(t0)
fd.add_input(t1)

t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0
t2 = fd.Ops.add(t0, t1_b)

fd.add_output(t2)

fusion3.print_ir()

# Execute Fusion
input1 = torch.randn(3, 1, device='cuda')
input2 = torch.randn(3, device='cuda')

# Kernel compilation should be cached for the 2nd iteration
# with input tensors of the same shape
for _ in range(5) :
o = fusion3.execute([input1, input2])[0]

assert(o.shape == torch.Size([3, 3]))

print(outputs[0])
# Reference in prim torch
ref_o = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0]))
assert(ref_o.allclose(o))
assert(ref_o.shape == o.shape)
28 changes: 26 additions & 2 deletions torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ void initNvFuserPythonBindings(PyObject* module) {
[](TensorView* input,
std::vector<int>& output_shape,
std::vector<int>& broadcast_dims) -> TensorView* {
const auto input_ndims = input->domain()->noReductions().size();
const auto& iter_domains = input->domain()->noReductions();
const auto input_ndims = iter_domains.size();
TORCH_CHECK(
output_shape.size() >= input_ndims,
"The new shape is expected to be greater-then-or-equal to the input",
Expand All @@ -619,7 +620,9 @@ void initNvFuserPythonBindings(PyObject* module) {
input_ndims,
broadcast_dims.size());

// default all dimensions to be broadcasted
std::vector<bool> is_broadcast_dim(output_shape.size(), true);
std::vector<bool> is_expand_dim(output_shape.size(), true);
for (const auto idx : c10::irange(broadcast_dims.size())) {
if (idx > 0) {
TORCH_CHECK(
Expand All @@ -630,9 +633,30 @@ void initNvFuserPythonBindings(PyObject* module) {
broadcast_dims[idx] < static_cast<int>(output_shape.size()),
"Invalid broadcast_dims value.");
is_broadcast_dim.at(broadcast_dims[idx]) = false;
// Note: when we expand a broadcasted dimension, we need to expand it
// to a concrete size, hence the need for `is_expand_dim` flag and the
// expand operation following the broadcast.
is_expand_dim.at(broadcast_dims[idx]) =
iter_domains[idx]->isBroadcast();
}

return torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim);
std::vector<torch::jit::fuser::cuda::Val*> output_shape_on_bcast(
output_shape.size(), nullptr);
for (const auto idx : c10::irange(output_shape.size())) {
if (is_expand_dim[idx]) {
// TODO: this would be tricky to handle on dynamic shapes, we'll
// need to pass-in a symbol instead somehow.
output_shape_on_bcast[idx] =
IrBuilder::create<Int>(output_shape[idx]);
} else {
output_shape_on_bcast[idx] = IrBuilder::create<Int>(-1);
}
}

auto bcasted_input =
torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim);
return torch::jit::fuser::cuda::expand(
bcasted_input, output_shape_on_bcast);
},
py::return_value_policy::reference);

Expand Down