diff --git a/test/cpp/jit/test_gpu_fusion.cpp b/test/cpp/jit/test_gpu_fusion.cpp index 329e29f70dcac..fb90af1641f80 100644 --- a/test/cpp/jit/test_gpu_fusion.cpp +++ b/test/cpp/jit/test_gpu_fusion.cpp @@ -57,7 +57,6 @@ void testGPU_FusionSimpleArith() { Float* f1 = new Float(1.f); Float* f2 = new Float{2.f}; - Float* f3 = new Float(); //Disrupt the fusion to make sure guard works well { @@ -70,7 +69,7 @@ void testGPU_FusionSimpleArith() { ss2 << fusion2; } - new BinaryOp(BinaryOpType::Add, f3, f1, f2); + Val* f3 = add(f1, f2); ss1 << fusion; TORCH_CHECK(ss1.str().compare(ss2.str()) == 0, @@ -132,8 +131,8 @@ void testGPU_FusionRegister() { FusionGuard fg(&fusion); Float* v1 = new Float{1.f}; Float* v2 = new Float{2.f}; - Val* v3 = binaryOp(BinaryOpType::Add, v1, v2); - Val* v4 = binaryOp(BinaryOpType::Add, v1, v2); + Val* v3 = add(v1, v2); + Val* v4 = add(v1, v2); TORCH_CHECK(v1->name() + 1 == v2->name()); TORCH_CHECK(v2->name() + 1 == v3->name()); TORCH_CHECK(v3->name() + 1 == v4->name()); @@ -554,14 +553,18 @@ void testGPU_FusionParser() { std::stringstream ref; ref + << "__device__ int ceilDiv(const int a, const int b) {\n" + << " return (a + b - 1) / b;\n" + << "}\n" + << "\n" << "__global__ void kernel(Tensor T0, Tensor T1, Tensor T3){\n" << " float T2[1];\n" - << " if( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) < T1.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) < T1.size[1] ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) < T1.size[2] ) ) {\n" + << " if( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) < T1.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) < T1.size[1] ) ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) < T1.size[2] ) ) ) {\n" << " T2[0]\n" << " = T0[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) * T0.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) * T0.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) * T0.stride[2]]\n" << " * T1[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) * T1.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) * T1.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) * T1.stride[2]];\n" << " }\n" - << " if( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) < T0.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) < T0.size[1] ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) < T0.size[2] ) ) {\n" + << " if( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) < T0.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) < T0.size[1] ) ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) < T0.size[2] ) ) ) {\n" << " T3[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) * T3.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) * T3.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) * T3.stride[2]]\n" << " = T2[0]\n" << " * T0[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) * T0.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) * T0.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) * T0.stride[2]];\n" @@ -658,7 +661,7 @@ void testGPU_FusionCodeGen() { Fusion fusion; FusionGuard fg(&fusion); - TensorView* tv0 = makeDummyTensor(4); + TensorView* tv0 = makeDummyTensor(3); new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0)); TensorView* tv1 = static_cast(add(tv0, new Float(2.0))); @@ -675,45 +678,35 @@ void testGPU_FusionCodeGen() { //[I0i{4}*I1, I0o, I2i{2}, I2o] fusion.addOutput(tv2); - tv0->computeAt(tv2, 1); + tv0->computeAt(tv2, -1); std::stringstream ref; ref + << "__device__ int ceilDiv(const int a, const int b) {\n" + << " return (a + b - 1) / b;\n" + << "}\n" + << "\n" << "__global__ void kernel(Tensor T2){\n" - << " float T0[( ( ( 1 * ( ceilDiv(T2.size[0], 4) ) ) * T2.size[2] ) * T2.size[3] )];\n" - << " for( size_t i27 = 0; i27 < ( 4 * T2.size[1] ); ++i27 ) {\n" - << " for( size_t i29 = 0; i29 < ( ceilDiv(T2.size[0], 4) ); ++i29 ) {\n" - << " for( size_t i31 = 0; i31 < T2.size[2]; ++i31 ) {\n" - << " for( size_t i33 = 0; i33 < T2.size[3]; ++i33 ) {\n" - << " if( ( ( ( i29 * 4 ) + ( i27 / T2.size[1] ) ) < T2.size[0] ) && ( ( i27 % T2.size[1] ) < T2.size[1] ) ) {\n" - << " T0[i29 * T2.size[2] * T2.size[3] + i31 * T2.size[3] + i33]\n" + << " float T0[1];\n" + << " for( size_t i29 = 0; i29 < ( 4 * T2.size[1] ); ++i29 ) {\n" + << " for( size_t i31 = 0; i31 < ( ceilDiv(T2.size[0], 4) ); ++i31 ) {\n" + << " for( size_t i33 = 0; i33 < 2; ++i33 ) {\n" + << " for( size_t i35 = 0; i35 < ( ceilDiv(T2.size[2], 2) ); ++i35 ) {\n" + << " if( ( ( ( ( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) < T2.size[0] ) && ( ( i29 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i35 * 2 ) + i33 ) < T2.size[2] ) ) ) {\n" + << " T0[0]\n" << " = float(0)\n" << " + float(1);\n" << " }\n" - << " }\n" - << " }\n" - << " }\n" - << " float T1[( ( ( 1 * ( ceilDiv(T2.size[0], 4) ) ) * T2.size[2] ) * T2.size[3] )];\n" - << " for( size_t i55 = 0; i55 < ( ceilDiv(T2.size[0], 4) ); ++i55 ) {\n" - << " for( size_t i57 = 0; i57 < T2.size[2]; ++i57 ) {\n" - << " for( size_t i59 = 0; i59 < T2.size[3]; ++i59 ) {\n" - << " if( ( ( ( i55 * 4 ) + ( i27 / T2.size[1] ) ) < T2.size[0] ) && ( ( i27 % T2.size[1] ) < T2.size[1] ) ) {\n" - << " T1[i55 * T2.size[2] * T2.size[3] + i57 * T2.size[3] + i59]\n" - << " = T0[i55 * T2.size[2] * T2.size[3] + i57 * T2.size[3] + i59]\n" + << " float T1[1];\n" + << " if( ( ( ( ( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) < T2.size[0] ) && ( ( i29 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i35 * 2 ) + i33 ) < T2.size[2] ) ) ) {\n" + << " T1[0]\n" + << " = T0[0]\n" << " + float(2);\n" << " }\n" - << " }\n" - << " }\n" - << " }\n" - << " for( size_t i85 = 0; i85 < ( ceilDiv(T2.size[0], 4) ); ++i85 ) {\n" - << " for( size_t i87 = 0; i87 < ( ceilDiv(T2.size[3], 2) ); ++i87 ) {\n" - << " for( size_t i89 = 0; i89 < T2.size[2]; ++i89 ) {\n" - << " for( size_t i91 = 0; i91 < 2; ++i91 ) {\n" - << " if( ( ( ( i85 * 4 ) + ( i27 / T2.size[1] ) ) < T2.size[0] ) && ( ( i27 % T2.size[1] ) < T2.size[1] ) && ( ( ( i87 * 2 ) + i91 ) < T2.size[3] ) ) {\n" - << " T2[( ( i85 * 4 ) + ( i27 / T2.size[1] ) ) * T2.stride[0] + ( i27 % T2.size[1] ) * T2.stride[1] + i89 * T2.stride[2] + ( ( i87 * 2 ) + i91 ) * T2.stride[3]]\n" - << " = T1[i85 * ( ceilDiv(T2.size[3], 2) ) * T2.size[2] * 2 + i87 * T2.size[2] * 2 + i89 * 2 + i91]\n" - << " + float(3);\n" - << " }\n" + << " if( ( ( ( ( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) < T2.size[0] ) && ( ( i29 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i35 * 2 ) + i33 ) < T2.size[2] ) ) ) {\n" + << " T2[( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) * T2.stride[0] + ( i29 % T2.size[1] ) * T2.stride[1] + ( ( i35 * 2 ) + i33 ) * T2.stride[2]]\n" + << " = T1[0]\n" + << " + float(3);\n" << " }\n" << " }\n" << " }\n" @@ -735,6 +728,28 @@ void testGPU_FusionCodeGen() { TORCH_CHECK(false); } + torch::jit::fuser::cuda::CudaKernel prog; + prog.device_ = 0; + // These can be set to anything as there are no bindings! + // All CTAS and threads execute the same thing. + prog.grid(4); + prog.block(32); + + auto options = + at::TensorOptions() + .dtype(at::kFloat) + .device(at::kCUDA, 0); + + at::Tensor output = at::empty({16,8,8}, options); + std::vector outputs{{output}}; + + torch::jit::fuser::cuda::compileKernel(fusion, prog); + torch::jit::fuser::cuda::runTestKernel(prog, {}, outputs); + + at::Tensor output_ref = at::zeros_like(output, options); + output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; + + TORCH_CHECK(output_ref.equal(output)); } void testGPU_FusionCodeGen2() { @@ -767,6 +782,10 @@ void testGPU_FusionCodeGen2() { std::stringstream ref; ref + << "__device__ int ceilDiv(const int a, const int b) {\n" + << " return (a + b - 1) / b;\n" + << "}\n" + << "\n" << "__global__ void kernel(Tensor T0, Tensor T1, Tensor T3){\n" << " float T2[1];\n" << " for( size_t i15 = 0; i15 < 4; ++i15 ) {\n" @@ -784,8 +803,10 @@ void testGPU_FusionCodeGen2() { << " }\n" << " }\n" << "}\n" - ; - std::stringstream cdg; + ; + + std::stringstream cdg; + CodeWrite cw(cdg); cw.traverse(&fusion); @@ -943,5 +964,68 @@ void testGPU_FusionExecKernel() { TORCH_CHECK(output.equal(check)); } +void testGPU_FusionForLoop() { + Fusion fusion; + FusionGuard fg(&fusion); + + const auto TV0 = new TensorView(new TensorDomain({new IterDomain(new Int(16))}), DataType::Float); + const auto TV1 = new TensorView(new TensorDomain({new IterDomain(new Int(16))}), DataType::Float); + + fusion.addInput(TV0); + fusion.addInput(TV1); + + auto ID0 = new IterDomain(new Int(8)); + + TensorView* TV2 = static_cast(add(TV0, TV1)); + BinaryOp* op = static_cast(TV2->getOrigin()); + fusion.addOutput(TV2); + + ForLoop* fl = new ForLoop(new Int(), ID0, {op}); + std::stringstream result; + std::stringstream ref; + result << fl; + ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}"; + + if(result.str().compare(ref.str()) == 0){ + std::stringstream err_msg; + err_msg << "ForLoop printing has changed or something has gone wrong. " + << result.str() << "\n does not match reference: " << ref.str() << std::endl; + TORCH_CHECK(false, err_msg.str()); + } + +} + +void testGPU_FusionConstCheck() { + Fusion fusion; + FusionGuard fg(&fusion); + + Val* cInt = new Int(2); + Val* sInt = new Int(); + Val* cFloat = new Float(2.0); + Val* sFloat = new Float(); + TORCH_CHECK(cInt->isConstScalar()); + TORCH_CHECK(cFloat->isConstScalar()); + TORCH_CHECK(!sInt->isConstScalar()); + TORCH_CHECK(!sFloat->isConstScalar()); +} + +void testGPU_FusionReductionInternals() { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv1 = makeDummyTensor(3); + TensorView* tv2 = static_cast(sum(tv1, {0, 1})); + std::cout << fusion << std::endl; + + fusion.addInput(tv1); + fusion.addOutput(tv2); + + CodeWrite cw(std::cout); + cw.traverse(&fusion); + + +} + + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 5ba6ff3a52f64..0685308d72a63 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -91,33 +91,36 @@ namespace jit { _(TorchbindIValueAPI) #if defined(USE_CUDA) && !defined(USE_ROCM) -#define TH_FORALL_TESTS_CUDA(_) \ - _(ArgumentSpec) \ - _(CompleteArgumentSpec) \ - _(Fusion) \ - _(GraphExecutor) \ - _(ModuleConversion) \ - _(Interp) \ - _(GPU_FusionDispatch) \ - _(GPU_FusionSimpleArith) \ - _(GPU_FusionSimpleTypePromote)\ - _(GPU_FusionCastOp) \ - _(GPU_FusionMutator) \ - _(GPU_FusionRegister) \ - _(GPU_FusionTopoSort) \ - _(GPU_FusionTensor) \ - _(GPU_FusionTensorContiguity) \ - _(GPU_FusionTVSplit) \ - _(GPU_FusionTVMerge) \ - _(GPU_FusionTVReorder) \ - _(GPU_FusionEquality) \ - _(GPU_FusionReplaceAll) \ - _(GPU_FusionParser) \ - _(GPU_FusionDependency) \ - _(GPU_FusionCodeGen) \ - _(GPU_FusionCodeGen2) \ - _(GPU_FusionSimplePWise) \ - _(GPU_FusionExecKernel) +#define TH_FORALL_TESTS_CUDA(_) \ + _(ArgumentSpec) \ + _(CompleteArgumentSpec) \ + _(Fusion) \ + _(GraphExecutor) \ + _(ModuleConversion) \ + _(Interp) \ + _(GPU_FusionDispatch) \ + _(GPU_FusionSimpleArith) \ + _(GPU_FusionSimpleTypePromote) \ + _(GPU_FusionCastOp) \ + _(GPU_FusionMutator) \ + _(GPU_FusionRegister) \ + _(GPU_FusionTopoSort) \ + _(GPU_FusionTensor) \ + _(GPU_FusionTensorContiguity) \ + _(GPU_FusionTVSplit) \ + _(GPU_FusionTVMerge) \ + _(GPU_FusionTVReorder) \ + _(GPU_FusionEquality) \ + _(GPU_FusionReplaceAll) \ + _(GPU_FusionParser) \ + _(GPU_FusionDependency) \ + _(GPU_FusionCodeGen) \ + _(GPU_FusionCodeGen2) \ + _(GPU_FusionSimplePWise) \ + _(GPU_FusionExecKernel) \ + _(GPU_FusionForLoop) \ + _(GPU_FusionReductionInternals) \ + _(GPU_FusionConstCheck) #else #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 38020a25dc250..dd179b88c92ed 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -26,12 +27,11 @@ TORCH_CUDA_API Val* newValLike(const Val* const val, DataType dtype) { break; } - TORCH_CHECK( - false + TORCH_INTERNAL_ASSERT(false , "Could not generate a new value of type " - , val->getValType().value() - , " with data type " - , val->getDataType().value()); + , val->getValType().value() , " with data type " + , val->getDataType().value() + ); } TORCH_CUDA_API Val* newValLike(const Val* const val) { @@ -62,12 +62,11 @@ TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) { return v1; if (!is_cast_legal(v1->getDataType().value(), dtype)) { - TORCH_CHECK( - false - , "Illegal Cast value from DataType: " - , v1->getDataType().value() - , " to DataType: " - , dtype); + TORCH_CHECK(false + , "Illegal Cast value from DataType: " + , v1->getDataType().value() + , " to DataType: " + , dtype); } Val* out = newValLike(v1, dtype); @@ -75,13 +74,13 @@ TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) { return out; } -TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1) { +Val* unaryOp(UnaryOpType type, Val* v1) { Val* out = newValLike(v1); Statement* expr = new UnaryOp(type, out, v1); return out; } -TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { +Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { Val* out = promoteNew(v1, v2); if (type >= BinaryOpType::Mod) { if (out->getDataType().value() != DataType::Int) @@ -91,6 +90,7 @@ TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { return out; } +// BINARY OPERATIONS TORCH_CUDA_API Val* add(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Add, v1, v2); } @@ -119,6 +119,79 @@ TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2) { return binaryOp(BinaryOpType::CeilDiv, v1, v2); } +TORCH_CUDA_API Val* andOp(Val* v1, Val* v2) { + return binaryOp(BinaryOpType::And, v1, v2); +} + + +// REDUCTION OPERATIONS + +Val* reductionOp(BinaryOpType reduction_op_type, std::vector axes, Val* init, Val* v1){ + TORCH_CHECK(v1->getValType().value() == ValType::TensorView, + "Cannot reduce on values that are not TensorViews, but recieved type ", v1->getValType().value()); + TensorView* tv = static_cast(v1); + + std::vector uint_axes; + for(int axis : axes){ + + if(axis < 0) + axis += int(tv->nDims()); + + TORCH_CHECK( + axis >= 0 && axis < tv->nDims() + , "Reduction on invalid axis, recieved: " + , axis + , " however tensor view only has " + , tv->nDims() + , " dims."); + + uint_axes.push_back((unsigned int)axis); + } + + Val* out = tv->newForReduction(uint_axes); + new ReductionOp(reduction_op_type, init, out, v1); + + return out; +} + +Val* newConstScalar(DataType dtype, long int val) { + switch (dtype) { + case (DataType::Int): + return new Int( (int) val); + default: + break; + } + TORCH_CHECK(false + , "Could not generate a new Scalar with data type " + , dtype + , "and constant value: " + , val); +} + +Val* newConstScalar(DataType dtype, double val) { + switch (dtype) { + case (DataType::Float): + return new Float(val); + default: + break; + } + TORCH_CHECK(false + , "Could not generate a new Scalar with data type " + , dtype + , "and constant value: " + , val); +} + +TORCH_CUDA_API Val* sum(Val* v1, std::vector axes){ + + return reductionOp( + BinaryOpType::Add + , axes + , newConstScalar(v1->getDataType().value(), 0.0) + , v1); +} + + } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 4e488077a4a54..9197cfeaaebfa 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -25,12 +25,13 @@ TORCH_CUDA_API Val* promoteNew(Val* v1, Val* v2); TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1); // Perform unary op type and return the output -TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1); +Val* unaryOp(UnaryOpType type, Val* v1); // Perform binary op type on v1 and v2 and return a type promoted output. // Mod, CeilDiv, and LT are considered Int only output operations for now. -TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2); +Val* binaryOp(BinaryOpType type, Val* v1, Val* v2); +// BINARY OPERATIONS TORCH_CUDA_API Val* add(Val* v1, Val* v2); TORCH_CUDA_API Val* sub(Val* v1, Val* v2); TORCH_CUDA_API Val* mul(Val* v1, Val* v2); @@ -38,6 +39,10 @@ TORCH_CUDA_API Val* div(Val* v1, Val* v2); TORCH_CUDA_API Val* mod(Val* v1, Val* v2); TORCH_CUDA_API Val* lt(Val* v1, Val* v2); TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2); +TORCH_CUDA_API Val* andOp(Val* v1, Val* v2); + +// REDUCTION OPERATIONS +TORCH_CUDA_API Val* sum(Val* v1, std::vector reduction_axes); } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/code_write.cpp b/torch/csrc/jit/codegen/cuda/code_write.cpp index 4a7fbdf4712d3..3a22f6e6449ed 100644 --- a/torch/csrc/jit/codegen/cuda/code_write.cpp +++ b/torch/csrc/jit/codegen/cuda/code_write.cpp @@ -16,7 +16,7 @@ namespace fuser { std::vector CodeWrite::getLoopIndices() { std::vector inds; for (auto loop : fors) - inds.push_back(loop.first); + inds.push_back(loop->index()); return inds; } @@ -49,9 +49,9 @@ void CodeWrite::printIndexInto( // assuming we've printed something bool first_index = true; - for (size_t i{size_t(tv->getComputeAtAxis())}; - i < fors.size(); i++) { - if (fors[i].second->isThread()) + for (decltype(fors.size()) i{tv->getComputeAtAxis()}; i < fors.size(); + i++) { + if (fors[i]->range()->isThread()) continue; if (!first_index) @@ -60,14 +60,14 @@ void CodeWrite::printIndexInto( first_index = false; // Index - print_inline(fors[i].first); + print_inline(fors[i]->index()); for (decltype(fors.size()) j{i + 1}; j < fors.size(); j++) { - if (fors[j].second->isThread()) + if (fors[j]->range()->isThread()) continue; os << " * "; // Strides - print_inline(fors[j].second->size()); + print_inline(fors[j]->range()->size()); } } @@ -130,16 +130,19 @@ bool CodeWrite::print_predicate(const TensorView* const pred_tv) { bool first_pred = true; os << "if( "; + Val* cond = nullptr; for (decltype(preds.size()) i{0}; i < preds.size(); i++) { if (preds[i]->sameAs(new Int(1.0))) continue; if (!first_pred) - os << " && "; - - print_inline(preds[i]); + cond = andOp(cond, preds[i]); + else + cond = preds[i]; first_pred = false; } + new IfThenElse(cond, {}); + print_inline(cond); os << " ) {\n"; ++indent_size; indent(); @@ -197,6 +200,45 @@ void CodeWrite::handle(const UnaryOp* const uop) { } } +// The UnaryOps captured here will have a TensorView as an output +void CodeWrite::handle(const ReductionOp* const rop) { + TORCH_INTERNAL_ASSERT(isTVOp(rop), + "Recieved a reduction operation that is not on a TensorView: ", rop); + + bool predicated = printConsumer(static_cast(rop->out())); + os << rop->init() <<" ;\n"; + + printConsumer(static_cast(rop->out())); + + if (auto inline_rop = inline_op_str(rop->getReductionOpType())) { + handle(rop->out()); + os << "\n"; + indent(); + os << " "; + os << inline_rop.value() << " "; + handle(rop->in()); + } else { + os << rop->getReductionOpType() << "("; + handle(rop->out()); + os << "\n"; + indent(); + os << ", "; + handle(rop->in()); + os << ")"; + } + + consumer = nullptr; + producer = false; + + os << ";\n"; + + if (predicated) { + --indent_size; + indent(); + os << "}\n"; + } +} + // The BinaryOps captured here will have a TensorView as an output void CodeWrite::handle(const BinaryOp* const bop) { if (!isTVOp(bop)) { @@ -243,8 +285,8 @@ void CodeWrite::indent() { // Pop the inner most for loop void CodeWrite::closeFor() { - IterDomain* id = fors.back().second; - Val* iterator = fors.back().first; + IterDomain* id = fors.back()->range(); + Val* iterator = fors.back()->index(); fors.pop_back(); // Clear overrides associated with this for loop if (id->parallel_method() != ParallelType::Serial) { @@ -275,10 +317,10 @@ void CodeWrite::bind(IterDomain* id, Val* iterator) { // Push Back a new for loop scope based on the IterDomain void CodeWrite::openFor(IterDomain* id) { - fors.push_back({ new Int(), id }); + fors.push_back(new ForLoop(new Int(), id, {})); if (id->parallel_method() != ParallelType::Serial) { - bind(id, fors.back().first); + bind(id, fors.back()->index()); return; } @@ -286,13 +328,13 @@ void CodeWrite::openFor(IterDomain* id) { indent_size++; os << "for( size_t "; - handle(fors.back().first); + handle(fors.back()->index()); os << " = " << new Int(0) << "; "; - handle(fors.back().first); + handle(fors.back()->index()); os << " < "; print_inline(id->size()); os << "; ++"; - handle(fors.back().first); + handle(fors.back()->index()); os << " ) {" << std::endl; } @@ -317,14 +359,11 @@ void CodeWrite::printAlloc(TensorView* tv) { FusionGuard::getCurFusion()->hasOutput(tv)) return; - Int* size = new Int(1); - for (decltype(tv->nDims()) i = tv->getComputeAtAxis(); i < tv->nDims(); i++) { - size = static_cast(mul(size, tv->axis(i)->size())); - } + Allocate* alloc = new Allocate(tv); indent(); - os << tv->getDataType().value() << " T" << tv->name() << "["; - print_inline(size); + os << alloc->buf_type() << " T" << alloc->buf_name() << "["; + print_inline(alloc->extent()); os << "];" << std::endl; } @@ -379,8 +418,9 @@ void CodeWrite::updateView(TensorView* tv) { bool CodeWrite::isTVOp(const Expr* expr) { if (expr->nOutputs() == 1 && expr->output(0)->getValType().value() == ValType::TensorView) - if (expr->getExprType().value() == ExprType::BinaryOp || - expr->getExprType().value() == ExprType::UnaryOp) + if ( expr->getExprType().value() == ExprType::BinaryOp + || expr->getExprType().value() == ExprType::UnaryOp + || expr->getExprType().value() == ExprType::ReductionOp) return true; return false; } @@ -417,6 +457,13 @@ void CodeWrite::setupOverrides() { // Print the header for the kernel, the inputs/outputs // TODO: Push this out to another class so we don't need dispatch implemented here void CodeWrite::header() { + // ceilDiv Helper funtion + os + << "__device__ int ceilDiv(const int a, const int b) {\n" + << " return (a + b - 1) / b;\n" + << "}\n\n" + ; + os << "__global__ void " << kernel_name_ << "("; std::deque vals; @@ -478,7 +525,7 @@ void CodeWrite::traverse(Fusion* fusion) { producer = false; consumer = nullptr; - fors = std::vector >(); + fors = std::vector(); indent_size = 0; active_view = nullptr; active_view_axis = 0; diff --git a/torch/csrc/jit/codegen/cuda/code_write.h b/torch/csrc/jit/codegen/cuda/code_write.h index a168a7cdd0ff9..d7a0c674a7484 100644 --- a/torch/csrc/jit/codegen/cuda/code_write.h +++ b/torch/csrc/jit/codegen/cuda/code_write.h @@ -17,6 +17,20 @@ namespace torch { namespace jit { namespace fuser { +/* +std::ostream& operator<<(std::ostream& os, std::vector vec) { + os << "<"; + for (int i = 0; i < vec.size(); i++) { + IRPrinter(os).print_inline(vec[i]); + if (i == vec.size() - 1) + os << ">"; + else + os << ","; + } + return os; +} +*/ + // Run through and grab all values that are used in this fusion based on // the registered outputs. struct FindUsedVals : public IterVisitor { @@ -55,7 +69,8 @@ struct TORCH_CUDA_API CodeWrite : public IRPrinter { void handle(const Val* const); void handle(const UnaryOp* const); void handle(const BinaryOp* const); - + void handle(const ReductionOp* const); + /****END CODE PRINTING FUNCTIONS****/ // Ignore split/merge/reorder operations, @@ -87,7 +102,7 @@ struct TORCH_CUDA_API CodeWrite : public IRPrinter { TensorView* consumer = nullptr; // Track the for loops - std::vector >fors; + std::vector fors; // Track the indentation size for pretty printing int indent_size = 0; diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 31dfee9d2eb88..3efac00403fcd 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -86,6 +86,15 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::BinaryOp: ptr(handler)->handle(static_cast(expr)); return; + case ExprType::ReductionOp: + ptr(handler)->handle(static_cast(expr)); + return; + case ExprType::ForLoop: + ptr(handler)->handle(static_cast(expr)); + return; + case ExprType::IfThenElse: + ptr(handler)->handle(static_cast(expr)); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -148,6 +157,15 @@ void Expr::constDispatch(T handler, const Expr* const expr) { case ExprType::BinaryOp: ptr(handler)->handle(static_cast(expr)); return; + case ExprType::ReductionOp: + ptr(handler)->handle(static_cast(expr)); + return; + case ExprType::ForLoop: + ptr(handler)->handle(static_cast(expr)); + return; + case ExprType::IfThenElse: + ptr(handler)->handle(static_cast(expr)); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -211,6 +229,12 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { return ptr(mutator)->mutate(static_cast(expr)); case ExprType::BinaryOp: return ptr(mutator)->mutate(static_cast(expr)); + case ExprType::ReductionOp: + return ptr(mutator)->mutate(static_cast(expr)); + case ExprType::ForLoop: + return ptr(mutator)->mutate(static_cast(expr)); + case ExprType::IfThenElse: + return ptr(mutator)->mutate(static_cast(expr)); default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 0b9ac8ad2e6d6..456c96ddbf748 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -63,6 +63,9 @@ struct Merge; struct Reorder; struct UnaryOp; struct BinaryOp; +struct ReductionOp; +struct ForLoop; +struct IfThenElse; /* * By default, all IR nodes are handled in this dispatch, and will call an empty @@ -96,6 +99,9 @@ struct TORCH_CUDA_API OptOutDispatch { virtual void handle(Reorder*) {} virtual void handle(UnaryOp*) {} virtual void handle(BinaryOp*) {} + virtual void handle(ReductionOp*) {} + virtual void handle(ForLoop*) {} + virtual void handle(IfThenElse*) {} }; struct TORCH_CUDA_API OptInConstDispatch { @@ -146,6 +152,15 @@ struct TORCH_CUDA_API OptInConstDispatch { virtual void handle(const BinaryOp* const) { AT_ERROR("Handle not overriden for BinaryOp."); } + virtual void handle(const ReductionOp* const) { + AT_ERROR("Handle not overriden for ReductionOp."); + } + virtual void handle(const ForLoop* const) { + AT_ERROR("Handle not overriden for ForLoop."); + } + virtual void handle(const IfThenElse* const) { + AT_ERROR("Handle not overriden for IfThenElse."); + } }; struct TORCH_CUDA_API OptInDispatch { @@ -196,6 +211,15 @@ struct TORCH_CUDA_API OptInDispatch { virtual void handle(BinaryOp*) { AT_ERROR("Handle not overriden for BinaryOp."); } + virtual void handle(ReductionOp*) { + AT_ERROR("Handle not overriden for ReductionOp."); + } + virtual void handle(ForLoop*) { + AT_ERROR("Handle not overriden for ForLoop."); + } + virtual void handle(IfThenElse*) { + AT_ERROR("Handle not overriden for IfThenElse."); + } }; struct TORCH_CUDA_API OptOutMutator { @@ -229,6 +253,9 @@ struct TORCH_CUDA_API OptOutMutator { virtual Statement* mutate(Reorder*); virtual Statement* mutate(UnaryOp*); virtual Statement* mutate(BinaryOp*); + virtual Statement* mutate(ReductionOp*); + virtual Statement* mutate(ForLoop*); + virtual Statement* mutate(IfThenElse*); }; struct TORCH_CUDA_API OptInMutator { @@ -279,6 +306,15 @@ struct TORCH_CUDA_API OptInMutator { virtual Statement* mutate(BinaryOp*) { AT_ERROR("Mutate not overriden for BinaryOp."); } + virtual Statement* mutate(ReductionOp*) { + AT_ERROR("Mutate not overriden for ReductionOp."); + } + virtual Statement* mutate(ForLoop*) { + AT_ERROR("Mutate not overriden for ForLoop."); + } + virtual Statement* mutate(IfThenElse*) { + AT_ERROR("Mutate not overriden for IfThenElse."); + } }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 982901ad16667..db52736ac91bc 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -41,6 +41,34 @@ Expr* Val::getOrigin() { return (fusion_->origin(this)); } + +struct IsConstScalar : public OptInConstDispatch{ + +private: +virtual void handle(const Float* const f){ + isConst = f->isConst(); +} + +virtual void handle(const Int* const i) { + isConst = i->isConst(); +} + +bool isConst = false; + +public: +static bool check(const Val* const val){ + IsConstScalar ics; + static_cast(&ics)->handle(val); + return ics.isConst; +} + +}; + +bool Val::isConstScalar() const { + return IsConstScalar::check(this); +} + + bool IRInputOutput::hasInput(const Val* const input) const { for (auto val : inputs_) if (val == input) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index f0b8befe48fe7..a74e28e9f07c1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -155,6 +155,8 @@ struct TORCH_CUDA_API Val : public Statement { return vtype_ == ValType::Scalar; } + bool isConstScalar() const; + // Returns the Expr that this value is an output of, returns nullptr if none // was found Expr* getOrigin(); diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 4b079ff0eedf0..8dc9ca0e1f95d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -110,6 +110,10 @@ struct TORCH_CUDA_API TensorView : public Val { // (minus reduction IterDomains). TensorView* newForOutput(DataType dtype) const; + // Make a new tensor with the given dtype, same domain as this tensor, minus + // reduction IterDomains, with new reduced axes marked as so. + TensorView* newForReduction(std::vector axes) const; + // Make an exact copy of this tensor with the same dtype and same domain TensorView* clone() const; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index a474feb12344d..baf88963792d3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -83,6 +83,39 @@ struct TORCH_CUDA_API BinaryOp : public Expr { Val* const rhs_; }; +/* + * A specialization for Unary operations. Unary operations take in a single + * input and produce a single output. Examples include: + * 1) Casting operation i.e. float(a_val) + * 2) Negation i.e. val * -1 + * 3) Reduction across a dimension i.e. val.sum(axis=2) + * 4) split/merge/reorder + */ +struct TORCH_CUDA_API ReductionOp : public Expr { + ~ReductionOp() = default; + ReductionOp(BinaryOpType _reduction_op_type, Val* _init, Val* _out, Val* _in); + + ReductionOp(const ReductionOp& other) = delete; + ReductionOp& operator=(const ReductionOp& other) = delete; + + ReductionOp(ReductionOp&& other) = delete; + ReductionOp& operator=(ReductionOp&& other) = delete; + + Val* out() const noexcept { return out_; } + Val* in() const noexcept { return in_; } + Val* init() const noexcept { return init_; } + + BinaryOpType getReductionOpType() const noexcept { return reduction_op_type_; } + + bool sameAs(const ReductionOp* const other) const; + + private: + const BinaryOpType reduction_op_type_; + Val* const init_; + Val* const out_; + Val* const in_; +}; + /* * Simply a representation of an iterable from 0 to size. TensorDomains which * represent how to iterate over a tensor is made up of IterDomains. We directly @@ -267,5 +300,168 @@ struct TORCH_CUDA_API Reorder : public Expr { const std::vector pos2axis_; }; + +/* + * ForLoop provides scoping around an int iterator from 0 to range. Exprs placed + * in its body are considered inside the scope of the for loop. In the future + * the implementation should look quite different so that we can do proper + * dependency annalysis like in Fusion. + * + * TODO: Change implmentation of Exprs contained in the scope to be more similar + * to Fusion where we can do proper dependency analysis. + */ +struct TORCH_API ForLoop : public Expr { + ~ForLoop() = default; + ForLoop( + Int* _index, + IterDomain* _range, + const std::vector& _body); + + ForLoop(const ForLoop& other) = delete; + ForLoop& operator=(const ForLoop& other) = delete; + + ForLoop(ForLoop&& other) = delete; + ForLoop& operator=(ForLoop&& other) = delete; + + Int* index() const noexcept { + return index_; + } + IterDomain* range() const noexcept { + return range_; + } + + const std::vector& body() const noexcept { + return body_; + } + + void add_expr(const Expr* e) { + body_.push_back(e); + } + + void remove_expr(const Expr* e); + bool sameAs(const ForLoop* other) const; + + private: + Int* const index_; + IterDomain* const range_; + std::vector body_; +}; + + +/* + * IfThenElse provides scoping for an boolean operator. Exprs placed in its body + * are considered inside the scope of the if statement. In the future the + * implementation should look quite different so that we can do proper + * dependency annalysis like in Fusion. + * + * TODO: Change implmentation of Exprs contained in the scope to be more similar + * to Fusion where we can do proper dependency analysis. + */ +struct TORCH_API IfThenElse : public Expr { + ~IfThenElse() = default; + IfThenElse( + Val* _cond, + const std::vector& _if_body, + const std::vector& _else_body = {}); + + IfThenElse(const IfThenElse& other) = delete; + IfThenElse& operator=(const IfThenElse& other) = delete; + + IfThenElse(IfThenElse&& other) = delete; + IfThenElse& operator=(IfThenElse&& other) = delete; + + Val* cond() const noexcept { + return cond_; + } + const std::vector& if_body() const noexcept { + return if_body_; + } + const std::vector& else_body() const noexcept { + return else_body_; + } + + void add_if_expr(const Expr* e) { + if_body_.push_back(e); + } + void add_else_expr(const Expr* e) { + else_body_.push_back(e); + } + + bool hasElse() const noexcept { + return !else_body_.empty(); + } + + bool sameAs(const IfThenElse* other) const; + + private: + // TODO: Why is the pointer const and not what's in the object? + Val* const cond_; + std::vector if_body_; + std::vector else_body_; +}; + + +/* + * TODO: Fill out TensorIndex, which is a list of Ints used to directly index a + * TensorView. It is not the flattened index, which needs to be computed using + * stride information. + */ +struct TORCH_API TensorIndex : public Val { + ~TensorIndex() = default; + + TensorIndex(const TensorIndex& other) = delete; + TensorIndex& operator=(const TensorIndex& other) = delete; + + TensorIndex(TensorIndex&& other) = delete; + TensorIndex& operator=(TensorIndex&& other) = delete; + + TensorIndex(std::vector _indices) + : Val(ValType::TensorIndex), indices_(_indices) {} + + std::vector::size_type size() const { + return indices_.size(); + } + + bool sameAs(const TensorIndex* const other) const; + //i here is int, as we want to accept negative value and ::size_type can be a uint. + Int* axis(int i) const; + + private: + std::vector indices_; +}; + +/* + * Allocate is a lower level Node that describes a buffer of memory that + * is required as an intermediate within a kernel. The extent is the expression + * of the size of the buffer that is generated from the TensorView that describes + * the output of an operation. + * + * TODO: + * 1.) Should extent_ be an Expr vs a Val? The Val is currently used to print + * the Expr of the size(). + * 2.) The components of Allocate like Type and Name could be separated from the + * the assocated TensorView. Perhaps that is more appropriate? + */ +struct TORCH_API Allocate : public Expr { + ~Allocate() = default; + Allocate(TensorView* _tv); + + Allocate(const Allocate& other) = delete; + Allocate& operator=(const Allocate& other) = delete; + + Allocate(Allocate&& other) = delete; + Allocate& operator=(Allocate&& other) = delete; + + DataType buf_type() const noexcept; + StmtNameType buf_name() const noexcept; + const Val* extent() const noexcept; + + bool sameAs(const Allocate* other) const; + + private: + const TensorView* const buffer_; + const Val* extent_; +}; + }}} diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index aae752586093f..c16dbe4cd8c2e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1,6 +1,7 @@ #include #include +#include namespace torch { namespace jit { @@ -35,11 +36,12 @@ bool UnaryOp::sameAs(const UnaryOp* const other) const { BinaryOp::BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs) - : Expr(ExprType::BinaryOp), - binary_op_type_{_type}, - out_{_out}, - lhs_{_lhs}, - rhs_{_rhs} { + : Expr(ExprType::BinaryOp) + , binary_op_type_{_type} + , out_{_out} + , lhs_{_lhs} + , rhs_{_rhs} + { addOutput(_out); addInput(_lhs); addInput(_rhs); @@ -55,6 +57,28 @@ bool BinaryOp::sameAs(const BinaryOp* other) const { } +ReductionOp::ReductionOp(BinaryOpType _reduction_op_type, Val* _init, Val* _out, Val* _in) + : Expr(ExprType::ReductionOp) + , reduction_op_type_(_reduction_op_type) + , init_(_init) + , out_{_out} + , in_{_in} +{ + TORCH_INTERNAL_ASSERT(_init->isConstScalar()); + addOutput(_out); + addInput(_in); + this->name_ = FusionGuard::getCurFusion()->registerExpr(this); +} + +bool ReductionOp::sameAs(const ReductionOp* other) const { + return ( + this->in()->sameAs(other->in()) + && this->getReductionOpType() == other->getReductionOpType() + && this->init()->sameAs(other->init()) + ); +} + + IterDomain::IterDomain( Int* _size, ParallelType _parallel_method, @@ -160,6 +184,120 @@ bool Reorder::sameAs(const Reorder* const other) const { return (out()->sameAs(other->out()) && in()->sameAs(other->in())); } + +ForLoop::ForLoop( + Int* _index, + IterDomain* _range, + const std::vector& _body) + : Expr(ExprType::ForLoop), index_{_index}, range_{_range}, body_{_body} { + addInput(_index); + addInput(_range); + this->name_ = FusionGuard::getCurFusion()->registerExpr(this); +} + +void ForLoop::remove_expr(const Expr* e) { + auto it = body_.begin(); + for (; it != body_.end(); ++it) + if (*it == e) + break; + if (it != body_.end()) + body_.erase(it); +} + +bool ForLoop::sameAs(const ForLoop* other) const { + if (this->range() != other->range()) + return false; + if (body().size() != other->body().size()) + return false; + for (decltype(body().size()) i{0}; i < body().size(); i++) + if (!body()[i]->sameAs(other->body()[i])) + return false; + return true; +} + + +IfThenElse::IfThenElse( + Val* _cond, + const std::vector& _if_body, + const std::vector& _else_body) + : Expr(ExprType::IfThenElse), + cond_{_cond}, + if_body_{_if_body}, + else_body_{_else_body} { + addInput(_cond); + this->name_ = FusionGuard::getCurFusion()->registerExpr(this); +} + +bool IfThenElse::sameAs(const IfThenElse* other) const { + if (this->cond() != other->cond()) + return false; + if (this->hasElse() != other->hasElse()) + return false; + + for (decltype(if_body().size()) i{0}; i < if_body().size(); i++) + if (!if_body()[i]->sameAs(other->if_body()[i])) + return false; + + if (hasElse()) + for (decltype(else_body().size()) i{0}; i < else_body().size(); i++) + if (!else_body()[i]->sameAs(other->else_body()[i])) + return false; + return true; +} + + +bool TensorIndex::sameAs(const TensorIndex* const other) const { + if (size() != other->size()) + return false; + + for (decltype(size()) i = 0; i < size(); i++) + if (!(axis(i)->sameAs(other->axis(i)))) + return false; + + return true; +} + +Int* TensorIndex::axis(int i) const { + if (i < 0) + i += size(); + assert(i >= 0 && i < size()); + return indices_[i]; +} + +Allocate::Allocate(TensorView* _tv) + : Expr(ExprType::Allocate), + buffer_(_tv), + extent_{nullptr} { + Val* size = new Int(1); + for (auto i = _tv->getComputeAtAxis(); i < _tv->nDims(); i++) { + size = mul(size, _tv->axis(i)->size()); + } + extent_ = size; + + this->name_ = FusionGuard::getCurFusion()->registerExpr(this); +} + +DataType Allocate::buf_type() const noexcept { + return buffer_->getDataType().value(); +} +StmtNameType Allocate::buf_name() const noexcept { + return buffer_->name(); +} +const Val* Allocate::extent() const noexcept { + return extent_; +} + +bool Allocate::sameAs(const Allocate* other) const { + if(this->type() != other->type()) + return false; + if(this->buf_name() != other->buf_name()) + return false; + if(!this->extent()->sameAs(other->extent())) + return false; + + return true; +} + } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index b0471058e08e8..332bf3667fc81 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -50,6 +50,9 @@ class TORCH_CUDA_API IRTransformPrinter : public IRPrinter { IRPrinter::handle(bop); } + void handle(const ForLoop* const) override {} + void handle(const IfThenElse* const) override {} + void handle(Fusion* f) override { IRPrinter::handle(f); } diff --git a/torch/csrc/jit/codegen/cuda/iriostream.cpp b/torch/csrc/jit/codegen/cuda/iriostream.cpp index 4ab613ef8628c..b21268c0fce2e 100644 --- a/torch/csrc/jit/codegen/cuda/iriostream.cpp +++ b/torch/csrc/jit/codegen/cuda/iriostream.cpp @@ -59,6 +59,17 @@ void IRPrinter::handle(const IterDomain* const id) { os << "}"; } +void IRPrinter::handle(const TensorIndex* const ti) { + os << "[ "; + for(decltype(ti->size()) i{0}; i < ti->size(); i++){ + print_inline(ti->axis(i)); + if(i != ti->size() - 1) + os<<", "; + } + os<<" ]"; +} + + void IRPrinter::handle(const TensorContiguity* const t) { os << "format_tag: " << t->getContiguityTag(); } @@ -146,6 +157,39 @@ void IRPrinter::handle(const BinaryOp* const bop) { os<<"\n"; } +void IRPrinter::handle(const ReductionOp* const rop) { + os << rop->out() << " = reduction( " << rop->in() + << ", op = " << rop->getReductionOpType() + << ", initial value = " << rop->init() << ")\n"; +} + +void IRPrinter::handle(const ForLoop* const fl) { + os <<"for(size_t " << fl->index() << "{0}; " + << fl->index() << " < " << fl->range() << "; " + << "++" << fl->index() <<" ) {\n"; + + for(auto &expr : fl->body()) + handle(expr); + + os << "}\n"; +} + +void IRPrinter::handle(const IfThenElse* const ite) { + os << "if ( "; + print_inline(ite->cond()); + os << " ) { \n"; + for(auto &expr : ite->if_body()) { + handle(expr); + } + if(ite->hasElse()) { + os << "} else { \n"; + for(auto &expr : ite->else_body()) { + handle(expr); + } + } + os<<"}\n"; +} + void IRPrinter::handle(const Split* const s) { os << "Split: "; handle( s->in() ); diff --git a/torch/csrc/jit/codegen/cuda/iriostream.h b/torch/csrc/jit/codegen/cuda/iriostream.h index 3e6aa741239b8..16198997c1c2f 100644 --- a/torch/csrc/jit/codegen/cuda/iriostream.h +++ b/torch/csrc/jit/codegen/cuda/iriostream.h @@ -19,10 +19,15 @@ struct Expr; struct UnaryOp; struct BinaryOp; +struct ReductionOp; + +struct ForLoop; +struct IfThenElse; struct TensorDomain; struct TensorView; struct IterDomain; +struct TensorIndex; struct TensorContiguity; @@ -70,6 +75,7 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { virtual void handle(const TensorDomain* const); virtual void handle(const TensorView* const); virtual void handle(const IterDomain* const); + virtual void handle(const TensorIndex* const); virtual void handle(const TensorContiguity* const); virtual void handle(const Float* const); @@ -77,6 +83,10 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { virtual void handle(const UnaryOp* const); virtual void handle(const BinaryOp* const); + virtual void handle(const ReductionOp* const); + + virtual void handle(const ForLoop* const); + virtual void handle(const IfThenElse* const); virtual void handle(const Split* const); virtual void handle(const Merge* const); diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index e3671c8750f99..b207944933a7c 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -119,11 +119,36 @@ Statement* OptOutMutator::mutate(BinaryOp* bop) { Val* out = static_cast(mutate(bop->out())); Val* lhs = static_cast(mutate(bop->lhs())); Val* rhs = static_cast(mutate(bop->rhs())); - if (!(out != bop->out() && lhs != bop->lhs() && rhs != bop->rhs())) + if (!( + out->sameAs(bop->out()) + && lhs->sameAs(bop->lhs()) + && rhs->sameAs(bop->rhs()) + )) return new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); return bop; } +Statement* OptOutMutator::mutate(ReductionOp* rop) { + Val* out = static_cast(mutate(rop->out())); + Val* in = static_cast(mutate(rop->in())); + Val* init = rop->init(); + if(!( + out->sameAs(rop->out()) + && in->sameAs(rop->in()) + && init->sameAs(rop->init()) + )) + return new ReductionOp(rop->getReductionOpType(), init, out, in); + + return rop; +} + +Statement* OptOutMutator::mutate(ForLoop* n) { + return n; +} +Statement* OptOutMutator::mutate(IfThenElse* n) { + return n; +} + Statement* OptInMutator::mutate(Statement* s) { return Statement::mutatorDispatch(this, s); } @@ -134,6 +159,7 @@ Statement* OptInMutator::mutate(Val* v) { return Val::mutatorDispatch(this, v); } + Statement* ReplaceAll::mutate( Val* val){ if(val->sameAs(instance_)) return with_; diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index dfcdcef3ee25a..c1a12c10f817b 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -226,6 +226,33 @@ TensorView* TensorView::newForOutput(DataType dtype) const { return new TensorView(td, dtype); }; +TensorView* TensorView::newForReduction(std::vector axes) const { + std::vector domain_copy; + int ref_dim = 0; + for (decltype(this->nDims()) orig_dim = 0; orig_dim < this->nDims(); orig_dim++) { + // If reduction axis, don't copy it over. Reduction axes are owned by + // consumers and we're copying over a producer. + if (this->axis(orig_dim)->isReduction()) + continue; + + //Check if this dim should be reduced based on axes + bool isReduction = false; + for(decltype(axes.size()) i{0}; iaxis(orig_dim)->size(), ParallelType::Serial, isReduction)); + ref_dim++; + + } + + TensorDomain* td = new TensorDomain(domain_copy); + return new TensorView(td, this->getDataType().value()); +}; + TensorDomain* TensorView::getRootDomain() const { return TransformIter::getRoot(this->domain()); }; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 732e88c4b4370..53005a0e60c2e 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -220,6 +220,7 @@ TensorView* TransformReplay::runReplay( TensorView* replay_ref, TensorView* replay_target, int compute_at_axis) { + if (compute_at_axis < 0) compute_at_axis += int(replay_ref->nDims()) + 1; @@ -230,6 +231,13 @@ TensorView* TransformReplay::runReplay( this->compute_at_axis = compute_at_axis; + // If this is a reduction operation, we may call transform_replay on the same + // tensor view. When this happens, just return thet target view. + if( replay_ref->getRootDomain()->sameAs(replay_target->getRootDomain()) + && replay_ref->getComputeAtView()->sameAs(replay_target->getComputeAtView()) + && replay_ref->getComputeAtAxis() == replay_target->getComputeAtAxis()) + return replay_target; + /* STEP 1 */ // Reset the tensor domain of the target, this is the only way we can be // certain That we can actually replay the ops of ref. @@ -247,8 +255,8 @@ TensorView* TransformReplay::runReplay( // used during replay to forward propagate influence. std::vector root_influence_vector = influence; - // Remove isReduction from the axis_map of a producer - // isReduction is only impactful when its on a consumer + // Remove isReduction from the axis_map of a producer isReduction is only + // impactful when its on a consumer. auto init_size = replay_target->nDims(); for (decltype(init_size) i = 0; i < init_size; i++) if (!replay_target->axis(i)->isReduction()) diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 1fff656e06156..c9275ca4e0f9d 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -44,6 +44,7 @@ static std::unordered_map expr_type_string_map { , {ExprType::BinaryOp, "BinaryOp"} , {ExprType::ForLoop, "ForLoop"} , {ExprType::IfThenElse, "IfThenElse"} + , {ExprType::Allocate, "Allocate"} , {ExprType::Split, "Split"} , {ExprType::Merge, "Merge"} , {ExprType::Reorder, "Reorder"} @@ -63,6 +64,7 @@ static std::unordered_map binary_op_type_string_map { , {BinaryOpType::Mod, "Mod" } , {BinaryOpType::LT, "LessThan"} , {BinaryOpType::CeilDiv, "ceilDiv" } + , {BinaryOpType::And, "And" } }; static std::unordered_map binary_op_type_inline_op_string_map { {BinaryOpType::Add, "+" } @@ -71,6 +73,7 @@ static std::unordered_map binary_op_type_inline_op_st , {BinaryOpType::Div, "/" } , {BinaryOpType::Mod, "%" } , {BinaryOpType::LT, "<" } + , {BinaryOpType::And, "&&" } }; static std::unordered_map parallel_type_string_map { diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 8e30fa2690ab7..2072482648085 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -31,13 +31,13 @@ enum class TORCH_CUDA_API DataType { enum class TORCH_CUDA_API ExprType { UnaryOp , BinaryOp + , ReductionOp , ForLoop , IfThenElse + , Allocate , Split , Merge , Reorder -// , Swap -// , Index }; enum class TORCH_CUDA_API UnaryOpType { @@ -50,10 +50,12 @@ enum class TORCH_CUDA_API BinaryOpType { , Sub , Mul , Div - //Int operations, leave position oif Mod we depend on its location of first Int op + // Int operations, leave Mod as the first int operator + // as we we depend on its location , Mod , LT , CeilDiv + , And }; enum class TORCH_CUDA_API ParallelType {