Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 42 additions & 45 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ void testGPU_FusionCodeGen() {
TensorView* tv0 = makeDummyTensor(3);

new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0));
TensorView* tv1 = static_cast<TensorView*>(add(tv0, new Float(2.0)));
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(3.0)));
TensorView* tv1 = add(tv0, new Float(2.0));
TensorView* tv2 = add(tv1, new Float(3.0));
fusion.addOutput(tv2);

//[I0, I1, I2]
Expand Down Expand Up @@ -895,8 +895,8 @@ void testGPU_FusionCodeGen2() {

TensorView* tv0 = makeDummyTensor(3);
TensorView* tv1 = makeDummyTensor(3);
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(2.0)));
TensorView* tv3 = static_cast<TensorView*>(add(tv0, tv2));
TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);

fusion.addInput(tv0);
fusion.addInput(tv1);
Expand Down Expand Up @@ -953,8 +953,8 @@ void testGPU_FusionSimplePWise() {

// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(2.0)));
TensorView* tv3 = static_cast<TensorView*>(add(tv0, tv2));
TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);

// Register your outputs
fusion.addOutput(tv3);
Expand Down Expand Up @@ -1012,8 +1012,8 @@ void testGPU_FusionExecKernel() {

// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(2.0)));
TensorView* tv3 = static_cast<TensorView*>(add(tv0, tv2));
TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);

// Register your outputs
fusion.addOutput(tv3);
Expand Down Expand Up @@ -1075,13 +1075,13 @@ void testGPU_FusionAdvancedComputeAt() {
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);

TensorView* tv1 = static_cast<TensorView*>(mul(tv0, new Float(-1.0)));
TensorView* tv2 = static_cast<TensorView*>(add(tv0, new Float(3.0)));
TensorView* tv3 = static_cast<TensorView*>(mul(tv0, new Float(2.0)));
TensorView* tv4 = static_cast<TensorView*>(add(tv2, tv1));
TensorView* tv1 = mul(tv0, new Float(-1.0));
TensorView* tv2 = add(tv0, new Float(3.0));
TensorView* tv3 = mul(tv0, new Float(2.0));
TensorView* tv4 = add(tv2, tv1);

TensorView* tv5 = static_cast<TensorView*>(add(tv4, tv3));
TensorView* tv6 = static_cast<TensorView*>(add(tv0, tv3));
TensorView* tv5 = add(tv4, tv3);
TensorView* tv6 = add(tv0, tv3);

fusion.addOutput(tv5);
fusion.addOutput(tv6);
Expand Down Expand Up @@ -1168,13 +1168,13 @@ void testGPU_FusionAdvancedComputeAt() {
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);

TensorView* tv1 = static_cast<TensorView*>(mul(tv0, new Float(-1.0)));
TensorView* tv2 = static_cast<TensorView*>(add(tv0, new Float(3.0)));
TensorView* tv3 = static_cast<TensorView*>(mul(tv0, new Float(2.0)));
TensorView* tv4 = static_cast<TensorView*>(add(tv2, tv1));
TensorView* tv1 = mul(tv0, new Float(-1.0));
TensorView* tv2 = add(tv0, new Float(3.0));
TensorView* tv3 = mul(tv0, new Float(2.0));
TensorView* tv4 = add(tv2, tv1);

TensorView* tv5 = static_cast<TensorView*>(add(tv4, tv3));
TensorView* tv6 = static_cast<TensorView*>(add(tv5, tv3));
TensorView* tv5 = add(tv4, tv3);
TensorView* tv6 = add(tv5, tv3);

