diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 4c8482df3780..d335e0ae654d 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -2835,6 +2835,104 @@ void testGPU_FusionSimpleBCast() { TORCH_CHECK(t4.allclose(cg_output)); } + + { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + std::vector dom; + dom.push_back(new IterDomain( + new Int(0), + new Int(1), + ParallelType::Serial, + IterType::BroadcastWithStride)); + dom.push_back(new IterDomain(new Int(0), new Int())); + TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); + + TensorView* tv1 = makeDummyTensor(3); + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + + tv2->merge(0); + tv2->merge(0); + + fusion.addOutput(tv2); + + tv0->computeAt(tv2, -1); + tv1->computeAt(tv2, -1); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + + constexpr int x = 63, y = 33, z = 15; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({1, z}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + at::Tensor cg_output = at::empty({x, y, z}, options); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0, t1}, {cg_output}); + + auto t2 = t0.add(t1); + + TORCH_CHECK(t2.allclose(cg_output)); + } + + // Disabling the tests because of failing test on broadcast indexing + // { + // Fusion fusion; + // FusionGuard fg(&fusion); + + // // Set up your input tensor views + // std::vector dom; + // dom.push_back(new IterDomain(new Int(0), new Int())); + // dom.push_back(new IterDomain( + // new Int(0), + // new Int(1), + // ParallelType::Serial, + // IterType::BroadcastWithStride)); + // TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); + + // TensorView* tv1 = makeDummyTensor(3); + // fusion.addInput(tv0); + // fusion.addInput(tv1); + + // TensorView* tv2 = add(tv0, tv1); + + // tv2->merge(0); + // tv2->merge(0); + + // fusion.addOutput(tv2); + + // tv0->computeAt(tv2, -1); + // tv1->computeAt(tv2, -1); + + // tv2->axis(0)->parallelize(ParallelType::BIDx); + + // constexpr int x = 63, y = 33, z = 15; + + // auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, + // 0); + + // at::Tensor t0 = at::randn({y, 1}, options); + // at::Tensor t1 = at::randn({x, y, z}, options); + + // at::Tensor cg_output = at::empty({x, y, z}, options); + + // torch::jit::fuser::cuda::FusionExecutor fe; + // fe.compileFusion(&fusion); + // fe.runFusion({t0, t1}, {cg_output}); + + // auto t2 = t0.add(t1); + + // TORCH_CHECK(t2.allclose(cg_output)); + // } } // Test a simple Gemm but also play around with fusion executor features diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 816bdeb17076..25c377feabc2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -200,6 +200,10 @@ BroadcastOp::BroadcastOp(Val* _out, Val* _in) if (!dom->isBroadcast()) ndims++; + for (auto dom : in()->as()->getRootDomain()) + if (dom->isBroadcast()) + ndims++; + TORCH_INTERNAL_ASSERT( ndims == (int)TensorDomain::noReductions(