diff --git a/third_party/nvfuser/csrc/lower_index.cpp b/third_party/nvfuser/csrc/lower_index.cpp index 608ff6841b8d8..38df35a7879f1 100644 --- a/third_party/nvfuser/csrc/lower_index.cpp +++ b/third_party/nvfuser/csrc/lower_index.cpp @@ -294,9 +294,26 @@ void IndexLowering::handle(const ScatterOp* sop) { } void IndexLowering::handle(const SelectOp* sop) { - const auto input = lowerSrcIndex( - sop->input(0), sop->output(0), sop->getIndexOverridingMap()); + auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0)); + auto lowered_index_cast = lowered_index; + + // If the type of the index tensor is different from the kernel + // index type, promote it to the kernel index type + if (GpuLower::current()->kernel()->indexType() != + sop->input(1)->getDataType().value()) { + lowered_index_cast = + IrBuilder::newScalar(GpuLower::current()->kernel()->indexType()); + IrBuilder::create( + UnaryOpType::Cast, lowered_index_cast, lowered_index); + } + + const std::unordered_map override_index = { + {sop->getSelectAxis(), lowered_index_cast}}; + const auto input = + lowerSrcIndex(sop->input(0), sop->output(0), override_index); + const auto out = lowerDstIndex(sop->output(0)); + pushBack(IrBuilder::create(UnaryOpType::Set, out, input)); GpuLower::current()->propagateExprInfo(sop, back()); } diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index 08a8fdba7ef90..4b7fff8b8c9de 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -94,7 +94,7 @@ TensorView* unaryOp( return unaryOp(type, cast_v1)->as(); } -TensorView* select(TensorView* tv, int dim, Int* index) { +TensorView* select(TensorView* tv, int dim, Val* index) { auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); TORCH_CHECK(dom.size() > 0, "select can not be applied to 0d tensor."); @@ -137,7 +137,13 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { TensorDomain::noReductions(index_tv->getMaybeRFactorDomain()); size_t n_dims = lookup_dom.size(); TORCH_CHECK(n_dims > 0, "index_select can not be applied to 0d tensor."); - TORCH_CHECK(index_dom.size() == 1, "index array must be 1d tensor."); + TORCH_CHECK( + index_dom.size() <= 1, "index array must be 1d or scalar tensor."); + + if (index_dom.size() == 0) { + auto select_tv = select(lookup_tv, dim, index_tv); + return unsqueeze(select_tv, dim); + } if (dim < 0) { dim += lookup_dom.size(); diff --git a/third_party/nvfuser/csrc/ops/arith.h b/third_party/nvfuser/csrc/ops/arith.h index 8eb398ecf0fee..cb1cdbc39149c 100644 --- a/third_party/nvfuser/csrc/ops/arith.h +++ b/third_party/nvfuser/csrc/ops/arith.h @@ -140,7 +140,7 @@ TORCH_CUDA_CU_API WelfordResult WelfordRaw( // import IrBuilder just for this one interface. Int* init_N = nullptr); -TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Int* index); +TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Val* index); // RNG OPERATIONS TORCH_CUDA_CU_API TensorView* rand( diff --git a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp index 0ab30d16853d5..25bf93b9ce709 100644 --- a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp +++ b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp @@ -318,10 +318,10 @@ void initNvFuserPythonBindings(PyObject* module) { maybe_symbolic_sizes.reserve(sizes.size()); for (const auto i : c10::irange(sizes.size())) { TORCH_INTERNAL_ASSERT( - sizes[i] > 0, + sizes[i] >= 0, "Size of ", sizes[i], - " is not supported in nvFuser. Expected size > 0."); + " is not supported in nvFuser. Expected size >= 0."); if (sizes[i] == 1) { maybe_symbolic_sizes.push_back(1); } else { diff --git a/third_party/nvfuser/python_tests/test_python_frontend.py b/third_party/nvfuser/python_tests/test_python_frontend.py index 6f02dca74eeab..d24664c6c95cb 100644 --- a/third_party/nvfuser/python_tests/test_python_frontend.py +++ b/third_party/nvfuser/python_tests/test_python_frontend.py @@ -746,6 +746,28 @@ def fusion_func(fd: FusionDefinition): test_fn(0) test_fn(1) + def test_index_select_scalar_indices(self): + inputs = [ + torch.randn(8, 16, device="cuda"), + torch.tensor(2, device="cuda").to(dtype=torch.long), + ] + + def test_fn(dim): + def fusion_func(fd: FusionDefinition): + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.ops.index_select(t0, t1, dim) + fd.add_output(t2) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + eager_out = torch.index_select(inputs[0], dim, inputs[1]) + self.assertEqual(eager_out, nvf_out[0]) + + test_fn(0) + test_fn(1) + + def test_squeeze(self): t0_sizes = [4] t1_sizes = [1, 4, 1]