diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index c418ee54e457..589248880178 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -98,6 +98,7 @@ if(USE_CUDA) list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_fused_reduction.cpp) list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_shift.cpp) list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_tensorcore.cpp) + list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_view.cpp) endif() add_executable(test_jit diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 7bbd401d6801..5e34684e3acb 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -14716,472 +14716,6 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); } -TEST_F(NVFuserTest, FusionViewDtypeSameSizeOutput_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector input_shape{2, 10, 40}; - - TensorView* x = makeSymbolicTensor(input_shape.size(), DataType::Float); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - auto x_view = view(x_add_bias, DataType::Int32); - fusion.addOutput(x_view); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(input_shape, options); - std::vector aten_inputs = {at_x, at_bias}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto outputs = fe.runFusion(aten_inputs, lparams); - - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at_x_add_bias.view(at::ScalarType::Int); - - testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionViewDtypeFailMismatchSize_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector input_shape{2, 10, 40}; - - TensorView* x = makeSymbolicTensor(input_shape.size(), DataType::Float); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - ASSERT_ANY_THROW(view(x_add_bias, DataType::Int)); -} - -TEST_F(NVFuserTest, FusionViewOutput_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector input_shape{2, 10, 40}; - std::vector output_shape{2, 10, 4, 10}; - - TensorView* x = makeSymbolicTensor(input_shape.size()); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - auto x_view = view(x_add_bias, input_shape, output_shape); - fusion.addOutput(x_view); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(input_shape, options); - std::vector aten_inputs = {at_x, at_bias}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto outputs = fe.runFusion(aten_inputs, lparams); - - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at::native::view(at_x_add_bias, output_shape); - - testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionViewFailMismatchSize_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // The number of elements in input and output shapes do not match, - // so this view transformation is invalid. - // 2 * 10 * 40 != 2 * 50 * 4 * 10 - - std::vector input_shape{2, 10, 40}; - std::vector output_shape{2, 50, 4, 10}; - - TensorView* x = makeSymbolicTensor(input_shape.size()); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); -} - -TEST_F(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Only one dimension can be inferred in the output shape. - // Otherwise, the size of the dimensions is ambiguous. - std::vector input_shape{2, 10, 40}; - std::vector output_shape{2, -1, 4, -1}; - - TensorView* x = makeSymbolicTensor(input_shape.size()); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); -} - -void reductionViewAddFusion( - std::vector& input_shape, - std::vector& output_shape, - bool view_before_reduction) { - constexpr int kReductionAxis = -1; - - // Drop size for reduction axis from view_shape - std::vector view_shape; - { - const auto kAxis = (kReductionAxis < 0) - ? (kReductionAxis + input_shape.size()) - : kReductionAxis; - for (auto i : c10::irange(input_shape.size())) { - if (view_before_reduction || i != kAxis) { - view_shape.push_back(input_shape[i]); - } - } - } - - auto bias_shape = (view_before_reduction) ? input_shape : output_shape; - for (auto has_implicit_broadcast : {false, true}) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - TensorView* x = (has_implicit_broadcast) - ? makeConcreteTensor(input_shape) - : makeSymbolicTensor(input_shape.size()); - TensorView* bias = (has_implicit_broadcast) - ? makeConcreteTensor(bias_shape) - : makeSymbolicTensor(bias_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto tv1 = - (view_before_reduction) ? add(x, bias) : sum(x, {kReductionAxis}); - auto x_view = view(tv1, view_shape, output_shape); - auto y = (view_before_reduction) ? sum(x_view, {kReductionAxis}) - : add(x_view, bias); - fusion.addOutput(y); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(bias_shape, options); - std::vector aten_inputs = {at_x, at_bias}; - - FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); - auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); - - auto at_tv1 = (view_before_reduction) ? (at_x + at_bias) - : at::sum(at_x, kReductionAxis); - auto at_x_view = at::native::view(at_tv1, output_shape); - auto at_y = (view_before_reduction) ? at::sum(at_x_view, kReductionAxis) - : at::add(at_x_view, at_bias); - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionViewReductionShmoo_CUDA) { - typedef std::vector shape; - typedef std::pair view_example; - - std::vector view_before_examples = { - {{19, 12, 7, 99}, {19, 3, 2772}}, - {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, - // Incorrect Result - Broadcast Issue - Pointwise - // {{3, 17, 80, 1}, {51, 2, 4, 1, 10}}, - // {{3, 17, 80, 1, 9}, {51, 2, 4, 1, 10, 9}}, - {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 5, 1}}, - {{22, 22, 2}, {22, 11, 1, 1, 4}}, - {{37, 9, 7, 6, 10}, {333, 2, 2, 3, 35}}, - {{1, 1, 333, 1}, {1, 1, 333, 1}}, - {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, 8}}, - {{1, 333, 1}, {1, 37, 9, 1}}, - {{1, 333}, {1, 1, 1, 111, 1, 3}}, - {{22, 1, 22, 1}, {484}}, - {{1, 333, 1}, {333}}, - // Incorrect Result - Broadcast Issue - Reduction - {{1, 27454, 1, 2}, {1, 7844, 1, 7}}, - {{1, 7844, 1, 7}, {1, 27454, 2}}}; - - for (auto e : view_before_examples) { - reductionViewAddFusion(e.first, e.second, true /* view_before_reduction */); - } - - std::vector view_after_examples = { - {{19, 12, 7, 99}, {19, 3, 28}}, - {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 28}}, - {{3, 17, 80, 1}, {51, 1, 2, 4, 10}}, - {{3, 17, 80, 1, 9}, {51, 1, 2, 4, 10}}, - {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 1}}, - {{22, 22, 2}, {22, 11, 1, 1, 2}}, - {{37, 9, 7, 6, 10}, {333, 2, 21}}, - {{1, 1, 333, 1}, {1, 1, 333, 1}}, - {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1}}, - {{1, 333, 1}, {1, 37, 9, 1}}, - {{22, 1, 22, 1}, {484}}, - {{1, 333, 1}, {333}}, - {{1, 27454, 1, 2}, {1, 3922, 1, 7}}, - {{1, 7844, 1, 7}, {1, 1961, 4}}}; - - for (auto e : view_after_examples) { - reductionViewAddFusion( - e.first, e.second, false /* view_before_reduction */); - } -} - -TEST_F(NVFuserTest, FusionViewFailPersistent_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - // View is only supported by the pointwise scheduler, - // so it should fail with any persistent normalization operations - std::vector input_shape{2, 10, 40}; - std::vector output_shape{2, 10, 2, 20}; - - TensorView* x = makeSymbolicTensor(input_shape.size()); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - auto x_view = view(x_add_bias, input_shape, output_shape); - auto x_softmax = softmax(x_view, -1); - - fusion.addOutput(x_softmax); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(input_shape, options); - - FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); - ASSERT_ANY_THROW(fusion_executor_cache.runFusionWithInputs({at_x, at_bias})); -} - -void addViewGeluFusion( - std::vector& input_shape, - std::vector& output_shape) { - for (auto has_implicit_broadcast : {false, true}) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* x = (has_implicit_broadcast) - ? makeConcreteTensor(input_shape) - : makeSymbolicTensor(input_shape.size()); - TensorView* bias = (has_implicit_broadcast) - ? makeConcreteTensor(input_shape) - : makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - auto x_view = view(x_add_bias, input_shape, output_shape); - auto y = gelu(x_view); - fusion.addOutput(y); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(input_shape, options); - std::vector aten_inputs = {at_x, at_bias}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto outputs = fe.runFusion(aten_inputs, lparams); - - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at::native::view(at_x_add_bias, output_shape); - auto at_y = at::gelu(at_x_view); - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionViewSplit_CUDA) { - std::vector input_shape{80}; - std::vector output_shape{2, 4, 10}; - addViewGeluFusion(input_shape, output_shape); -} - -TEST_F(NVFuserTest, FusionViewBroadcast_CUDA) { - std::vector input_shape{80}; - std::vector output_shape{1, 80}; - addViewGeluFusion(input_shape, output_shape); -} - -TEST_F(NVFuserTest, FusionViewMerge_CUDA) { - std::vector input_shape{2, 40, 7}; - std::vector output_shape{560}; - addViewGeluFusion(input_shape, output_shape); -} - -TEST_F(NVFuserTest, FusionViewAllShmoo_CUDA) { - typedef std::vector shape; - typedef std::pair view_example; - - std::vector examples = { - {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, - {{3, 17, 80, 1}, {51, 1, 2, 4, 10}}, - {{3, 17, 80, 1, 9}, {51, 1, 2, 4, 10, 9}}, - {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 5, 1}}, - {{22, 22, 2}, {22, 11, 1, 1, 4}}, - {{37, 9, 7, 6, 10}, {333, 2, 2, 3, 35}}, - {{1, 1, 333, 1}, {1, 1, 333, 1}}, - {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, 8}}, - {{1, 333, 1}, {1, 37, 9, 1}}, - {{1, 333}, {1, 1, 1, 111, 1, 3}}, - {{22, 1, 22, 1}, {484}}, - {{1, 333, 1}, {333}}, - {{1, 27454, 1, 2}, {1, 7844, 1, 7}}, - {{1, 7844, 1, 7}, {1, 27454, 2}}}; - - for (auto e : examples) { - addViewGeluFusion(e.first, e.second); - } -} - -TEST_F(NVFuserTest, FusionViewInferShmoo_CUDA) { - typedef std::vector shape; - typedef std::pair view_example; - - std::vector examples = { - {{1, 19, 1, 12, 7, 1, 99}, {1, 19, -1, 3, 2772}}, - {{3, 17, 80, 1}, {51, 1, 2, 4, -1}}, - {{3, 17, 80, 1, 9}, {-1, 1, 2, 4, 10, 9}}, - {{2, 3, 4, 5}, {1, 6, 1, -1, 2, 5, 1}}, - {{22, 22, 2}, {22, -1, 1, 1, 4}}, - {{37, 9, 7, 6, 10}, {333, 2, -1, 3, 35}}, - {{1, 1, 333, 1}, {1, 1, -1, 1}}, - {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, -1}}, - {{1, 333, 1}, {1, 37, -1, 1}}, - {{1, 333}, {1, 1, 1, -1, 1, 3}}, - {{22, 1, 22, 1}, {-1}}, - {{1, 333, 1}, {-1}}, - {{1, 27454, 1, 2}, {1, 7844, 1, -1}}, - {{1, 7844, 1, 7}, {1, -1, 2}}}; - - for (auto e : examples) { - addViewGeluFusion(e.first, e.second); - } -} - -void geluViewAddFusion( - std::vector input_shape, - std::vector output_shape) { - for (auto hasImplicitBroadcast : {false, true}) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* x = (hasImplicitBroadcast) - ? makeConcreteTensor(input_shape) - : makeSymbolicTensor(input_shape.size()); - TensorView* bias = (hasImplicitBroadcast) - ? makeConcreteTensor(output_shape) - : makeSymbolicTensor(output_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_gelu = gelu(x); - auto x_view = view(x_gelu, input_shape, output_shape); - auto y = add(x_view, bias); - fusion.addOutput(y); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(output_shape, options); - std::vector aten_inputs = {at_x, at_bias}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto outputs = fe.runFusion(aten_inputs, lparams); - - auto at_x_gelu = at::gelu(at_x); - auto at_x_view = at::native::view(at_x_gelu, output_shape); - auto at_y = at_x_view + at_bias; - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionViewStride_CUDA) { - typedef std::vector shape; - typedef std::pair view_example; - - std::vector examples = { - {{1, 27454, 2}, {1, 7844, 7}}, - {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, - {{1, 7844, 1, 7}, {1, 27454, 2}}}; - - for (auto e : examples) { - geluViewAddFusion(e.first, e.second); - } -} - -void geluViewBinaryAddFusion( - std::vector input_shape1, - std::vector input_shape2, - std::vector output_shape) { - for (auto hasImplicitBroadcast : {false, true}) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* x = (hasImplicitBroadcast) - ? makeConcreteTensor(input_shape1) - : makeSymbolicTensor(input_shape1.size()); - TensorView* bias = (hasImplicitBroadcast) - ? makeConcreteTensor(input_shape2) - : makeSymbolicTensor(input_shape2.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_gelu = gelu(x); - auto x_view = view(x_gelu, input_shape1, output_shape); - auto bias_view = view(bias, input_shape2, output_shape); - auto y = add(x_view, bias_view); - fusion.addOutput(y); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape1, options); - at::Tensor at_bias = at::randn(input_shape2, options); - std::vector aten_inputs = {at_x, at_bias}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto outputs = fe.runFusion(aten_inputs, lparams); - - auto at_x_gelu = at::gelu(at_x); - auto at_x_view = at::native::view(at_x_gelu, output_shape); - auto at_bias_view = at::native::view(at_bias, output_shape); - auto at_y = at_x_view + at_bias_view; - - testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionViewBinary_CUDA) { - geluViewBinaryAddFusion({27454, 2}, {54908}, {7844, 7}); -} - TEST_F(NVFuserTest, FusionVectorization1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/test/cpp/jit/test_gpu_view.cpp b/test/cpp/jit/test_gpu_view.cpp new file mode 100644 index 000000000000..7bc4a6576014 --- /dev/null +++ b/test/cpp/jit/test_gpu_view.cpp @@ -0,0 +1,628 @@ +#if defined(USE_CUDA) +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// fuser and IR parser +#include +#include + +#include "test_gpu_validator.h" + +#include +#include +#include + +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; +using namespace at::indexing; + +namespace { + +// Make a tensor that is known to be fully contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder() + .ndims(ndims) + .dtype(dtype) + .contiguity(std::vector(ndims, true)) + .build(); +} + +// Make a tensor that is known to be non-contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); +} + +// Make a non-contiguous tensor of compile-time known sizes +TensorView* makeConcreteTensor( + std::vector shape, + DataType dtype = DataType::Float) { + return TensorViewBuilder().shape(shape).dtype(dtype).build(); +} + +} // namespace + +TEST_F(NVFuserTest, FusionViewDtypeSameSizeOutput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape{2, 10, 40}; + + TensorView* x = makeSymbolicTensor(input_shape.size(), DataType::Float); + TensorView* bias = makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + auto x_view = view(x_add_bias, DataType::Int32); + fusion.addOutput(x_view); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(input_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto outputs = fe.runFusion(aten_inputs, lparams); + + auto at_x_add_bias = at_x + at_bias; + auto at_x_view = at_x_add_bias.view(at::ScalarType::Int); + + testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionViewDtypeFailMismatchSize_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape{2, 10, 40}; + + TensorView* x = makeSymbolicTensor(input_shape.size(), DataType::Float); + TensorView* bias = makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + ASSERT_ANY_THROW(view(x_add_bias, DataType::Int)); +} + +TEST_F(NVFuserTest, FusionViewRfactorExtentReplacement_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + auto tv2 = view(tv0, {12, 8}, {4, 3, 8}); + auto tv3 = sum(tv2, {-1}); + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv1, tv4); + fusion->addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({12, 8}, options); + auto t1 = at::randn({4, 3}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1}); + + auto ref = at::native::view(t0, {4, 3, 8}).sum({-1}) + 1 + t1; + + testValidate( + executor_cache.fusion(), cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionViewOutput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape{2, 10, 40}; + std::vector output_shape{2, 10, 4, 10}; + + TensorView* x = makeSymbolicTensor(input_shape.size()); + TensorView* bias = makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + auto x_view = view(x_add_bias, input_shape, output_shape); + fusion.addOutput(x_view); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(input_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto outputs = fe.runFusion(aten_inputs, lparams); + + auto at_x_add_bias = at_x + at_bias; + auto at_x_view = at::native::view(at_x_add_bias, output_shape); + + testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionViewFailMismatchSize_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // The number of elements in input and output shapes do not match, + // so this view transformation is invalid. + // 2 * 10 * 40 != 2 * 50 * 4 * 10 + + std::vector input_shape{2, 10, 40}; + std::vector output_shape{2, 50, 4, 10}; + + TensorView* x = makeSymbolicTensor(input_shape.size()); + TensorView* bias = makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); +} + +TEST_F(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Only one dimension can be inferred in the output shape. + // Otherwise, the size of the dimensions is ambiguous. + std::vector input_shape{2, 10, 40}; + std::vector output_shape{2, -1, 4, -1}; + + TensorView* x = makeSymbolicTensor(input_shape.size()); + TensorView* bias = makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); +} + +void reductionViewAddFusion( + std::vector& input_shape, + std::vector& output_shape, + bool view_before_reduction) { + constexpr int kReductionAxis = -1; + + // Drop size for reduction axis from view_shape + std::vector view_shape; + { + const auto kAxis = (kReductionAxis < 0) + ? (kReductionAxis + input_shape.size()) + : kReductionAxis; + for (auto i : c10::irange(input_shape.size())) { + if (view_before_reduction || i != kAxis) { + view_shape.push_back(input_shape[i]); + } + } + } + + auto bias_shape = (view_before_reduction) ? input_shape : output_shape; + for (auto has_implicit_broadcast : {false, true}) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* x = (has_implicit_broadcast) + ? makeConcreteTensor(input_shape) + : makeSymbolicTensor(input_shape.size()); + TensorView* bias = (has_implicit_broadcast) + ? makeConcreteTensor(bias_shape) + : makeSymbolicTensor(bias_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto tv1 = + (view_before_reduction) ? add(x, bias) : sum(x, {kReductionAxis}); + auto x_view = view(tv1, view_shape, output_shape); + auto y = (view_before_reduction) ? sum(x_view, {kReductionAxis}) + : add(x_view, bias); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(bias_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + + auto at_tv1 = (view_before_reduction) ? (at_x + at_bias) + : at::sum(at_x, kReductionAxis); + auto at_x_view = at::native::view(at_tv1, output_shape); + auto at_y = (view_before_reduction) ? at::sum(at_x_view, kReductionAxis) + : at::add(at_x_view, at_bias); + + testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionViewReductionShmoo_CUDA) { + typedef std::vector shape; + typedef std::pair view_example; + + std::vector view_before_examples = { + {{19, 12, 7, 99}, {19, 3, 2772}}, + {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, + // Incorrect Result - Broadcast Issue - Pointwise + // {{3, 17, 80, 1}, {51, 2, 4, 1, 10}}, + // {{3, 17, 80, 1, 9}, {51, 2, 4, 1, 10, 9}}, + {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 5, 1}}, + {{22, 22, 2}, {22, 11, 1, 1, 4}}, + {{37, 9, 7, 6, 10}, {333, 2, 2, 3, 35}}, + {{1, 1, 333, 1}, {1, 1, 333, 1}}, + {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, 8}}, + {{1, 333, 1}, {1, 37, 9, 1}}, + {{1, 333}, {1, 1, 1, 111, 1, 3}}, + {{22, 1, 22, 1}, {484}}, + {{1, 333, 1}, {333}}, + // Incorrect Result - Broadcast Issue - Reduction + {{1, 27454, 1, 2}, {1, 7844, 1, 7}}, + {{1, 7844, 1, 7}, {1, 27454, 2}}}; + + for (auto e : view_before_examples) { + reductionViewAddFusion(e.first, e.second, true /* view_before_reduction */); + } + + std::vector view_after_examples = { + {{19, 12, 7, 99}, {19, 3, 28}}, + {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 28}}, + {{3, 17, 80, 1}, {51, 1, 2, 4, 10}}, + {{3, 17, 80, 1, 9}, {51, 1, 2, 4, 10}}, + {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 1}}, + {{22, 22, 2}, {22, 11, 1, 1, 2}}, + {{37, 9, 7, 6, 10}, {333, 2, 21}}, + {{1, 1, 333, 1}, {1, 1, 333, 1}}, + {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1}}, + {{1, 333, 1}, {1, 37, 9, 1}}, + {{22, 1, 22, 1}, {484}}, + {{1, 333, 1}, {333}}, + {{1, 27454, 1, 2}, {1, 3922, 1, 7}}, + {{1, 7844, 1, 7}, {1, 1961, 4}}}; + + for (auto e : view_after_examples) { + reductionViewAddFusion( + e.first, e.second, false /* view_before_reduction */); + } +} + +void persistentViewAddFusion( + std::vector& input_shape, + std::vector& output_shape, + bool view_before_persistent) { + constexpr int kAxis = -1; + + auto bias_shape = (view_before_persistent) ? input_shape : output_shape; + for (auto has_implicit_broadcast : {false, true}) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* x = (has_implicit_broadcast) + ? makeConcreteTensor(input_shape) + : makeSymbolicTensor(input_shape.size()); + TensorView* bias = (has_implicit_broadcast) + ? makeConcreteTensor(bias_shape) + : makeSymbolicTensor(bias_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto tv1 = (view_before_persistent) ? add(x, bias) : softmax(x, kAxis); + auto x_view = view(tv1, input_shape, output_shape); + auto y = + (view_before_persistent) ? softmax(x_view, kAxis) : add(x_view, bias); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(bias_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + + auto at_tv1 = (view_before_persistent) + ? (at_x + at_bias) + : at::_softmax(at_x, kAxis, false /* half_to_float */); + auto at_x_view = at::native::view(at_tv1, output_shape); + auto at_y = (view_before_persistent) + ? at::_softmax(at_x_view, kAxis, false /* half_to_float */) + : at::add(at_x_view, at_bias); + + testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionViewPersistentShmoo_CUDA) { + typedef std::vector shape; + typedef std::pair view_example; + + std::vector view_examples = { + {{19, 12, 7, 99}, {19, 3, 2772}}, + {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, + // Incorrect Result - Broadcast Issue - Pointwise + // {{3, 17, 80, 1}, {51, 2, 4, 1, 10}}, + // {{3, 17, 80, 1, 9}, {51, 2, 4, 1, 10, 9}}, + {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 5, 1}}, + {{22, 22, 2}, {22, 11, 1, 1, 4}}, + {{37, 9, 7, 6, 10}, {333, 2, 2, 3, 35}}, + {{1, 1, 333, 1}, {1, 1, 333, 1}}, + {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, 8}}, + {{1, 333, 1}, {1, 37, 9, 1}}, + {{1, 333}, {1, 1, 1, 111, 1, 3}}, + {{22, 1, 22, 1}, {484}}, + {{1, 333, 1}, {333}}, + // Incorrect Result - Broadcast Issue - Reduction + {{1, 27454, 1, 2}, {1, 7844, 1, 7}}, + {{1, 7844, 1, 7}, {1, 27454, 2}}}; + + for (auto e : view_examples) { + persistentViewAddFusion( + e.first, e.second, true /* view_before_persistent */); + } + + // Disabled: How to select post-view concrete ID? + // for (auto e : view_examples) { + // persistentViewAddFusion(e.first, e.second, false /* view_before_persistent + // */); + // } +} + +void addViewGeluFusion( + std::vector& input_shape, + std::vector& output_shape) { + for (auto has_implicit_broadcast : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = (has_implicit_broadcast) + ? makeConcreteTensor(input_shape) + : makeSymbolicTensor(input_shape.size()); + TensorView* bias = (has_implicit_broadcast) + ? makeConcreteTensor(input_shape) + : makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + auto x_view = view(x_add_bias, input_shape, output_shape); + auto y = gelu(x_view); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(input_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto outputs = fe.runFusion(aten_inputs, lparams); + + auto at_x_add_bias = at_x + at_bias; + auto at_x_view = at::native::view(at_x_add_bias, output_shape); + auto at_y = at::gelu(at_x_view); + + testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionViewSplit_CUDA) { + std::vector input_shape{80}; + std::vector output_shape{2, 4, 10}; + addViewGeluFusion(input_shape, output_shape); +} + +TEST_F(NVFuserTest, FusionViewBroadcast_CUDA) { + std::vector input_shape{80}; + std::vector output_shape{1, 80}; + addViewGeluFusion(input_shape, output_shape); +} + +TEST_F(NVFuserTest, FusionViewMerge_CUDA) { + std::vector input_shape{2, 40, 7}; + std::vector output_shape{560}; + addViewGeluFusion(input_shape, output_shape); +} + +TEST_F(NVFuserTest, FusionViewAllShmoo_CUDA) { + typedef std::vector shape; + typedef std::pair view_example; + + std::vector examples = { + {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, + {{3, 17, 80, 1}, {51, 1, 2, 4, 10}}, + {{3, 17, 80, 1, 9}, {51, 1, 2, 4, 10, 9}}, + {{2, 3, 4, 5}, {1, 6, 1, 2, 2, 5, 1}}, + {{22, 22, 2}, {22, 11, 1, 1, 4}}, + {{37, 9, 7, 6, 10}, {333, 2, 2, 3, 35}}, + {{1, 1, 333, 1}, {1, 1, 333, 1}}, + {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, 8}}, + {{1, 333, 1}, {1, 37, 9, 1}}, + {{1, 333}, {1, 1, 1, 111, 1, 3}}, + {{22, 1, 22, 1}, {484}}, + {{1, 333, 1}, {333}}, + {{1, 27454, 1, 2}, {1, 7844, 1, 7}}, + {{1, 7844, 1, 7}, {1, 27454, 2}}}; + + for (auto e : examples) { + addViewGeluFusion(e.first, e.second); + } +} + +TEST_F(NVFuserTest, FusionViewInferShmoo_CUDA) { + typedef std::vector shape; + typedef std::pair view_example; + + std::vector examples = { + {{1, 19, 1, 12, 7, 1, 99}, {1, 19, -1, 3, 2772}}, + {{3, 17, 80, 1}, {51, 1, 2, 4, -1}}, + {{3, 17, 80, 1, 9}, {-1, 1, 2, 4, 10, 9}}, + {{2, 3, 4, 5}, {1, 6, 1, -1, 2, 5, 1}}, + {{22, 22, 2}, {22, -1, 1, 1, 4}}, + {{37, 9, 7, 6, 10}, {333, 2, -1, 3, 35}}, + {{1, 1, 333, 1}, {1, 1, -1, 1}}, + {{8, 1, 1, 8, 1, 8}, {8, 2, 4, 1, -1}}, + {{1, 333, 1}, {1, 37, -1, 1}}, + {{1, 333}, {1, 1, 1, -1, 1, 3}}, + {{22, 1, 22, 1}, {-1}}, + {{1, 333, 1}, {-1}}, + {{1, 27454, 1, 2}, {1, 7844, 1, -1}}, + {{1, 7844, 1, 7}, {1, -1, 2}}}; + + for (auto e : examples) { + addViewGeluFusion(e.first, e.second); + } +} + +void geluViewAddFusion( + std::vector input_shape, + std::vector output_shape) { + for (auto hasImplicitBroadcast : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = (hasImplicitBroadcast) + ? makeConcreteTensor(input_shape) + : makeSymbolicTensor(input_shape.size()); + TensorView* bias = (hasImplicitBroadcast) + ? makeConcreteTensor(output_shape) + : makeSymbolicTensor(output_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_gelu = gelu(x); + auto x_view = view(x_gelu, input_shape, output_shape); + auto y = add(x_view, bias); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(output_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto outputs = fe.runFusion(aten_inputs, lparams); + + auto at_x_gelu = at::gelu(at_x); + auto at_x_view = at::native::view(at_x_gelu, output_shape); + auto at_y = at_x_view + at_bias; + + testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionViewStride_CUDA) { + typedef std::vector shape; + typedef std::pair view_example; + + std::vector examples = { + {{1, 27454, 2}, {1, 7844, 7}}, + {{1, 19, 1, 12, 7, 1, 99}, {1, 19, 1, 3, 2772}}, + {{1, 7844, 1, 7}, {1, 27454, 2}}}; + + for (auto e : examples) { + geluViewAddFusion(e.first, e.second); + } +} + +void geluViewBinaryAddFusion( + std::vector input_shape1, + std::vector input_shape2, + std::vector output_shape) { + for (auto hasImplicitBroadcast : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = (hasImplicitBroadcast) + ? makeConcreteTensor(input_shape1) + : makeSymbolicTensor(input_shape1.size()); + TensorView* bias = (hasImplicitBroadcast) + ? makeConcreteTensor(input_shape2) + : makeSymbolicTensor(input_shape2.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_gelu = gelu(x); + auto x_view = view(x_gelu, input_shape1, output_shape); + auto bias_view = view(bias, input_shape2, output_shape); + auto y = add(x_view, bias_view); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape1, options); + at::Tensor at_bias = at::randn(input_shape2, options); + std::vector aten_inputs = {at_x, at_bias}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto outputs = fe.runFusion(aten_inputs, lparams); + + auto at_x_gelu = at::gelu(at_x); + auto at_x_view = at::native::view(at_x_gelu, output_shape); + auto at_bias_view = at::native::view(at_bias, output_shape); + auto at_y = at_x_view + at_bias_view; + + testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionViewBinary_CUDA) { + geluViewBinaryAddFusion({27454, 2}, {54908}, {7844, 7}); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index d3b7dcb33764..bfc76acdfccd 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -474,11 +474,6 @@ class TORCH_CUDA_CU_API TensorView : public Val { void setMaxProducer(unsigned int this_pos, bool decrease = false); - //! Create a new root domain and replacement TensorDomain. - //! If a new symbolic extent exists for the original iterDomain, - //! we create a new iterDomain. - void createReplacementDomain(const std::vector& domain_extents); - private: int normalizeAxisPos(int pos) const { if (pos < 0) { diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 491f4ce6376a..6abc8cce0723 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -499,6 +499,73 @@ std::vector getReductionOps(Fusion* fusion, bool ignore_trivial) { return red_ops; } +namespace { + +class ValReplacementMutator : private OptOutMutator { + public: + ValReplacementMutator( + Fusion* fusion, + const std::unordered_map& replacement_map) + : replacement_map_(replacement_map) { + FusionGuard fg(fusion); + + // Welford makes this a little annoying since it holds a count which is + // typically not used by anything else. If we don't grab that count, then it + // would be a tensorview that doesn't get updated extents. Therefore, first + // grab all leaves towards outputs and grab stmts from there. + auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); + for (auto stmt : stmts) { + mutate(stmt); + } + } + + private: + using OptOutMutator::mutate; + void mutate(Val* val) final { + if (replacement_map_.find(val) == replacement_map_.end()) { + return OptOutMutator::mutate(val); + } + auto replaced_val = replacement_map_.at(val); + registerMutation(val, replaced_val); + } + + std::vector allLeafOuts(Fusion* fusion) { + auto exprs = StmtSort::getExprs(fusion, true); + std::unordered_set inputs; + std::unordered_set outputs; + std::vector ordered_outputs; + for (auto expr : exprs) { + inputs.insert(expr->inputs().begin(), expr->inputs().end()); + outputs.insert(expr->outputs().begin(), expr->outputs().end()); + ordered_outputs.insert( + ordered_outputs.end(), + expr->outputs().begin(), + expr->outputs().end()); + } + for (auto input : inputs) { + outputs.erase(input); + } + + std::vector ordered_leaf_outs; + for (auto out : ordered_outputs) { + if (outputs.find(out) != outputs.end()) { + ordered_leaf_outs.push_back(out); + } + } + return ordered_leaf_outs; + } + + const std::unordered_map& replacement_map_; +}; + +} // namespace + +void replaceValue( + Fusion* fusion, + const std::unordered_map& replacement_map) { + ValReplacementMutator(fusion, replacement_map); +} + } // namespace ir_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index bbebfe797138..dd5c9dd13e83 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -12,6 +12,11 @@ namespace fuser { namespace cuda { namespace ir_utils { +// Replace values in fusion using ValReplacementMutator +void replaceValue( + Fusion*, + const std::unordered_map& replacement_map); + template class FilterIterator { public: diff --git a/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp index 582b6d91d067..beec550e537f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp @@ -147,61 +147,6 @@ std::unordered_map getSimplificationMap(Fusion* fusion) { return extent_to_min_input_id_extent; } -std::vector allLeafOuts(Fusion* fusion) { - auto exprs = StmtSort::getExprs(fusion, true); - std::unordered_set inputs; - std::unordered_set outputs; - std::vector ordered_outputs; - for (auto expr : exprs) { - inputs.insert(expr->inputs().begin(), expr->inputs().end()); - outputs.insert(expr->outputs().begin(), expr->outputs().end()); - ordered_outputs.insert( - ordered_outputs.end(), expr->outputs().begin(), expr->outputs().end()); - } - for (auto input : inputs) { - outputs.erase(input); - } - - std::vector ordered_leaf_outs; - for (auto out : ordered_outputs) { - if (outputs.find(out) != outputs.end()) { - ordered_leaf_outs.push_back(out); - } - } - return ordered_leaf_outs; -} - -class ValReplacementMutator : private OptOutMutator { - public: - ValReplacementMutator( - Fusion* fusion, - const std::unordered_map& replacement_map) - : replacement_map_(replacement_map) { - FusionGuard fg(fusion); - - // Welford makes this a little annoying since it holds a count which is - // typically not used by anything else. If we don't grab that count, then it - // would be a tensorview that doesn't get updated extents. Therefore, first - // grab all leaves towards outputs and grab stmts from there. - auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); - for (auto stmt : stmts) { - mutate(stmt); - } - } - - private: - using OptOutMutator::mutate; - void mutate(Val* val) final { - if (replacement_map_.find(val) == replacement_map_.end()) { - return OptOutMutator::mutate(val); - } - auto replaced_val = replacement_map_.at(val); - registerMutation(val, replaced_val); - } - - const std::unordered_map& replacement_map_; -}; - } // namespace void replaceSymbolicSizes(Fusion* fusion) { @@ -279,7 +224,7 @@ void replaceSymbolicSizes(Fusion* fusion) { } // Run mutation on the fusion with the tensor_dim_map - ValReplacementMutator(fusion, tensor_dim_map); + ir_utils::replaceValue(fusion, tensor_dim_map); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index dca0f1e85f1e..2c067b2a6bfd 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -148,37 +148,6 @@ TensorView::TensorView( "Function invalid for kernel container."); } -void TensorView::createReplacementDomain( - const std::vector& replacement_extents) { - TORCH_INTERNAL_ASSERT( - !replacement_extents.empty() && - getMaybeRFactorDomain().size() == replacement_extents.size()); - // Given an rfactor domain, create a new IterDomain. - // Otherwise, clone the previous IterDomain - size_t idx = 0; - std::vector new_root_domain(getMaybeRFactorDomain().size()); - for (const auto& id : getMaybeRFactorDomain()) { - if (replacement_extents[idx] != nullptr) { - new_root_domain[idx] = IrBuilder::create( - container(), - id->start(), - replacement_extents[idx], - id->stopOffset(), - id->getParallelType(), - id->getIterType()); - ++idx; - } else { - TORCH_INTERNAL_ASSERT(!id->isRFactorProduct()); - new_root_domain[idx++] = id->clone(); - } - } - - TORCH_INTERNAL_ASSERT( - new_root_domain.size() == domain()->contiguity().size()); - setDomain(IrBuilder::create( - container(), new_root_domain, domain()->contiguity())); -} - void TensorView::convertRfactorToRootDomain() { // For a given TensorView, does its domain (root / rfactor) contain any // concrete sized extents? @@ -191,8 +160,42 @@ void TensorView::convertRfactorToRootDomain() { return true; }; - const auto kThisIsConcreteTensor = is_concrete_tensor(this); + // Create a new root domain and replacement TensorDomain. + // Given an rfactor domain, create a new IterDomain. + // Otherwise, clone the previous IterDomain + auto createReplacementDomain = + [this](const std::vector& replacement_extents) { + TORCH_INTERNAL_ASSERT( + !replacement_extents.empty() && + getMaybeRFactorDomain().size() == replacement_extents.size()); + size_t idx = 0; + std::vector new_root_domain( + getMaybeRFactorDomain().size()); + for (const auto& id : getMaybeRFactorDomain()) { + if (replacement_extents[idx] != nullptr) { + new_root_domain[idx] = IrBuilder::create( + container(), + id->start(), + replacement_extents[idx], + id->stopOffset(), + id->getParallelType(), + id->getIterType()); + ++idx; + } else { + TORCH_INTERNAL_ASSERT(!id->isRFactorProduct()); + new_root_domain[idx++] = id->clone(); + } + } + + TORCH_INTERNAL_ASSERT( + new_root_domain.size() == domain()->contiguity().size()); + setDomain(IrBuilder::create( + container(), new_root_domain, domain()->contiguity())); + }; + std::vector rfactor_extents; + std::unordered_map replacement_map; + const auto kThisIsConcreteTensor = is_concrete_tensor(this); for (const auto& id : getMaybeRFactorDomain()) { if (id->isRFactorProduct()) { // Create new symbolic extents for rfactor iterDomains @@ -200,38 +203,15 @@ void TensorView::convertRfactorToRootDomain() { ? IrBuilder::create(container()) : id->extent(); rfactor_extents.push_back(domain_extent); + replacement_map.emplace(id->extent(), domain_extent); } else { rfactor_extents.push_back(nullptr); } } createReplacementDomain(rfactor_extents); - auto getBroadcastReplacementExtents = [&rfactor_extents](auto bcast_def) { - TORCH_INTERNAL_ASSERT(bcast_def != nullptr); - std::vector bcast_rfactor_extents; - size_t i = 0; - for (auto flag : bcast_def->getBroadcastDimFlags()) { - auto domain_extent = (flag) ? nullptr : rfactor_extents[i++]; - bcast_rfactor_extents.push_back(domain_extent); - } - return bcast_rfactor_extents; - }; - - for (auto expr : uses()) { - auto out_tv = ir_utils::getTvOutput(expr); - if (out_tv != nullptr) { - TORCH_INTERNAL_ASSERT(!out_tv->hasRFactor()); - TORCH_INTERNAL_ASSERT( - kThisIsConcreteTensor == is_concrete_tensor(out_tv)); - if (out_tv->isDefinitionType(ExprType::BroadcastOp)) { - auto bcast_def = out_tv->definition()->as(); - out_tv->createReplacementDomain( - getBroadcastReplacementExtents(bcast_def)); - } else { - out_tv->createReplacementDomain(rfactor_extents); - } - } - } + // Propagate new extent throughout fusion using ValReplacementMutator + ir_utils::replaceValue(fusion(), replacement_map); } TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)