Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 123 additions & 39 deletions test/cpp/jit/test_gpu_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ void testGPU_FusionSimpleArith() {

Float* f1 = new Float(1.f);
Float* f2 = new Float{2.f};
Float* f3 = new Float();

//Disrupt the fusion to make sure guard works well
{
Expand All @@ -70,7 +69,7 @@ void testGPU_FusionSimpleArith() {
ss2 << fusion2;
}

new BinaryOp(BinaryOpType::Add, f3, f1, f2);
Val* f3 = add(f1, f2);
ss1 << fusion;

TORCH_CHECK(ss1.str().compare(ss2.str()) == 0,
Expand Down Expand Up @@ -132,8 +131,8 @@ void testGPU_FusionRegister() {
FusionGuard fg(&fusion);
Float* v1 = new Float{1.f};
Float* v2 = new Float{2.f};
Val* v3 = binaryOp(BinaryOpType::Add, v1, v2);
Val* v4 = binaryOp(BinaryOpType::Add, v1, v2);
Val* v3 = add(v1, v2);
Val* v4 = add(v1, v2);
TORCH_CHECK(v1->name() + 1 == v2->name());
TORCH_CHECK(v2->name() + 1 == v3->name());
TORCH_CHECK(v3->name() + 1 == v4->name());
Expand Down Expand Up @@ -554,14 +553,18 @@ void testGPU_FusionParser() {

std::stringstream ref;
ref
<< "__device__ int ceilDiv(const int a, const int b) {\n"
<< " return (a + b - 1) / b;\n"
<< "}\n"
<< "\n"
<< "__global__ void kernel(Tensor<float> T0, Tensor<float> T1, Tensor<float> T3){\n"
<< " float T2[1];\n"
<< " if( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) < T1.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) < T1.size[1] ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) < T1.size[2] ) ) {\n"
<< " if( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) < T1.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) < T1.size[1] ) ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) < T1.size[2] ) ) ) {\n"
<< " T2[0]\n"
<< " = T0[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) * T0.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) * T0.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) * T0.stride[2]]\n"
<< " * T1[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) * T1.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) * T1.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) * T1.stride[2]];\n"
<< " }\n"
<< " if( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) < T0.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) < T0.size[1] ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) < T0.size[2] ) ) {\n"
<< " if( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) < T0.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) < T0.size[1] ) ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) < T0.size[2] ) ) ) {\n"
<< " T3[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) * T3.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) * T3.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) * T3.stride[2]]\n"
<< " = T2[0]\n"
<< " * T0[( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) / T0.size[1] ) * T0.stride[0] + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T0.size[2] ) % T0.size[1] ) * T0.stride[1] + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T0.size[2] ) * T0.stride[2]];\n"
Expand Down Expand Up @@ -658,7 +661,7 @@ void testGPU_FusionCodeGen() {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 = makeDummyTensor(4);
TensorView* tv0 = makeDummyTensor(3);

new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0));
TensorView* tv1 = static_cast<TensorView*>(add(tv0, new Float(2.0)));
Expand All @@ -675,45 +678,35 @@ void testGPU_FusionCodeGen() {
//[I0i{4}*I1, I0o, I2i{2}, I2o]
fusion.addOutput(tv2);

tv0->computeAt(tv2, 1);
tv0->computeAt(tv2, -1);

std::stringstream ref;
ref
<< "__device__ int ceilDiv(const int a, const int b) {\n"
<< " return (a + b - 1) / b;\n"
<< "}\n"
<< "\n"
<< "__global__ void kernel(Tensor<float> T2){\n"
<< " float T0[( ( ( 1 * ( ceilDiv(T2.size[0], 4) ) ) * T2.size[2] ) * T2.size[3] )];\n"
<< " for( size_t i27 = 0; i27 < ( 4 * T2.size[1] ); ++i27 ) {\n"
<< " for( size_t i29 = 0; i29 < ( ceilDiv(T2.size[0], 4) ); ++i29 ) {\n"
<< " for( size_t i31 = 0; i31 < T2.size[2]; ++i31 ) {\n"
<< " for( size_t i33 = 0; i33 < T2.size[3]; ++i33 ) {\n"
<< " if( ( ( ( i29 * 4 ) + ( i27 / T2.size[1] ) ) < T2.size[0] ) && ( ( i27 % T2.size[1] ) < T2.size[1] ) ) {\n"
<< " T0[i29 * T2.size[2] * T2.size[3] + i31 * T2.size[3] + i33]\n"
<< " float T0[1];\n"
<< " for( size_t i29 = 0; i29 < ( 4 * T2.size[1] ); ++i29 ) {\n"
<< " for( size_t i31 = 0; i31 < ( ceilDiv(T2.size[0], 4) ); ++i31 ) {\n"
<< " for( size_t i33 = 0; i33 < 2; ++i33 ) {\n"
<< " for( size_t i35 = 0; i35 < ( ceilDiv(T2.size[2], 2) ); ++i35 ) {\n"
<< " if( ( ( ( ( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) < T2.size[0] ) && ( ( i29 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i35 * 2 ) + i33 ) < T2.size[2] ) ) ) {\n"
<< " T0[0]\n"
<< " = float(0)\n"
<< " + float(1);\n"
<< " }\n"
<< " }\n"
<< " }\n"
<< " }\n"
<< " float T1[( ( ( 1 * ( ceilDiv(T2.size[0], 4) ) ) * T2.size[2] ) * T2.size[3] )];\n"
<< " for( size_t i55 = 0; i55 < ( ceilDiv(T2.size[0], 4) ); ++i55 ) {\n"
<< " for( size_t i57 = 0; i57 < T2.size[2]; ++i57 ) {\n"
<< " for( size_t i59 = 0; i59 < T2.size[3]; ++i59 ) {\n"
<< " if( ( ( ( i55 * 4 ) + ( i27 / T2.size[1] ) ) < T2.size[0] ) && ( ( i27 % T2.size[1] ) < T2.size[1] ) ) {\n"
<< " T1[i55 * T2.size[2] * T2.size[3] + i57 * T2.size[3] + i59]\n"
<< " = T0[i55 * T2.size[2] * T2.size[3] + i57 * T2.size[3] + i59]\n"
<< " float T1[1];\n"
<< " if( ( ( ( ( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) < T2.size[0] ) && ( ( i29 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i35 * 2 ) + i33 ) < T2.size[2] ) ) ) {\n"
<< " T1[0]\n"
<< " = T0[0]\n"
<< " + float(2);\n"
<< " }\n"
<< " }\n"
<< " }\n"
<< " }\n"
<< " for( size_t i85 = 0; i85 < ( ceilDiv(T2.size[0], 4) ); ++i85 ) {\n"
<< " for( size_t i87 = 0; i87 < ( ceilDiv(T2.size[3], 2) ); ++i87 ) {\n"
<< " for( size_t i89 = 0; i89 < T2.size[2]; ++i89 ) {\n"
<< " for( size_t i91 = 0; i91 < 2; ++i91 ) {\n"
<< " if( ( ( ( i85 * 4 ) + ( i27 / T2.size[1] ) ) < T2.size[0] ) && ( ( i27 % T2.size[1] ) < T2.size[1] ) && ( ( ( i87 * 2 ) + i91 ) < T2.size[3] ) ) {\n"
<< " T2[( ( i85 * 4 ) + ( i27 / T2.size[1] ) ) * T2.stride[0] + ( i27 % T2.size[1] ) * T2.stride[1] + i89 * T2.stride[2] + ( ( i87 * 2 ) + i91 ) * T2.stride[3]]\n"
<< " = T1[i85 * ( ceilDiv(T2.size[3], 2) ) * T2.size[2] * 2 + i87 * T2.size[2] * 2 + i89 * 2 + i91]\n"
<< " + float(3);\n"
<< " }\n"
<< " if( ( ( ( ( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) < T2.size[0] ) && ( ( i29 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i35 * 2 ) + i33 ) < T2.size[2] ) ) ) {\n"
<< " T2[( ( i31 * 4 ) + ( i29 / T2.size[1] ) ) * T2.stride[0] + ( i29 % T2.size[1] ) * T2.stride[1] + ( ( i35 * 2 ) + i33 ) * T2.stride[2]]\n"
<< " = T1[0]\n"
<< " + float(3);\n"
<< " }\n"
<< " }\n"
<< " }\n"
Expand All @@ -735,6 +728,28 @@ void testGPU_FusionCodeGen() {
TORCH_CHECK(false);
}

torch::jit::fuser::cuda::CudaKernel prog;
prog.device_ = 0;
// These can be set to anything as there are no bindings!
// All CTAS and threads execute the same thing.
prog.grid(4);
prog.block(32);

auto options =
at::TensorOptions()
.dtype(at::kFloat)
.device(at::kCUDA, 0);

at::Tensor output = at::empty({16,8,8}, options);
std::vector<at::Tensor> outputs{{output}};

torch::jit::fuser::cuda::compileKernel(fusion, prog);
torch::jit::fuser::cuda::runTestKernel(prog, {}, outputs);

at::Tensor output_ref = at::zeros_like(output, options);
output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0;

TORCH_CHECK(output_ref.equal(output));
}

void testGPU_FusionCodeGen2() {
Expand Down Expand Up @@ -767,6 +782,10 @@ void testGPU_FusionCodeGen2() {

std::stringstream ref;
ref
<< "__device__ int ceilDiv(const int a, const int b) {\n"
<< " return (a + b - 1) / b;\n"
<< "}\n"
<< "\n"
<< "__global__ void kernel(Tensor<float> T0, Tensor<float> T1, Tensor<float> T3){\n"
<< " float T2[1];\n"
<< " for( size_t i15 = 0; i15 < 4; ++i15 ) {\n"
Expand All @@ -784,8 +803,10 @@ void testGPU_FusionCodeGen2() {
<< " }\n"
<< " }\n"
<< "}\n"
;
std::stringstream cdg;
;

std::stringstream cdg;

CodeWrite cw(cdg);
cw.traverse(&fusion);

Expand Down Expand Up @@ -943,5 +964,68 @@ void testGPU_FusionExecKernel() {
TORCH_CHECK(output.equal(check));
}

void testGPU_FusionForLoop() {
Fusion fusion;
FusionGuard fg(&fusion);

const auto TV0 = new TensorView(new TensorDomain({new IterDomain(new Int(16))}), DataType::Float);
const auto TV1 = new TensorView(new TensorDomain({new IterDomain(new Int(16))}), DataType::Float);

fusion.addInput(TV0);
fusion.addInput(TV1);

auto ID0 = new IterDomain(new Int(8));

TensorView* TV2 = static_cast<TensorView*>(add(TV0, TV1));
BinaryOp* op = static_cast<BinaryOp*>(TV2->getOrigin());
fusion.addOutput(TV2);

ForLoop* fl = new ForLoop(new Int(), ID0, {op});
std::stringstream result;
std::stringstream ref;
result << fl;
ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}";

if(result.str().compare(ref.str()) == 0){
std::stringstream err_msg;
err_msg << "ForLoop printing has changed or something has gone wrong. "
<< result.str() << "\n does not match reference: " << ref.str() << std::endl;
TORCH_CHECK(false, err_msg.str());
}

}

void testGPU_FusionConstCheck() {
Fusion fusion;
FusionGuard fg(&fusion);

Val* cInt = new Int(2);
Val* sInt = new Int();
Val* cFloat = new Float(2.0);
Val* sFloat = new Float();
TORCH_CHECK(cInt->isConstScalar());
TORCH_CHECK(cFloat->isConstScalar());
TORCH_CHECK(!sInt->isConstScalar());
TORCH_CHECK(!sFloat->isConstScalar());
}

void testGPU_FusionReductionInternals() {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv1 = makeDummyTensor(3);
TensorView* tv2 = static_cast<TensorView*>(sum(tv1, {0, 1}));
std::cout << fusion << std::endl;

fusion.addInput(tv1);
fusion.addOutput(tv2);

CodeWrite cw(std::cout);
cw.traverse(&fusion);


}


} // namespace jit
} // namespace torch
57 changes: 30 additions & 27 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,36 @@ namespace jit {
_(TorchbindIValueAPI)

#if defined(USE_CUDA) && !defined(USE_ROCM)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
_(CompleteArgumentSpec) \
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
_(GPU_FusionDispatch) \
_(GPU_FusionSimpleArith) \
_(GPU_FusionSimpleTypePromote)\
_(GPU_FusionCastOp) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
_(GPU_FusionTopoSort) \
_(GPU_FusionTensor) \
_(GPU_FusionTensorContiguity) \
_(GPU_FusionTVSplit) \
_(GPU_FusionTVMerge) \
_(GPU_FusionTVReorder) \
_(GPU_FusionEquality) \
_(GPU_FusionReplaceAll) \
_(GPU_FusionParser) \
_(GPU_FusionDependency) \
_(GPU_FusionCodeGen) \
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
_(CompleteArgumentSpec) \
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
_(GPU_FusionDispatch) \
_(GPU_FusionSimpleArith) \
_(GPU_FusionSimpleTypePromote) \
_(GPU_FusionCastOp) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
_(GPU_FusionTopoSort) \
_(GPU_FusionTensor) \
_(GPU_FusionTensorContiguity) \
_(GPU_FusionTVSplit) \
_(GPU_FusionTVMerge) \
_(GPU_FusionTVReorder) \
_(GPU_FusionEquality) \
_(GPU_FusionReplaceAll) \
_(GPU_FusionParser) \
_(GPU_FusionDependency) \
_(GPU_FusionCodeGen) \
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop) \
_(GPU_FusionReductionInternals) \
_(GPU_FusionConstCheck)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
Loading