Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions third_party/nvfuser/csrc/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,26 @@ void IndexLowering::handle(const ScatterOp* sop) {
}

void IndexLowering::handle(const SelectOp* sop) {
const auto input = lowerSrcIndex(
sop->input(0), sop->output(0), sop->getIndexOverridingMap());
auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0));
auto lowered_index_cast = lowered_index;

// If the type of the index tensor is different from the kernel
// index type, promote it to the kernel index type
if (GpuLower::current()->kernel()->indexType() !=
sop->input(1)->getDataType().value()) {
lowered_index_cast =
IrBuilder::newScalar(GpuLower::current()->kernel()->indexType());
IrBuilder::create<UnaryOp>(
UnaryOpType::Cast, lowered_index_cast, lowered_index);
}

const std::unordered_map<IterDomain*, Val*> override_index = {
{sop->getSelectAxis(), lowered_index_cast}};
const auto input =
lowerSrcIndex(sop->input(0), sop->output(0), override_index);

const auto out = lowerDstIndex(sop->output(0));

pushBack(IrBuilder::create<UnaryOp>(UnaryOpType::Set, out, input));
GpuLower::current()->propagateExprInfo(sop, back());
}
Expand Down
10 changes: 8 additions & 2 deletions third_party/nvfuser/csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TensorView* unaryOp(
return unaryOp(type, cast_v1)->as<TensorView>();
}

TensorView* select(TensorView* tv, int dim, Int* index) {
TensorView* select(TensorView* tv, int dim, Val* index) {
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
TORCH_CHECK(dom.size() > 0, "select can not be applied to 0d tensor.");

Expand Down Expand Up @@ -137,7 +137,13 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) {
TensorDomain::noReductions(index_tv->getMaybeRFactorDomain());
size_t n_dims = lookup_dom.size();
TORCH_CHECK(n_dims > 0, "index_select can not be applied to 0d tensor.");
TORCH_CHECK(index_dom.size() == 1, "index array must be 1d tensor.");
TORCH_CHECK(
index_dom.size() <= 1, "index array must be 1d or scalar tensor.");

if (index_dom.size() == 0) {
auto select_tv = select(lookup_tv, dim, index_tv);
return unsqueeze(select_tv, dim);
}

if (dim < 0) {
dim += lookup_dom.size();
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/ops/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ TORCH_CUDA_CU_API WelfordResult WelfordRaw(
// import IrBuilder just for this one interface.
Int* init_N = nullptr);

TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Int* index);
TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Val* index);

// RNG OPERATIONS
TORCH_CUDA_CU_API TensorView* rand(
Expand Down
4 changes: 2 additions & 2 deletions third_party/nvfuser/csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ void initNvFuserPythonBindings(PyObject* module) {
maybe_symbolic_sizes.reserve(sizes.size());
for (const auto i : c10::irange(sizes.size())) {
TORCH_INTERNAL_ASSERT(
sizes[i] > 0,
sizes[i] >= 0,
"Size of ",
sizes[i],
" is not supported in nvFuser. Expected size > 0.");
" is not supported in nvFuser. Expected size >= 0.");
if (sizes[i] == 1) {
maybe_symbolic_sizes.push_back(1);
} else {
Expand Down
22 changes: 22 additions & 0 deletions third_party/nvfuser/python_tests/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,28 @@ def fusion_func(fd: FusionDefinition):
test_fn(0)
test_fn(1)

def test_index_select_scalar_indices(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam tests added and verified the failing after reverting changes in codegen.cpp. It's all yours now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, BTW, current branch I'm on is broken. You might want to revert the fouling commit for scatter 9340f80

inputs = [
torch.randn(8, 16, device="cuda"),
torch.tensor(2, device="cuda").to(dtype=torch.long),
]

def test_fn(dim):
def fusion_func(fd: FusionDefinition):
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
t2 = fd.ops.index_select(t0, t1, dim)
fd.add_output(t2)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

eager_out = torch.index_select(inputs[0], dim, inputs[1])
self.assertEqual(eager_out, nvf_out[0])

test_fn(0)
test_fn(1)


def test_squeeze(self):
t0_sizes = [4]
t1_sizes = [1, 4, 1]
Expand Down