fusion.addOutput(tv5);
fusion.addOutput(tv6);
Expand Down Expand Up @@ -1252,8 +1252,8 @@ void testGPU_FusionAdvancedComputeAt() {
TensorView* tv1 = makeDummyTensor(4);
fusion.addInput(tv1);

TensorView* tv2 = static_cast<TensorView*>(mul(tv1, new Float(.979361)));
TensorView* tv3 = static_cast<TensorView*>(mul(tv2, tv0));
TensorView* tv2 = mul(tv1, new Float(.979361));
TensorView* tv3 = mul(tv2, tv0);

fusion.addOutput(tv3);

Expand Down Expand Up @@ -1325,9 +1325,9 @@ void testGPU_FusionAdvancedComputeAt() {
TensorView* tv3 = makeDummyTensor(4);
fusion.addInput(tv3);

TensorView* tv4 = static_cast<TensorView*>(sub(tv2, tv3));
TensorView* tv5 = static_cast<TensorView*>(add(tv1, tv4));
TensorView* tv6 = static_cast<TensorView*>(sub(tv5, tv0));
TensorView* tv4 = sub(tv2, tv3);
TensorView* tv5 = add(tv1, tv4);
TensorView* tv6 = sub(tv5, tv0);

fusion.addOutput(tv6);

Expand Down Expand Up @@ -1406,9 +1406,9 @@ void testGPU_FusionScalarInputs() {
Val* f4 = mul(f0, f1);
Val* f5 = sub(f2, f3);

TensorView* tv2 = static_cast<TensorView*>(sub(tv1, f4));
TensorView* tv3 = static_cast<TensorView*>(add(tv0, f5));
TensorView* tv4 = static_cast<TensorView*>(mul(tv3, tv2));
TensorView* tv2 = sub(tv1, f4);
TensorView* tv3 = add(tv0, f5);
TensorView* tv4 = mul(tv3, tv2);

fusion.addOutput(tv4);

Expand Down Expand Up @@ -1499,8 +1499,8 @@ void testGPU_FusionLoopUnroll() {

// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(2.0)));
TensorView* tv3 = static_cast<TensorView*>(add(tv0, tv2));
TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);

// Register your outputs
fusion.addOutput(tv3);
Expand Down Expand Up @@ -1560,7 +1560,7 @@ void testGPU_FusionForLoop() {

auto ID0 = new IterDomain(new Int(0), new Int(8));

TensorView* TV2 = static_cast<TensorView*>(add(TV0, TV1));
TensorView* TV2 = add(TV0, TV1);
BinaryOp* op = static_cast<BinaryOp*>(TV2->getOrigin());
fusion.addOutput(TV2);

Expand Down Expand Up @@ -1917,7 +1917,7 @@ void testGPU_FusionBinaryOps() {
return at::add(
vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
},
/*JIT Func */ add_alpha,
/*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&add_alpha),
/*Output */ std::make_pair(ValType::TensorView, DataType::Float),
/*Inputs Tuple*/
std::make_tuple(
Expand All @@ -1933,7 +1933,7 @@ void testGPU_FusionBinaryOps() {
return at::sub(
vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
},
/*JIT Func */ sub_alpha,
/*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&sub_alpha),
/*Output */ std::make_pair(ValType::TensorView, DataType::Float),
/*Inputs Tuple*/
std::make_tuple(
Expand Down Expand Up @@ -1982,7 +1982,7 @@ void testGPU_FusionTernaryOps() {
return at::where(
vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
},
/*JIT Func */ where,
/*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&where),
/*Output */ std::make_pair(ValType::TensorView, DataType::Float),
/*Inputs Tuple*/
std::make_tuple(
Expand All @@ -2001,7 +2001,7 @@ void testGPU_FusionCompoundOps() {
return at::lerp(
vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
},
/*JIT Func */ lerp,
/*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&lerp),
/*Output */ std::make_pair(ValType::TensorView, DataType::Float),
/*Inputs Tuple*/
std::make_tuple(
Expand All @@ -2020,7 +2020,7 @@ void testGPU_FusionCompoundOps() {
vals[2].toTensor(),
vals[3].toScalar());
},
/*JIT Func */ addcmul,
/*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(&addcmul),
/*Output */ std::make_pair(ValType::TensorView, DataType::Float),
/*Inputs Tuple*/
std::make_tuple(
Expand All @@ -2036,8 +2036,8 @@ void testGPU_FusionCastOps() {

TensorView* tv0 = makeDummyTensor(2, DataType::Half);

Val* intrm1 = castOp(DataType::Float, tv0);
TensorView* out = static_cast<TensorView*>(castOp(DataType::Half, intrm1));
TensorView* intrm1 = castOp(DataType::Float, tv0);
TensorView* out = castOp(DataType::Half, intrm1);

fusion.addInput(tv0);
fusion.addOutput(out);
Expand Down Expand Up @@ -2096,7 +2096,7 @@ void testGPU_FusionRFactorReplay() {

// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
TensorView* tv1 = static_cast<TensorView*>(sum(tv0, {1}));
TensorView* tv1 = sum(tv0, {1});
// tv1[I0, R1]
tv1->split(0, 32);
// tv1[I0o, I0i{32}, R1]
Expand Down Expand Up @@ -2183,8 +2183,7 @@ void testGPU_FusionReduction() {
fusion.addInput(tv0);

// tv1[I0, R1] = tv0[I0, I1]
TensorView* tv1 = static_cast<TensorView*>(
reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0));
TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
fusion.addOutput(tv1);

TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
Expand Down Expand Up @@ -2252,8 +2251,7 @@ void testGPU_FusionReduction2() {
fusion.addInput(tv0);

// tv1[I0, R1] = tv0[I0, I1]
TensorView* tv1 = static_cast<TensorView*>(
reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0));
TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);

fusion.addOutput(tv1);

Expand Down Expand Up @@ -2331,8 +2329,7 @@ void testGPU_FusionReduction2() {
fusion.addInput(tv0);

// tv1[I0, R1] = tv0[I0, I1]
TensorView* tv1 = static_cast<TensorView*>(
reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0));
TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);

fusion.addOutput(tv1);

Expand Down
Loading