Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
Expand Down
12 changes: 6 additions & 6 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ void testGPU_FusionParser() {
fuser::cuda::parseJitIR(g, fusion);

std::stringstream ref;
ref << "__global__ void kernel(Tensor<float> T0, Tensor<float> T1, Tensor<float> T3){\n"
ref << "__global__ void kernel(Tensor<float, 3> T0, Tensor<float, 3> T1, Tensor<float, 3> T3){\n"
<< " float T2[1];\n"
<< " if( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) / T1.size[1] ) < T1.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T1.size[2] ) % T1.size[1] ) < T1.size[1] ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T1.size[2] ) < T1.size[2] ) ) {\n"
<< " T2[0]\n"
Expand Down Expand Up @@ -675,7 +675,7 @@ void testGPU_FusionCodeGen() {
tv0->computeAt(tv2, 1);

std::stringstream ref;
ref << "__global__ void kernel(Tensor<float> T2){\n"
ref << "__global__ void kernel(Tensor<float, 4> T2){\n"
<< " float T0[( ( ( 1 * ( ceilDiv(T2.size[0], 4) ) ) * T2.size[2] ) * T2.size[3] )];\n"
<< " for( size_t i27 = 0; i27 < ( 4 * T2.size[1] ); ++i27 ) {\n"
<< " for( size_t i29 = 0; i29 < ( ceilDiv(T2.size[0], 4) ); ++i29 ) {\n"
Expand Down Expand Up @@ -760,7 +760,7 @@ void testGPU_FusionCodeGen2() {
tv3->axis(-1)->parallelize(ParallelType::TIDx);

std::stringstream ref;
ref << "__global__ void kernel(Tensor<float> T0, Tensor<float> T1, Tensor<float> T3){\n"
ref << "__global__ void kernel(Tensor<float, 3> T0, Tensor<float, 3> T1, Tensor<float, 3> T3){\n"
<< " float T2[1];\n"
<< " for( size_t i15 = 0; i15 < 4; ++i15 ) {\n"
<< " for( size_t i17 = 0; i17 < T1.size[1]; ++i17 ) {\n"
Expand Down Expand Up @@ -805,7 +805,7 @@ void testGPU_FusionCodeGen2() {
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};

torch::jit::fuser::cuda::compileKernel(fusion, prog);
torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);

at::Tensor tv2_ref = input2 + 2.0;
Expand Down Expand Up @@ -874,7 +874,7 @@ void testGPU_FusionSimplePWise() {
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};

torch::jit::fuser::cuda::compileKernel(fusion, prog);
torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);

at::Tensor tv2_ref = input2 + 2.0;
Expand Down Expand Up @@ -926,7 +926,7 @@ void testGPU_FusionExecKernel() {
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};

torch::jit::fuser::cuda::compileKernel(fusion, prog);
torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);

at::Tensor check = at::full({1, 128}, 4, options);
Expand Down
45 changes: 41 additions & 4 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,57 @@ def t(x, y, z, q):
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_scalar_input(self):
def t(x, y, z):
# type: (Tensor, Tensor, float) -> Tensor
def t(x : torch.Tensor, y : torch.Tensor, z : float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
y = y.expand(4, 8, 32, 32)
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_broadcasting(self):
def t(x : torch.Tensor, y : torch.Tensor, z : float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_broadcasting_multiple_output_shape(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the JIT IR look like in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a graph with branches and joined at the end.

graph(%x.1 : Tensor,
      %y.1 : Tensor,
      %z.1 : Tensor):
  %16 : None = prim::Constant()
  %5 : int = prim::Constant[value=1]()
  %4 : int = prim::Constant[value=12]() # test_jit_cuda_fuser.py:125:20
  %o.1 : Tensor = aten::add(%x.1, %4, %5) # test_jit_cuda_fuser.py:125:16
  %o1.1 : Tensor = aten::add(%o.1, %y.1, %5) # test_jit_cuda_fuser.py:126:17
  %o2.1 : Tensor = aten::add(%o.1, %z.1, %5) # test_jit_cuda_fuser.py:127:17
  %17 : Tensor = aten::sum(%o1.1, %16) # test_jit_cuda_fuser.py:128:17
  %20 : Tensor = aten::sum(%o2.1, %16) # test_jit_cuda_fuser.py:128:28
  %oo.1 : Tensor = aten::add(%17, %20, %5) # test_jit_cuda_fuser.py:128:17
  return (%oo.1)

After profiling, if you ignore the bailout node, you can say all three add node has different output size, hence they are not fused with each other.

graph(%x.1 : Tensor,
      %y.1 : Tensor,
      %z.1 : Tensor):
  %5 : None = prim::Constant()
  %4 : int = prim::Constant[value=1]()
  %3 : int = prim::Constant[value=12]() # test_jit_cuda_fuser.py:125:20
  %37 : int = prim::BailoutTemplate_0()
  %32 : Float(4, 32, 32) = prim::BailOut[index=0](%37, %z.1, %x.1, %y.1)
  %33 : Float(2, 32, 32) = prim::BailOut[index=1](%37, %y.1, %x.1, %32)
  %34 : Float(32, 32) = prim::BailOut[index=2](%37, %x.1, %33, %32)
  %o.1 : Float(32, 32) = aten::add(%34, %3, %4) # test_jit_cuda_fuser.py:125:16
  %o1.1 : Float(2, 32, 32) = aten::add(%o.1, %33, %4) # test_jit_cuda_fuser.py:126:17
  %o2.1 : Float(4, 32, 32) = aten::add(%o.1, %32, %4) # test_jit_cuda_fuser.py:127:17
  %15 : Tensor = aten::sum(%o1.1, %5) # test_jit_cuda_fuser.py:128:17
  %35 : Float() = prim::BailOut[index=3](%37, %15, %o2.1)
  %17 : Tensor = aten::sum(%o2.1, %5) # test_jit_cuda_fuser.py:128:28
  %36 : Float() = prim::BailOut[index=4](%37, %17, %35)
  %oo.1 : Float() = aten::add(%35, %36, %4) # test_jit_cuda_fuser.py:128:17
  return (%oo.1)

def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
#can't fuse it now
self.assertFalse(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))

if __name__ == '__main__':
run_tests()
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ libtorch_cuda_sources = [
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
"torch/csrc/jit/codegen/cuda/parser.cpp",
Expand Down
8 changes: 6 additions & 2 deletions torch/csrc/jit/codegen/cuda/code_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,23 @@ void CodeWrite::header() {
for (Val* val : vals) {
switch (val->getValType().value()) {
case (ValType::TensorView):
{
switch (val->getDataType().value()) {
case (DataType::Float):
os << "Tensor<float> T";
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I moved this to ir_iostream. I will rebase on code_lowering and fix this.

os << "Tensor<float, ";
break;
case (DataType::Int):
os << "Tensor<int> T";
os << "Tensor<int, ";
break;
default:
TORCH_CHECK(
false,
"CodeWrite::header() found an input to the fusion of unexpected val type.");
}

os << static_cast<const TensorView*>(val)->getRootDomain()->size() << "> T";
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, nice usage.

break;
}
case (ValType::Scalar):
switch (val->getDataType().value()) {
case (DataType::Float):
Expand Down
Loading