Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def test_binary_ops(self):
def test_ternary_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda")

def add(x : torch.Tensor, other : torch.Tensor, alpha : float):
Expand Down Expand Up @@ -325,6 +326,49 @@ def where(x : torch.Tensor, y : torch.Tensor, cond : torch.Tensor):
where_jit = torch.jit.script(where)
self._run_helper(where_jit, where, True, x, y, cond)

def lerp(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = torch.rand_like(x)
o = o * torch.lerp(x, y, z)
return o
lerp_jit = torch.jit.script(lerp)
self._run_helper(lerp_jit, lerp, True, x, y, z)

def lerp_scale(x : torch.Tensor, y : torch.Tensor, z: float):
o = torch.rand_like(x)
o = o * torch.lerp(x, y, z)
return o
lerp_scale_jit = torch.jit.script(lerp_scale)
self._run_helper(lerp_scale_jit, lerp_scale, True, x, y, 0.5)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_addcmul_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")

def addcmul(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, value : float):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=value)
return o
addcmul_jit = torch.jit.script(addcmul)
self._run_helper(addcmul_jit, addcmul, True, x, y, z, 2.0)

def addcmul_no_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z)
return o
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, True, x, y, z)

def addcmul_const_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=0.75)
return o
addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha)
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, True, x, y, z)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires profiling node to run cuda fuser")
@skipIfRocm
Expand Down
17 changes: 11 additions & 6 deletions torch/csrc/jit/codegen/cuda/graph_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,16 @@ struct CudaGraphFuser {
WithInsertPoint guard(*subgraph.nodes().begin());
for (auto input : n->inputs()) {
if (inputs_map.count(input) == 0) {
// TODO: we are following the convention for no good reason;
// we don't need tensor to come before any other inputs.
if (input->type()->isSubtypeOf(TensorType::get())) {
auto in_group = subgraph.insertInput(tensor_insert_idx);
in_group->setType(input->type());
inputs_map[input] = in_group;
group->insertInput(tensor_insert_idx, input);
tensor_insert_idx++;
} else if (
// TODO: extend the supporting inputs here.
(input->type()->isSubtypeOf(FloatType::get()) &&
input->node()->kind() != prim::Constant) ||
(n->kind() == aten::_grad_sum_to_size &&
Expand All @@ -181,18 +184,20 @@ struct CudaGraphFuser {
in_group->setType(input->type());
inputs_map[input] = in_group;
group->addInput(input);
} else {
// We don't support passing in scalars as arguments to fused kernels,
// so we generally don't allow fusing tensor-scalar operations unless
// the scalar is constant. In those cases we inline the constants
// directly in the body of the fused group.
AT_ASSERT(input->node()->kind() == prim::Constant);
} else if (input->node()->kind() == prim::Constant) {
// inline the constants directly in the body of the fused group.
Node* in_const =
subgraph.createClone(input->node(), [](Value*) -> Value* {
throw std::runtime_error("unexpected input");
});
subgraph.insertNode(in_const);
inputs_map[input] = in_const->output();
} else {
// TODO: we need to figure out what are supported input scalar
auto in_group = subgraph.addInput();
in_group->setType(input->type());
inputs_map[input] = in_group;
group->addInput(input);
}
}
}
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace cuda {
constexpr auto NUM_UNARY_OPS = 31;
constexpr auto NUM_BINARY_OPS = 24;
constexpr auto NUM_BINARY_OPS_WITH_ALPHA = 4;
constexpr auto NUM_LERP_OPS = 2;

namespace {

Expand Down Expand Up @@ -416,6 +417,43 @@ class IrParser {
value_map.emplace(node->output()->unique(), out);
});
}

{
std::array<const char*, NUM_LERP_OPS> LerpOp = {
"aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
for (auto signature : LerpOp) {
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->inputs()[0]->unique()];
auto end = value_map[node->inputs()[1]->unique()];
auto weight = value_map[node->inputs()[2]->unique()];

auto out = lerp(self, end, weight);
value_map.emplace(node->output()->unique(), out);
});
}
}

{
auto ptr_op = getOperatorForLiteral(
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->inputs()[0]->unique()];
auto tensor1 = value_map[node->inputs()[1]->unique()];
auto tensor2 = value_map[node->inputs()[2]->unique()];
auto value = value_map[node->inputs()[3]->unique()];

auto out = addcmul(self, tensor1, tensor2, value);
value_map.emplace(node->output()->unique(), out);
});
}
}

void processJitNode(const JitOp* node) {
Expand Down