Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
840da18
Major refactor of code lowering and associated parts.
csarofeen Apr 3, 2020
f2b3982
Don't remove guards around gpu tests.
csarofeen Apr 3, 2020
a568e7c
Minor revisions.
csarofeen Apr 4, 2020
f24db56
Option to return vals from fusion in registered order.
csarofeen Apr 4, 2020
fdadf51
Clang.
csarofeen Apr 4, 2020
a632541
Continue lowering refactor, split out loop nest generator, create sco…
csarofeen Apr 5, 2020
d083d82
ForLoop::range renamed to ForLoop::iter_domain
csarofeen Apr 5, 2020
c87d495
Rename IterDomain::size -> IterDomain::extent.
csarofeen Apr 5, 2020
28ddcf4
Last working test before unrolling. Add incrementally better scalar c…
csarofeen Apr 6, 2020
e761dd7
Add basic infrastructure for unrolling pass.
csarofeen Apr 6, 2020
bb995a8
Factor out ir utilities that can be reused during lowering.
csarofeen Apr 6, 2020
d6531e9
Unrolling loops seemingly working.
csarofeen Apr 8, 2020
7e52a32
Clang.
csarofeen Apr 8, 2020
81e8de4
Test fix.
csarofeen Apr 9, 2020
c984dc8
Major refactor of code lowering and associated parts.
csarofeen Apr 3, 2020
f82a603
Minor revisions.
csarofeen Apr 4, 2020
2373ae6
Improve const scalar check. Add some parallelization guards. Move con…
csarofeen Apr 4, 2020
517f15c
Option to return vals from fusion in registered order.
csarofeen Apr 4, 2020
0df14ef
Clang.
csarofeen Apr 4, 2020
3da13d9
Continue lowering refactor, split out loop nest generator, create sco…
csarofeen Apr 5, 2020
ed0d394
Refactor split/merge/reorder so they can be called direcly on tensorD…
csarofeen Apr 9, 2020
6ea8820
tmp, working.
csarofeen Apr 9, 2020
1a4c089
Transform iter now based on tensor domains, not tensor views.
csarofeen Apr 10, 2020
413e649
Rename TensorDomain::size() -> ::nDims
csarofeen Apr 10, 2020
05467e9
Make transformations based on TensorDomains, not TensorViews. TensorV…
csarofeen Apr 10, 2020
07dfe39
TensorIndex::size renamed to ::nDims.
csarofeen Apr 10, 2020
4c735d3
Re-enable being able to compile a fusion multiple times.
csarofeen Apr 10, 2020
87ac1e0
Major refactor of code lowering and associated parts.
csarofeen Apr 3, 2020
a13162e
Minor revisions.
csarofeen Apr 4, 2020
84b689d
Improve const scalar check. Add some parallelization guards. Move con…
csarofeen Apr 4, 2020
9ac7841
Option to return vals from fusion in registered order.
csarofeen Apr 4, 2020
3745274
Clang.
csarofeen Apr 4, 2020
acf9a25
Improve error message in promote. Multiply IterDomain->size() for loc…
csarofeen Apr 4, 2020
b762029
Unroll 2.0
csarofeen Apr 10, 2020
04ce65a
Found an indexing mistake. Fixed.
csarofeen Apr 10, 2020
ec95211
Minor test cleanup.
csarofeen Apr 10, 2020
fe5b7c2
Clang.
csarofeen Apr 10, 2020
bc0dd59
Post rebase cleanup.
csarofeen Apr 10, 2020
7dab57d
Clang.
csarofeen Apr 10, 2020
2191e8c
[Integration code refactor]
jjsjann123 Apr 1, 2020
34bf12d
Test cleanup, merge cleanup, clang format.
csarofeen Apr 11, 2020
0776504
Add unrolling to pointwise kernels in the fuser.
csarofeen Apr 11, 2020
e35219c
Clang.
csarofeen Apr 11, 2020
83302d1
Flake.
csarofeen Apr 11, 2020
3c319bd
Switch to int64 indexing.
csarofeen Apr 13, 2020
bb14652
Refactor kernel argument parsing.
csarofeen Apr 14, 2020
985f290
Clang.
csarofeen Apr 14, 2020
a1f84b9
Jie review.
csarofeen Apr 14, 2020
b81e4b7
Clang.
csarofeen Apr 14, 2020
76290fa
clang-tidy, build warning->error.
csarofeen Apr 14, 2020
dba93b4
Clang tidy.
csarofeen Apr 14, 2020
c7d7488
Missed clang-tidy.
csarofeen Apr 14, 2020
5b20c5c
Merge branch 'master' of https://www.github.com/pytorch/pytorch into …
csarofeen Apr 14, 2020
c54f2b3
Clang.
csarofeen Apr 14, 2020
5f98e98
Foundation to start working on adding reduction support.
csarofeen Apr 15, 2020
be76a33
Basic reduction, no parallelization.
csarofeen Apr 17, 2020
788f0c0
Working towards block reductions, added reduction template kernels to…
csarofeen Apr 19, 2020
abfd4df
[WIP] rfactor replay.
csarofeen Apr 22, 2020
677583f
Fix circular compute at references. Extra error checking on codegen u…
csarofeen Apr 22, 2020
0cf75c1
Continue to fix computeAt support.
csarofeen Apr 23, 2020
34c9d8a
Further fixes/improvments for replay rfactor.
csarofeen Apr 23, 2020
dd707c1
Remove stop condition in iter_visitor.
csarofeen Apr 24, 2020
369e10f
Refactor iter visitor. Using a fully tracked stack approach.
csarofeen Apr 25, 2020
9d27318
Refactor dependency checking.
csarofeen Apr 25, 2020
e409ba3
Traverse all dependency chains in computeAt.
csarofeen Apr 25, 2020
1e0d775
Refactor computeAt, works with multiple consumers correctly now. Does…
csarofeen Apr 25, 2020
3e48c88
Clang.
csarofeen Apr 25, 2020
f3e6fd7
Add unrolling test, update test reference code.
csarofeen Apr 25, 2020
89c12aa
Refactor computeAt again, taking a whole graph approach, back prop co…
csarofeen Apr 27, 2020
08fc770
Finish computeAt refactor without rfactor. Tests working again added …
csarofeen Apr 27, 2020
cf438ba
Remove some non-deterministic behavior, re-add block binding, add mor…
csarofeen Apr 28, 2020
c1e0c7e
Support floating point intermediate values in codegen.
csarofeen Apr 29, 2020
641e2ac
Add rfactor tracking in IterDomain.
csarofeen Apr 29, 2020
9905cd2
Refactor transform replay and transform iter. Remove transform iter f…
csarofeen Apr 30, 2020
0bf9264
Move replay functionality to transform iter.
csarofeen Apr 30, 2020
0b0ffc0
Move backward replay to transform iter, continue refactoring of trans…
csarofeen Apr 30, 2020
f1f7ebc
Continue transform replay refactoring, in preparation for rfactor.
csarofeen May 1, 2020
8aa712c
Another fix to transform replay, add some comments.
csarofeen May 1, 2020
7a25bcc
Rewrite replay(reorder) for the 42nd time.
csarofeen May 2, 2020
0ad949c
Initial RFactor transform replay validation.
csarofeen May 3, 2020
03503e3
Remove old transform relpay functions.
csarofeen May 3, 2020
66ad362
Rename maps axis2pos -> old2new, pos2axis -> new2old.
csarofeen May 3, 2020
c731936
RFactor transforms with computeAt all seem correct, need to work on c…
csarofeen May 3, 2020
91908d9
Rework for reductions/rfactor, workout code lowering to work with red…
csarofeen May 7, 2020
5a712df
Clang.
csarofeen May 7, 2020
9d3ce77
Remove printing in test.
csarofeen May 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp
Expand All @@ -596,6 +599,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_rfactor.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/type.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/utils.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/register_interface.cpp
Expand Down
904 changes: 728 additions & 176 deletions test/cpp/jit/test_gpu.cpp

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ namespace jit {
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop)
_(GPU_FusionForLoop) \
_(GPU_FusionLoopUnroll) \
_(GPU_FusionAdvancedComputeAt) \
_(GPU_FusionScalarInputs) \
_(GPU_FusionRFactorReplay) \
_(GPU_FusionSimpleReduction)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
45 changes: 41 additions & 4 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,57 @@ def t(x, y, z, q):
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_scalar_input(self):
def t(x, y, z):
# type: (Tensor, Tensor, float) -> Tensor
def t(x : torch.Tensor, y : torch.Tensor, z : float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
y = y.expand(4, 8, 32, 32)
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_broadcasting(self):
def t(x : torch.Tensor, y : torch.Tensor, z : float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_broadcasting_multiple_output_shape(self):
def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertFalse(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))

if __name__ == '__main__':
run_tests()
4 changes: 4 additions & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ libtorch_cuda_sources = [
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
"torch/csrc/jit/codegen/cuda/lower_utils.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
Expand All @@ -249,6 +252,7 @@ libtorch_cuda_sources = [
"torch/csrc/jit/codegen/cuda/tensor_view.cpp",
"torch/csrc/jit/codegen/cuda/transform_iter.cpp",
"torch/csrc/jit/codegen/cuda/transform_replay.cpp",
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
"torch/csrc/jit/codegen/cuda/type.cpp",
"torch/csrc/jit/codegen/cuda/utils.cpp",
"torch/csrc/jit/codegen/cuda/register_interface.cpp",
Expand Down
87 changes: 87 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,36 @@ TORCH_CUDA_API Val* promoteNew(Val* v1, Val* v2) {
return newValLike(v1, out_dtype);
}

Val* newConstScalar(DataType dtype, long int val) {
switch (dtype) {
case (DataType::Int):
return new Int((int)val);
default:
break;
}
TORCH_CHECK(
false,
"Could not generate a new Scalar with data type ",
dtype,
"and constant value: ",
val);
}

Val* newConstScalar(DataType dtype, double val) {
switch (dtype) {
case (DataType::Float):
return new Float(val);
default:
break;
}
TORCH_CHECK(
false,
"Could not generate a new Scalar with data type ",
dtype,
"and constant value: ",
val);
}

TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
if (v1->getDataType().value() == dtype)
return v1;
Expand All @@ -75,12 +105,16 @@ TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
return out;
}

// UNARY OPERATIONS

TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1) {
Val* out = newValLike(v1);
Statement* expr = new UnaryOp(type, out, v1);
return out;
}

// BINARY OPERATIONS

TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
Val* out = promoteNew(v1, v2);
if (type >= BinaryOpType::Mod) {
Expand Down Expand Up @@ -123,6 +157,59 @@ TORCH_CUDA_API Val* andOp(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::And, v1, v2);
}

// REDUCTION OPERATIONS

Val* reductionOp(
BinaryOpType reduction_op_type,
std::vector<int> axes,
Val* init,
Val* v1) {
TORCH_CHECK(
v1->getValType().value() == ValType::TensorView,
"Cannot reduce on values that are not TensorViews, but recieved type ",
v1->getValType().value());

TORCH_CHECK(
init->isConstScalar(),
"Cannot create a reduction operation where the initial value is not a const scalar.");

TensorView* tv = static_cast<TensorView*>(v1);

TORCH_CHECK(
tv->getRootDomain() == tv->domain(),
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/reorder/computeAt.");

std::vector<unsigned int> uint_axes;
for (int axis : axes) {
if (axis < 0)
axis += int(tv->nDims());

TORCH_CHECK(
axis >= 0 && axis < tv->nDims(),
"Reduction on invalid axis, recieved: ",
axis,
" however tensor view only has ",
tv->nDims(),
" dims.");

uint_axes.push_back((unsigned int)axis);
}

Val* out = tv->newForReduction(uint_axes);
if (init->getDataType().value() != v1->getDataType().value())
init = castOp(v1->getDataType().value(), init);
new ReductionOp(reduction_op_type, init, out, v1);
return out;
}

TORCH_CUDA_API Val* sum(Val* v1, std::vector<int> axes) {
return reductionOp(
BinaryOpType::Add,
axes,
newConstScalar(v1->getDataType().value(), 0.0),
v1);
}

} // namespace fuser
} // namespace jit
} // namespace torch
9 changes: 8 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1);
// Mod, CeilDiv, and LT are considered Int only output operations for now.
TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2);

