From 3b0955dbe80ed9d2c8523cec7302316badf3cb70 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 21 Feb 2023 17:23:28 -0800 Subject: [PATCH 01/14] adding python tests for gather/index_select cache; python black --- .lintrunner.toml | 23 ++++++++++ .../python_tests/test_python_frontend.py | 45 +------------------ 2 files changed, 24 insertions(+), 44 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 55a9474277300..722daafbb858e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -41,6 +41,29 @@ init_command = [ 'pyflakes==2.2.0', ] +[[linter]] +code = 'FLAKE8' +include_patterns = ['third_party/nvfuser/**/*.py'] +command = [ + 'python3', + 'tools/linter/adapters/flake8_linter.py', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'flake8==3.8.2', + 'flake8-bugbear==20.1.4', + 'flake8-comprehensions==3.3.0', + 'flake8-executable==2.0.4', + 'flake8-pyi==20.5.0', + 'mccabe==0.6.1', + 'pycodestyle==2.6.0', + 'pyflakes==2.2.0', +] + [[linter]] code = 'CLANGFORMAT' include_patterns = [ diff --git a/third_party/nvfuser/python_tests/test_python_frontend.py b/third_party/nvfuser/python_tests/test_python_frontend.py index fd82a5d593bff..a2d256e77e0a1 100644 --- a/third_party/nvfuser/python_tests/test_python_frontend.py +++ b/third_party/nvfuser/python_tests/test_python_frontend.py @@ -1082,49 +1082,6 @@ def fuser_function(correction): torch_result = torch.var_mean(inputs[0], [0, 1, 2], bool(correction)) self.assertEqual(fuser_result, torch_result) - def test_scalar_only_inputs(self): - # We don't allow scalar outputs, currently, - # so a tensor has to be returned - def fusion_func(fd: FusionDefinition): - s0 = fd.define_scalar() - s1 = fd.define_scalar() - s2 = fd.ops.add(s0, s1) - c0 = fd.define_constant(1.0, DataType.Float) - t3 = fd.ops.full(size=[2, 2], arg=c0, dtype=DataType.Float) - t4 = fd.ops.mul(t3, s2) - fd.add_output(t4) - - with FusionDefinition() as fd: - fusion_func(fd) - - # TODO: full is broken and does not print its proper definition - # Issue: https://github.com/csarofeen/pytorch/issues/2502 - nvf_out = fd.execute([2.0, 3.0]) - eager_out = torch.full([2, 2], 1.0) * 5.0 - self.assertEqual(eager_out, nvf_out[0]) - - def test_addcmul(self): - inputs = [ - torch.randn(4, device="cuda", dtype=torch.float32), - torch.randn(4, device="cuda", dtype=torch.float32), - torch.randn(4, device="cuda", dtype=torch.float32), - ] - - def fusion_func(fd: FusionDefinition): - t0 = fd.from_pytorch(inputs[0]) - t1 = fd.from_pytorch(inputs[1]) - t2 = fd.from_pytorch(inputs[2]) - c0 = fd.define_constant(0.1) - - t3 = fd.ops.addcmul(t0, t1, t2, c0) - - fd.add_output(t3) - - nvfout, _ = self.exec_nvfuser(fusion_func, inputs) - - torch_out = torch.addcmul(*inputs, value=0.1) - - self.assertEqual(nvfout[0], torch_out) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() From 9f46d081a07d85e37b84bba03a7e90182109eb6b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 21 Feb 2023 18:40:39 -0800 Subject: [PATCH 02/14] reverting lintrunner --- .lintrunner.toml | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 722daafbb858e..55a9474277300 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -41,29 +41,6 @@ init_command = [ 'pyflakes==2.2.0', ] -[[linter]] -code = 'FLAKE8' -include_patterns = ['third_party/nvfuser/**/*.py'] -command = [ - 'python3', - 'tools/linter/adapters/flake8_linter.py', - '--', - '@{{PATHSFILE}}' -] -init_command = [ - 'python3', - 'tools/linter/adapters/pip_init.py', - '--dry-run={{DRYRUN}}', - 'flake8==3.8.2', - 'flake8-bugbear==20.1.4', - 'flake8-comprehensions==3.3.0', - 'flake8-executable==2.0.4', - 'flake8-pyi==20.5.0', - 'mccabe==0.6.1', - 'pycodestyle==2.6.0', - 'pyflakes==2.2.0', -] - [[linter]] code = 'CLANGFORMAT' include_patterns = [ From 756e8fbbe32ce5830854aa6404246d78827ffe29 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 14:14:32 -0800 Subject: [PATCH 03/14] patching empty & scalar tensor --- third_party/nvfuser/csrc/ops/arith.cpp | 8 ++++++-- .../nvfuser/csrc/python_frontend/python_bindings.cpp | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index 2873512f453fb..6606f7a92126a 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -137,7 +137,7 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { TensorDomain::noReductions(index_tv->getMaybeRFactorDomain()); size_t n_dims = lookup_dom.size(); TORCH_CHECK(n_dims > 0, "index_select can not be applied to 0d tensor."); - TORCH_CHECK(index_dom.size() == 1, "index array must be 1d tensor."); + TORCH_CHECK(index_dom.size() <= 1, "index array must be 1d tensor."); if (dim < 0) { dim += lookup_dom.size(); @@ -156,8 +156,12 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { for (auto i : c10::irange(lookup_dom.size())) { if ((int)i != dim) { new_root.emplace_back(lookup_dom[i]->cloneWithoutRFactor()); - } else { + } else if (index_dom.size() == 1) { new_root.emplace_back(index_dom[0]->cloneWithoutRFactor()); + } else { + new_root.emplace_back( + IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), FusionGuard::getCurFusion()->oneVal()).iter_type(IterType::Broadcast).build()); + } } diff --git a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp index 414b6e8c74b7c..a86384a23d017 100644 --- a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp +++ b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp @@ -251,7 +251,7 @@ void initNvFuserPythonBindings(PyObject* module) { maybe_symbolic_sizes.reserve(sizes.size()); for (const auto i : c10::irange(sizes.size())) { TORCH_INTERNAL_ASSERT( - sizes[i] > 0, + sizes[i] >= 0, "Size of ", sizes[i], " is not supported in nvFuser. Expected size > 0."); From f7ca114a505dfe681083fe62e83216da540ec190 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 15:58:35 -0800 Subject: [PATCH 04/14] remove checks for scalar tensor --- third_party/nvfuser/csrc/ops/utils.cpp | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/third_party/nvfuser/csrc/ops/utils.cpp b/third_party/nvfuser/csrc/ops/utils.cpp index e2edc559fdf8b..bf841b3d9af7b 100644 --- a/third_party/nvfuser/csrc/ops/utils.cpp +++ b/third_party/nvfuser/csrc/ops/utils.cpp @@ -28,8 +28,8 @@ TensorView* maybe_broadcast_index_tv(TensorView* t, size_t dim, size_t rank) { size_t ori_rank = TensorDomain::noReductions(t->getMaybeRFactorDomain()).size(); TORCH_INTERNAL_ASSERT( - ori_rank == 1, - "The rank of index tensorview in index_select must be 1, but got ", + ori_rank <= 1, + "The rank of index tensorview in index_select must be less than or equal to 1, but got ", ori_rank); TORCH_INTERNAL_ASSERT( dim < rank, @@ -37,19 +37,11 @@ TensorView* maybe_broadcast_index_tv(TensorView* t, size_t dim, size_t rank) { dim, " >= ", rank); - std::vector bcast_dims(rank, false); - // broadcast outter on inp to match rank with other. - if (dim + 1 < rank) { - std::fill(bcast_dims.begin() + dim + 1, bcast_dims.end(), true); + std::vector bcast_dims(rank, true); + if (ori_rank == 1) { + bcast_dims[dim] = false; } - // broadcast inner on inp to match rank with other. - if (dim > 0) { - std::fill(bcast_dims.begin(), bcast_dims.begin() + dim, true); - } - if (dim + 1 < rank || dim > 0) { - t = broadcast(t, bcast_dims); - } - return t; + return broadcast(t, bcast_dims); } Val* simplifiedInt(Val* val) { From 3e7ae7cdc42d6421049d719ced8c46ddfd03bdc4 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 15:58:44 -0800 Subject: [PATCH 05/14] Revert "remove checks for scalar tensor" This reverts commit ef7b91d848bce06e69a872c8d5fa2318ab65d21c. --- third_party/nvfuser/csrc/ops/utils.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/third_party/nvfuser/csrc/ops/utils.cpp b/third_party/nvfuser/csrc/ops/utils.cpp index bf841b3d9af7b..e2edc559fdf8b 100644 --- a/third_party/nvfuser/csrc/ops/utils.cpp +++ b/third_party/nvfuser/csrc/ops/utils.cpp @@ -28,8 +28,8 @@ TensorView* maybe_broadcast_index_tv(TensorView* t, size_t dim, size_t rank) { size_t ori_rank = TensorDomain::noReductions(t->getMaybeRFactorDomain()).size(); TORCH_INTERNAL_ASSERT( - ori_rank <= 1, - "The rank of index tensorview in index_select must be less than or equal to 1, but got ", + ori_rank == 1, + "The rank of index tensorview in index_select must be 1, but got ", ori_rank); TORCH_INTERNAL_ASSERT( dim < rank, @@ -37,11 +37,19 @@ TensorView* maybe_broadcast_index_tv(TensorView* t, size_t dim, size_t rank) { dim, " >= ", rank); - std::vector bcast_dims(rank, true); - if (ori_rank == 1) { - bcast_dims[dim] = false; + std::vector bcast_dims(rank, false); + // broadcast outter on inp to match rank with other. + if (dim + 1 < rank) { + std::fill(bcast_dims.begin() + dim + 1, bcast_dims.end(), true); } - return broadcast(t, bcast_dims); + // broadcast inner on inp to match rank with other. + if (dim > 0) { + std::fill(bcast_dims.begin(), bcast_dims.begin() + dim, true); + } + if (dim + 1 < rank || dim > 0) { + t = broadcast(t, bcast_dims); + } + return t; } Val* simplifiedInt(Val* val) { From 9bc2b1e558a8d512d3effeb7ef26642e5e78521c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 17:53:50 -0800 Subject: [PATCH 06/14] Revert "patching empty & scalar tensor" This reverts commit df4d2789948d5e061367d07c424389a417e9cff8. --- third_party/nvfuser/csrc/ops/arith.cpp | 8 ++------ .../nvfuser/csrc/python_frontend/python_bindings.cpp | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index 6606f7a92126a..2873512f453fb 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -137,7 +137,7 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { TensorDomain::noReductions(index_tv->getMaybeRFactorDomain()); size_t n_dims = lookup_dom.size(); TORCH_CHECK(n_dims > 0, "index_select can not be applied to 0d tensor."); - TORCH_CHECK(index_dom.size() <= 1, "index array must be 1d tensor."); + TORCH_CHECK(index_dom.size() == 1, "index array must be 1d tensor."); if (dim < 0) { dim += lookup_dom.size(); @@ -156,12 +156,8 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { for (auto i : c10::irange(lookup_dom.size())) { if ((int)i != dim) { new_root.emplace_back(lookup_dom[i]->cloneWithoutRFactor()); - } else if (index_dom.size() == 1) { - new_root.emplace_back(index_dom[0]->cloneWithoutRFactor()); } else { - new_root.emplace_back( - IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), FusionGuard::getCurFusion()->oneVal()).iter_type(IterType::Broadcast).build()); - + new_root.emplace_back(index_dom[0]->cloneWithoutRFactor()); } } diff --git a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp index a86384a23d017..414b6e8c74b7c 100644 --- a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp +++ b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp @@ -251,7 +251,7 @@ void initNvFuserPythonBindings(PyObject* module) { maybe_symbolic_sizes.reserve(sizes.size()); for (const auto i : c10::irange(sizes.size())) { TORCH_INTERNAL_ASSERT( - sizes[i] >= 0, + sizes[i] > 0, "Size of ", sizes[i], " is not supported in nvFuser. Expected size > 0."); From b250630396a63c2acd77cd9ccc97fdd02a5bf56d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 17:54:58 -0800 Subject: [PATCH 07/14] allow empty tensor (numel()==0) --- third_party/nvfuser/csrc/python_frontend/python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp index 414b6e8c74b7c..20f3179433b45 100644 --- a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp +++ b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp @@ -251,10 +251,10 @@ void initNvFuserPythonBindings(PyObject* module) { maybe_symbolic_sizes.reserve(sizes.size()); for (const auto i : c10::irange(sizes.size())) { TORCH_INTERNAL_ASSERT( - sizes[i] > 0, + sizes[i] >= 0, "Size of ", sizes[i], - " is not supported in nvFuser. Expected size > 0."); + " is not supported in nvFuser. Expected size >= 0."); if (sizes[i] == 1) { maybe_symbolic_sizes.push_back(1); } else { From 9b9b912553ca7d2eb7bf9096423e8f52e32b82fc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 17:59:28 -0800 Subject: [PATCH 08/14] scalar tensor for index_select attempt 2 --- third_party/nvfuser/csrc/codegen.cpp | 6 ++++-- third_party/nvfuser/csrc/ops/arith.cpp | 11 +++++++++-- third_party/nvfuser/csrc/ops/arith.h | 2 +- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/third_party/nvfuser/csrc/codegen.cpp b/third_party/nvfuser/csrc/codegen.cpp index 396e44224c848..783b71392d41a 100644 --- a/third_party/nvfuser/csrc/codegen.cpp +++ b/third_party/nvfuser/csrc/codegen.cpp @@ -492,8 +492,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void handle(const TensorView*) final { - TORCH_INTERNAL_ASSERT(false, "Unreachable"); + void handle(const TensorView* tv) final { + // This allows us to access scalar tensor as if they are just scalar + TORCH_INTERNAL_ASSERT(tv->isZeroDim(), "TensorView can only be handled as scalar tensor"); + code_ << ir_utils::varName(tv) << "[0]"; } //! Utility for generating vectorized pointer access in ldsm and diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index 2873512f453fb..8063a1ea3c7a3 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -94,7 +94,7 @@ TensorView* unaryOp( return unaryOp(type, cast_v1)->as(); } -TensorView* select(TensorView* tv, int dim, Int* index) { +TensorView* select(TensorView* tv, int dim, Val* index) { auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); TORCH_CHECK(dom.size() > 0, "select can not be applied to 0d tensor."); @@ -137,7 +137,14 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { TensorDomain::noReductions(index_tv->getMaybeRFactorDomain()); size_t n_dims = lookup_dom.size(); TORCH_CHECK(n_dims > 0, "index_select can not be applied to 0d tensor."); - TORCH_CHECK(index_dom.size() == 1, "index array must be 1d tensor."); + TORCH_CHECK(index_dom.size() <= 1, "index array must be 1d or scalar tensor."); + + if (index_dom.size() == 0) { + std::vector squeeze_take_dim(n_dims, false); + squeeze_take_dim[dim] = true; + auto select_tv = select(lookup_tv, dim, index_tv); + return squeeze(select_tv, squeeze_take_dim); + } if (dim < 0) { dim += lookup_dom.size(); diff --git a/third_party/nvfuser/csrc/ops/arith.h b/third_party/nvfuser/csrc/ops/arith.h index f6d03006810cf..18552828831d6 100644 --- a/third_party/nvfuser/csrc/ops/arith.h +++ b/third_party/nvfuser/csrc/ops/arith.h @@ -140,7 +140,7 @@ TORCH_CUDA_CU_API WelfordResult WelfordRaw( // import IrBuilder just for this one interface. Int* init_N = nullptr); -TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Int* index); +TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Val* index); // RNG OPERATIONS TORCH_CUDA_CU_API TensorView* rand( From 8863916894d527dc8641827b8a8196e966c6ac7d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 19:59:35 -0800 Subject: [PATCH 09/14] fixing unsqueeze --- third_party/nvfuser/csrc/ops/arith.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index 8063a1ea3c7a3..31c497e650ae5 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -140,10 +140,8 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { TORCH_CHECK(index_dom.size() <= 1, "index array must be 1d or scalar tensor."); if (index_dom.size() == 0) { - std::vector squeeze_take_dim(n_dims, false); - squeeze_take_dim[dim] = true; auto select_tv = select(lookup_tv, dim, index_tv); - return squeeze(select_tv, squeeze_take_dim); + return unsqueeze(select_tv, dim); } if (dim < 0) { From 6d17917e55d7e99d27532cfe20e81bd2c5563781 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 22:56:09 -0800 Subject: [PATCH 10/14] reverting unwanted changes from rebase --- .../python_tests/test_python_frontend.py | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/third_party/nvfuser/python_tests/test_python_frontend.py b/third_party/nvfuser/python_tests/test_python_frontend.py index a2d256e77e0a1..fd82a5d593bff 100644 --- a/third_party/nvfuser/python_tests/test_python_frontend.py +++ b/third_party/nvfuser/python_tests/test_python_frontend.py @@ -1082,6 +1082,49 @@ def fuser_function(correction): torch_result = torch.var_mean(inputs[0], [0, 1, 2], bool(correction)) self.assertEqual(fuser_result, torch_result) + def test_scalar_only_inputs(self): + # We don't allow scalar outputs, currently, + # so a tensor has to be returned + def fusion_func(fd: FusionDefinition): + s0 = fd.define_scalar() + s1 = fd.define_scalar() + s2 = fd.ops.add(s0, s1) + c0 = fd.define_constant(1.0, DataType.Float) + t3 = fd.ops.full(size=[2, 2], arg=c0, dtype=DataType.Float) + t4 = fd.ops.mul(t3, s2) + fd.add_output(t4) + + with FusionDefinition() as fd: + fusion_func(fd) + + # TODO: full is broken and does not print its proper definition + # Issue: https://github.com/csarofeen/pytorch/issues/2502 + nvf_out = fd.execute([2.0, 3.0]) + eager_out = torch.full([2, 2], 1.0) * 5.0 + self.assertEqual(eager_out, nvf_out[0]) + + def test_addcmul(self): + inputs = [ + torch.randn(4, device="cuda", dtype=torch.float32), + torch.randn(4, device="cuda", dtype=torch.float32), + torch.randn(4, device="cuda", dtype=torch.float32), + ] + + def fusion_func(fd: FusionDefinition): + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.from_pytorch(inputs[2]) + c0 = fd.define_constant(0.1) + + t3 = fd.ops.addcmul(t0, t1, t2, c0) + + fd.add_output(t3) + + nvfout, _ = self.exec_nvfuser(fusion_func, inputs) + + torch_out = torch.addcmul(*inputs, value=0.1) + + self.assertEqual(nvfout[0], torch_out) -if __name__ == "__main__": +if __name__ == '__main__': run_tests() From 4ed617a6ac504666da7611e006c0fb094a5e40c9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Feb 2023 23:32:26 -0800 Subject: [PATCH 11/14] python tests added --- .../python_tests/test_python_frontend.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/third_party/nvfuser/python_tests/test_python_frontend.py b/third_party/nvfuser/python_tests/test_python_frontend.py index fd82a5d593bff..23dd187e314ef 100644 --- a/third_party/nvfuser/python_tests/test_python_frontend.py +++ b/third_party/nvfuser/python_tests/test_python_frontend.py @@ -743,6 +743,28 @@ def fusion_func(fd: FusionDefinition): test_fn(0) test_fn(1) + def test_index_select_scalar_indices(self): + inputs = [ + torch.randn(8, 16, device="cuda"), + torch.tensor(2, device="cuda").to(dtype=torch.long), + ] + + def test_fn(dim): + def fusion_func(fd: FusionDefinition): + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.ops.index_select(t0, t1, dim) + fd.add_output(t2) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + eager_out = torch.index_select(inputs[0], dim, inputs[1]) + self.assertEqual(eager_out, nvf_out[0]) + + test_fn(0) + test_fn(1) + + def test_squeeze(self): t0_sizes = [4] t1_sizes = [1, 4, 1] From ef77a11cb8bf251577f11bb73be09c94157d7d96 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 24 Feb 2023 10:20:16 -0800 Subject: [PATCH 12/14] Lower index of SelectOp --- third_party/nvfuser/csrc/codegen.cpp | 6 ------ third_party/nvfuser/csrc/lower_index.cpp | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/third_party/nvfuser/csrc/codegen.cpp b/third_party/nvfuser/csrc/codegen.cpp index 783b71392d41a..0704c998f5baf 100644 --- a/third_party/nvfuser/csrc/codegen.cpp +++ b/third_party/nvfuser/csrc/codegen.cpp @@ -492,12 +492,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void handle(const TensorView* tv) final { - // This allows us to access scalar tensor as if they are just scalar - TORCH_INTERNAL_ASSERT(tv->isZeroDim(), "TensorView can only be handled as scalar tensor"); - code_ << ir_utils::varName(tv) << "[0]"; - } - //! Utility for generating vectorized pointer access in ldsm and //! cpasync. //! TODO: this access pattern as is could be merged with exisiting diff --git a/third_party/nvfuser/csrc/lower_index.cpp b/third_party/nvfuser/csrc/lower_index.cpp index 17db4654641b3..e870ec96dda95 100644 --- a/third_party/nvfuser/csrc/lower_index.cpp +++ b/third_party/nvfuser/csrc/lower_index.cpp @@ -262,9 +262,26 @@ void IndexLowering::handle(const ScatterOp* sop) { } void IndexLowering::handle(const SelectOp* sop) { - const auto input = lowerSrcIndex( - sop->input(0), sop->output(0), sop->getIndexOverridingMap()); + auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0)); + auto lowered_index_cast = lowered_index; + + // If the type of the index tensor is different from the kernel + // index type, promote it to the kernel index type + if (GpuLower::current()->kernel()->indexType() != + sop->input(1)->getDataType().value()) { + lowered_index_cast = + IrBuilder::newScalar(GpuLower::current()->kernel()->indexType()); + IrBuilder::create( + UnaryOpType::Cast, lowered_index_cast, lowered_index); + } + + const std::unordered_map override_index = { + {sop->getSelectAxis(), lowered_index_cast}}; + const auto input = + lowerSrcIndex(sop->input(0), sop->output(0), override_index); + const auto out = lowerDstIndex(sop->output(0)); + pushBack(IrBuilder::create(UnaryOpType::Set, out, input)); GpuLower::current()->propagateExprInfo(sop, back()); } From dea736b863e8f777d579dcc796e66f0bad7c7b61 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 24 Feb 2023 10:47:40 -0800 Subject: [PATCH 13/14] Revert codegen change --- third_party/nvfuser/csrc/codegen.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/nvfuser/csrc/codegen.cpp b/third_party/nvfuser/csrc/codegen.cpp index 0704c998f5baf..396e44224c848 100644 --- a/third_party/nvfuser/csrc/codegen.cpp +++ b/third_party/nvfuser/csrc/codegen.cpp @@ -492,6 +492,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { TORCH_INTERNAL_ASSERT(false, "Unreachable"); } + void handle(const TensorView*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); + } + //! Utility for generating vectorized pointer access in ldsm and //! cpasync. //! TODO: this access pattern as is could be merged with exisiting From 54e9037a197723f07d8e960555c908a6624d4525 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 10 Mar 2023 14:04:35 -0800 Subject: [PATCH 14/14] lintrunner --- third_party/nvfuser/csrc/ops/arith.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index ef98405fb23fe..f70bdb7d12bdb 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -137,7 +137,8 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { TensorDomain::noReductions(index_tv->getMaybeRFactorDomain()); size_t n_dims = lookup_dom.size(); TORCH_CHECK(n_dims > 0, "index_select can not be applied to 0d tensor."); - TORCH_CHECK(index_dom.size() <= 1, "index array must be 1d or scalar tensor."); + TORCH_CHECK( + index_dom.size() <= 1, "index array must be 1d or scalar tensor."); if (index_dom.size() == 0) { auto select_tv = select(lookup_tv, dim, index_tv);