diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index dadce1dc99e6..7ca697c50020 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1277,15 +1277,13 @@ int ceilDiv_(int a, int b) { void testGPU_FusionAdvancedComputeAt() { // Case 1 - /* - * tv1 = tv0 * 0.5 - * tv2 = tv1 * -1 - * tv3 = tv1 + 3 - * tv4 = tv1 * 2 - * tv5 = tv3 + tv2 - * tv6 = tv5 + tv4 - * tv7 = tv1 + tv4 - */ + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv1 + 3 + // tv4 = tv1 * 2 + // tv5 = tv3 + tv2 + // tv6 = tv5 + tv4 + // tv7 = tv1 + tv4 { Fusion fusion; FusionGuard fg(&fusion); @@ -1355,14 +1353,12 @@ void testGPU_FusionAdvancedComputeAt() { } // Case 2 - /* - * tv1 = tv0 * -1 - * tv2 = tv0 + 3 - * tv3 = tv0 * 2 - * tv4 = tv2 + tv1 - * tv5 = tv4 + tv3 - * tv6 = tv5 + tv3 - */ + // tv1 = tv0 * -1 + // tv2 = tv0 + 3 + // tv3 = tv0 * 2 + // tv4 = tv2 + tv1 + // tv5 = tv4 + tv3 + // tv6 = tv5 + tv3 { Fusion fusion; FusionGuard fg(&fusion); @@ -1550,6 +1546,44 @@ void testGPU_FusionAdvancedComputeAt() { TORCH_CHECK(at::allclose(outputs[0], t6), actual_kernel.str()); } + + // Case 5 + // tv2 = tv0 + 2.0 + // tv3 = tv1 * tv2 + { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeDummyTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, new Float(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv3->merge(0); + tv3->split(-1, 8); + tv3->split(-1, 4); + + tv2->computeAt(tv3, 1); + tv2->split(-1, 4); // Kernel will break without this split + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t0.add(2.0); + auto t3 = t1.mul(t2); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); + + TORCH_CHECK(at::allclose(outputs[0], t3)); + } } void testGPU_FusionScalarInputs() { @@ -2577,7 +2611,6 @@ void testGPU_FusionReduction4() { TensorView* tv2 = tv1->rFactor({-3}); tv0->computeAt(tv1, 1); - tv1->axis(0)->parallelize(ParallelType::BIDy); for (auto* val : fusion.vals()) { @@ -2600,7 +2633,10 @@ void testGPU_FusionReduction4() { fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); - TORCH_CHECK(aten_output.allclose(cg_output)); + TORCH_CHECK( + aten_output.allclose(cg_output, 1e-5, 1e-7), + "Error of: ", + aten_output.sub(cg_output).abs().max()); } void testGPU_FusionReduction5() { diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 55df035e2fb4..1b434916178b 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -220,9 +220,8 @@ unsigned int ComputeAt::forwardComputeAt_impl( } consumer_entry.setPassPosition(replay.second); - if ((consumer_entry.shouldSetComputeAt(replay.second) && - consumer != consumer_) || - (consumer == consumer_ && replay.second >= consumer_position_)) { + if (consumer_entry.shouldSetComputeAt(replay.second) && + consumer != consumer_) { consumer_entry.setComputeAtDomain(consumer->domain()); } diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h index c88caaee63a2..cdc41ddab51c 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h +++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h @@ -146,7 +146,7 @@ __device__ __inline__ float uniform(unsigned int x) { // Helper functions for Operations static auto code_helper_funcs = R"( -__device__ int ceilDiv(const int a, const int b) { +__device__ constexpr int ceilDiv(const int a, const int b) { return (a + b - 1) / b; } __device__ float clamp(const float x, const float minv, const float maxv) {