// Binary operations
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
TORCH_CUDA_API Val* sub(Val* v1, Val* v2);
TORCH_CUDA_API Val* mul(Val* v1, Val* v2);
Expand All @@ -40,7 +41,13 @@ TORCH_CUDA_API Val* mod(Val* v1, Val* v2);
TORCH_CUDA_API Val* lt(Val* v1, Val* v2);
TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2);
TORCH_CUDA_API Val* andOp(Val* v1, Val* v2);

TORCH_CUDA_API Val* reductionOp(
BinaryOpType reduction_op_type,
std::vector<int> axes,
Val* init,
Val* v1);
// REDUCTION OPERATIONS
TORCH_CUDA_API Val* sum(Val* v1, std::vector<int> reduction_axes);
} // namespace fuser
} // namespace jit
} // namespace torch
11 changes: 0 additions & 11 deletions torch/csrc/jit/codegen/cuda/data_struct_str.h

This file was deleted.

9 changes: 8 additions & 1 deletion torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <torch/csrc/jit/codegen/cuda/type.h>

#include <torch/csrc/jit/codegen/cuda/dispatch.h>
Expand Down Expand Up @@ -93,6 +92,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::BinaryOp:
ptr(handler)->handle(static_cast<BinaryOp*>(expr));
return;
case ExprType::ReductionOp:
ptr(handler)->handle(static_cast<ReductionOp*>(expr));
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<ForLoop*>(expr));
return;
Expand Down Expand Up @@ -170,6 +172,9 @@ void Expr::constDispatch(T handler, const Expr* const expr) {
case ExprType::BinaryOp:
ptr(handler)->handle(static_cast<const BinaryOp* const>(expr));
return;
case ExprType::ReductionOp:
ptr(handler)->handle(static_cast<const ReductionOp* const>(expr));
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<const ForLoop* const>(expr));
return;
Expand Down Expand Up @@ -246,6 +251,8 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
case ExprType::BinaryOp:
return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
case ExprType::ReductionOp:
return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
case ExprType::ForLoop:
return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
case ExprType::IfThenElse:
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct Merge;
struct Reorder;
struct UnaryOp;
struct BinaryOp;
struct ReductionOp;
struct ForLoop;
struct IfThenElse;
struct Allocate;
Expand Down Expand Up @@ -108,6 +109,7 @@ struct TORCH_CUDA_API OptOutConstDispatch {
virtual void handle(const Reorder* const) {}
virtual void handle(const UnaryOp* const) {}
virtual void handle(const BinaryOp* const) {}
virtual void handle(const ReductionOp* const) {}
virtual void handle(const ForLoop* const) {}
virtual void handle(const IfThenElse* const) {}
virtual void handle(const Allocate* const) {}
Expand Down Expand Up @@ -143,6 +145,7 @@ struct TORCH_CUDA_API OptOutDispatch {
virtual void handle(Reorder*) {}
virtual void handle(UnaryOp*) {}
virtual void handle(BinaryOp*) {}
virtual void handle(ReductionOp*) {}
virtual void handle(ForLoop*) {}
virtual void handle(IfThenElse*) {}
virtual void handle(Allocate*) {}
Expand Down Expand Up @@ -202,6 +205,9 @@ struct TORCH_CUDA_API OptInConstDispatch {
virtual void handle(const BinaryOp* const) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
}
virtual void handle(const ReductionOp* const) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp.");
}
virtual void handle(const ForLoop* const) {
AT_ERROR("Handle not overriden for ForLoop.");
}
Expand Down Expand Up @@ -267,6 +273,9 @@ struct TORCH_CUDA_API OptInDispatch {
virtual void handle(BinaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
}
virtual void handle(ReductionOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp.");
}
virtual void handle(ForLoop*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ForLoop.");
}
Expand Down Expand Up @@ -332,6 +341,7 @@ struct TORCH_CUDA_API OptOutMutator {
virtual Statement* mutate(Reorder*);
virtual Statement* mutate(UnaryOp*);
virtual Statement* mutate(BinaryOp*);
virtual Statement* mutate(ReductionOp*);
virtual Statement* mutate(ForLoop*);
virtual Statement* mutate(IfThenElse*);
virtual Statement* mutate(Allocate*);
Expand Down Expand Up @@ -401,6 +411,9 @@ struct TORCH_CUDA_API OptInMutator {
virtual Statement* mutate(BinaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BinaryOp.");
}
virtual Statement* mutate(ReductionOp*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ReductionOp.");
}
virtual Statement* mutate(ForLoop*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ForLoop.");
}
Expand Down
Loading