Skip to content

Commit 90a51f2

Browse files
authored
Some indexing cleanups, Add eye support (#1940)
1 parent ddc01e4 commit 90a51f2

26 files changed

+454
-95
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,23 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
553553
return out;
554554
}
555555

556+
TensorView* eye(Val* rows, Val* cols, DataType dtype) {
557+
TORCH_CHECK(rows->getDataType() == DataType::Int, "rows must have type Int");
558+
TORCH_CHECK(cols->getDataType() == DataType::Int, "cols must have type Int");
559+
auto out = TensorViewBuilder()
560+
.ndims(2)
561+
.dtype(dtype)
562+
.contiguity({true, true})
563+
.shape(std::vector<Val*>{rows, cols})
564+
.build();
565+
IrBuilder::create<EyeOp>(out, dtype);
566+
return out;
567+
}
568+
569+
TensorView* eye(Val* size, DataType dtype) {
570+
return eye(size, size, dtype);
571+
}
572+
556573
// UNARY OPERATIONS
557574

558575
#define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ TORCH_CUDA_CU_API TensorView* arange(
156156
Val* end,
157157
Val* step,
158158
DataType dtype = DataType::Int);
159+
TORCH_CUDA_CU_API TensorView* eye(Val* size, DataType dtype);
160+
TORCH_CUDA_CU_API TensorView* eye(Val* rows, Val* cols, DataType dtype);
159161

160162
// UNARY OPERATIONS
161163
// abs

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,12 +566,20 @@ class CudaKernelGenerator : private OptOutConstDispatch {
566566
}
567567

568568
void handle(const ARangeOp* aop) final {
569-
auto index = genTensorIndex(aop->getLinearIndex()->as<kir::TensorIndex>());
569+
auto index =
570+
genTensorIndex(aop->getLinearLogicalIndex()->as<kir::TensorIndex>());
570571
indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">";
571572
code_ << "(" << index << ", " << gen(aop->start()) << ", "
572573
<< gen(aop->step()) << ");\n";
573574
}
574575

576+
void handle(const EyeOp* aop) final {
577+
auto index1 = gen(aop->getIndex1());
578+
auto index2 = gen(aop->getIndex2());
579+
indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")";
580+
code_ << "(" << index1 << " == " << index2 << ");\n";
581+
}
582+
575583
void handle(const UnaryOp* uop) final {
576584
bool is_vector_op = false;
577585
size_t vector_word_size = 1;

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ void Expr::dispatch(T handler, Expr* expr) {
101101
case ExprType::ARangeOp:
102102
ptr(handler)->handle(expr->as<ARangeOp>());
103103
return;
104+
case ExprType::EyeOp:
105+
ptr(handler)->handle(expr->as<EyeOp>());
106+
return;
104107
case ExprType::UnaryOp:
105108
ptr(handler)->handle(expr->as<UnaryOp>());
106109
return;
@@ -290,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
290293
case ExprType::ARangeOp:
291294
ptr(handler)->handle(expr->as<ARangeOp>());
292295
return;
296+
case ExprType::EyeOp:
297+
ptr(handler)->handle(expr->as<EyeOp>());
298+
return;
293299
case ExprType::UnaryOp:
294300
ptr(handler)->handle(expr->as<UnaryOp>());
295301
return;
@@ -487,6 +493,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
487493
case ExprType::ARangeOp:
488494
ptr(mutator)->mutate(expr->as<ARangeOp>());
489495
return;
496+
case ExprType::EyeOp:
497+
ptr(mutator)->mutate(expr->as<EyeOp>());
498+
return;
490499
case ExprType::UnaryOp:
491500
ptr(mutator)->mutate(expr->as<UnaryOp>());
492501
return;
@@ -749,6 +758,9 @@ void OptOutConstDispatch::handle(const FullOp* stmt) {
749758
void OptOutConstDispatch::handle(const ARangeOp* stmt) {
750759
unhandled(stmt);
751760
}
761+
void OptOutConstDispatch::handle(const EyeOp* stmt) {
762+
unhandled(stmt);
763+
}
752764
void OptOutConstDispatch::handle(const UnaryOp* stmt) {
753765
unhandled(stmt);
754766
}
@@ -908,6 +920,9 @@ void OptOutDispatch::handle(FullOp* stmt) {
908920
void OptOutDispatch::handle(ARangeOp* stmt) {
909921
unhandled(stmt);
910922
}
923+
void OptOutDispatch::handle(EyeOp* stmt) {
924+
unhandled(stmt);
925+
}
911926
void OptOutDispatch::handle(UnaryOp* stmt) {
912927
unhandled(stmt);
913928
}

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class NamedScalar;
7070
// Exprs
7171
class FullOp;
7272
class ARangeOp;
73+
class EyeOp;
7374
class UnaryOp;
7475
class BinaryOp;
7576
class TernaryOp;
@@ -147,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
147148
// Exprs
148149
virtual void handle(const FullOp* stmt);
149150
virtual void handle(const ARangeOp* stmt);
151+
virtual void handle(const EyeOp* stmt);
150152
virtual void handle(const UnaryOp* stmt);
151153
virtual void handle(const BinaryOp* stmt);
152154
virtual void handle(const TernaryOp* stmt);
@@ -215,6 +217,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
215217
// Exprs
216218
virtual void handle(FullOp* stmt);
217219
virtual void handle(ARangeOp* stmt);
220+
virtual void handle(EyeOp* stmt);
218221
virtual void handle(UnaryOp* stmt);
219222
virtual void handle(BinaryOp* stmt);
220223
virtual void handle(TernaryOp* stmt);
@@ -324,6 +327,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
324327
// Exprs
325328
virtual void mutate(FullOp*);
326329
virtual void mutate(ARangeOp*);
330+
virtual void mutate(EyeOp*);
327331
virtual void mutate(UnaryOp*);
328332
virtual void mutate(BinaryOp*);
329333
virtual void mutate(TernaryOp*);

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 76 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,52 +1937,55 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
19371937
return strided_inds;
19381938
}
19391939

1940-
std::vector<Val*> Index::getLinearIndex(
1941-
TensorView* consumer_tv,
1942-
const std::vector<kir::ForLoop*>& loops) {
1940+
template <typename func_t>
1941+
auto evaluateWithOverridenContiguity(
1942+
TensorView* tv,
1943+
bool contiguity,
1944+
const func_t& functor) -> decltype(functor()) {
19431945
// Use domain guard to ignore the contiguity of
19441946
// consumer tv.
1945-
TensorDomain* consumer_tv_no_contiguity_domain = nullptr;
1946-
auto contiguity_vector =
1947-
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), true);
1948-
if (consumer_tv->hasRFactor()) {
1949-
consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1950-
consumer_tv->getRootDomain(),
1951-
consumer_tv->getRFactorDomain(),
1952-
consumer_tv->domain()->domain(),
1947+
TensorDomain* domain_with_specified_contiguity = nullptr;
1948+
std::vector<bool> contiguity_vector(
1949+
tv->getMaybeRFactorDomain().size(), contiguity);
1950+
if (tv->hasRFactor()) {
1951+
domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
1952+
tv->getRootDomain(),
1953+
tv->getRFactorDomain(),
1954+
tv->domain()->domain(),
19531955
contiguity_vector);
19541956
} else {
1955-
consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1956-
consumer_tv->getRootDomain(),
1957-
consumer_tv->domain()->domain(),
1958-
contiguity_vector);
1957+
domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
1958+
tv->getRootDomain(), tv->domain()->domain(), contiguity_vector);
19591959
}
19601960

1961-
ir_utils::TVDomainGuard domain_guard(
1962-
consumer_tv, consumer_tv_no_contiguity_domain);
1961+
ir_utils::TVDomainGuard domain_guard(tv, domain_with_specified_contiguity);
19631962

1964-
// TODO:
1965-
// More optimization on the underlying tensor layout
1966-
// will be done in a follow up.
1967-
return getGlobalConsumerStridedIndices(consumer_tv, loops);
1963+
return functor();
19681964
}
19691965

1970-
std::vector<Val*> Index::getGlobalConsumerStridedIndices(
1971-
const TensorView* consumer_tv,
1966+
std::vector<Val*> Index::getLinearLogicalIndex(
1967+
TensorView* consumer_tv,
19721968
const std::vector<kir::ForLoop*>& loops) {
1973-
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
1974-
1975-
auto gpu_lower = GpuLower::current();
1976-
1977-
auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv);
1969+
return evaluateWithOverridenContiguity(consumer_tv, true, [&]() {
1970+
return getGlobalConsumerStridedIndices(consumer_tv, loops);
1971+
});
1972+
}
19781973

1979-
auto consumer_indexing = index_from_id_graph.index;
1974+
std::vector<Val*> Index::getPerDimLogicalIndex(
1975+
TensorView* consumer_tv,
1976+
const std::vector<kir::ForLoop*>& loops) {
1977+
return evaluateWithOverridenContiguity(consumer_tv, false, [&]() {
1978+
IndexFromIdGraph index_from_id_graph =
1979+
getTensorIndexFromIdGraph(loops, consumer_tv);
1980+
return getRootIndices(consumer_tv, loops, index_from_id_graph);
1981+
});
1982+
}
19801983

1984+
std::vector<Val*> Index::getStrides(const TensorView* tv) {
19811985
// Indices should now be mapped onto IterDomains in consumer, so just grab
19821986
// and use them.
1983-
auto root_dom = consumer_tv->getMaybeRFactorDomain();
1987+
auto root_dom = tv->getMaybeRFactorDomain();
19841988

1985-
// TODO: Abstract stride logic to reuse with producer indexing
19861989
std::vector<Val*> strides(
19871990
root_dom.size(), GpuLower::current()->kernel()->oneVal());
19881991
{
@@ -1993,39 +1996,21 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
19931996
continue;
19941997
}
19951998
std::stringstream ss;
1996-
ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
1999+
ss << "T" << tv->name() << ".stride[" << stride_i++ << "]";
19972000
strides[i] =
19982001
SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int);
19992002
}
20002003
}
20012004

2002-
TORCH_INTERNAL_ASSERT(
2003-
root_dom.size() == consumer_tv->domain()->contiguity().size());
2005+
TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size());
20042006
Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
20052007
for (const auto i : c10::irange(root_dom.size())) {
20062008
auto dim = root_dom.size() - i - 1;
20072009
if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) {
20082010
continue;
20092011
}
20102012

2011-
Val* root_ind = nullptr;
2012-
if (consumer_indexing.indexMap().find(root_dom[dim]) !=
2013-
consumer_indexing.indexMap().end()) {
2014-
root_ind = consumer_indexing.indexMap().at(root_dom[dim]);
2015-
} else if (root_dom[dim]->isBroadcast()) {
2016-
root_ind = GpuLower::current()->kernel()->zeroVal();
2017-
}
2018-
2019-
TORCH_INTERNAL_ASSERT(
2020-
root_ind != nullptr,
2021-
"Couldn't find root mapping for ",
2022-
consumer_tv->toString(),
2023-
" dim: ",
2024-
dim,
2025-
" id: ",
2026-
root_dom[dim]->toString());
2027-
2028-
if (consumer_tv->domain()->contiguity()[dim]) {
2013+
if (tv->domain()->contiguity()[dim]) {
20292014
// If contig, used the stored stride which may be the previous
20302015
// dimensions stride * previous dimensions size
20312016
strides[dim] = cur_contig_stride;
@@ -2041,12 +2026,18 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
20412026
strides[dim], getHaloExtentOfRootAxis(root_dom[dim]));
20422027
}
20432028
}
2029+
return strides;
2030+
}
20442031

2045-
auto vectorize_shift =
2046-
loops.empty() ? nullptr : loops.back()->vectorize_shift();
2032+
std::vector<Val*> Index::getRootIndices(
2033+
const TensorView* tv,
2034+
const std::vector<kir::ForLoop*>& loops,
2035+
const IndexFromIdGraph& index_from_id_graph) {
2036+
auto gpu_lower = GpuLower::current();
2037+
auto root_dom = tv->getMaybeRFactorDomain();
2038+
auto indexing = index_from_id_graph.index;
20472039

2048-
// Global striding
2049-
std::vector<Val*> strided_inds(
2040+
std::vector<Val*> root_inds(
20502041
root_dom.size(), GpuLower::current()->kernel()->zeroVal());
20512042
for (const auto i : c10::irange(root_dom.size())) {
20522043
// See a comment in indexing to root domains in getGlobalProducerIndex.
@@ -2057,35 +2048,55 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
20572048
}
20582049

20592050
TORCH_INTERNAL_ASSERT(
2060-
consumer_indexing.indexMap().find(root_dom[i]) !=
2061-
consumer_indexing.indexMap().end(),
2051+
indexing.indexMap().find(root_dom[i]) != indexing.indexMap().end(),
20622052
"Couldn't find root mapping for ",
2063-
consumer_tv->toString(),
2053+
tv->toString(),
20642054
" dim: ",
20652055
i,
20662056
" id: ",
20672057
root_dom[i]->toString());
20682058

2069-
auto root_ind = consumer_indexing.indexMap().at(root_dom[i]);
2059+
auto root_ind = indexing.indexMap().at(root_dom[i]);
20702060

20712061
// index hoist must be done before the adjustments for halo
20722062
root_ind = hoistConsumerIndex(
20732063
root_dom[i],
2074-
consumer_tv,
2075-
consumer_indexing,
2064+
tv,
2065+
indexing,
20762066
index_from_id_graph.resolved_loop_domains,
20772067
index_from_id_graph.initial_concrete_index_map,
20782068
loops,
20792069
root_ind);
20802070

20812071
root_ind = SimplifyingIrBuilder::addExpr(
20822072
root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i]));
2073+
root_inds[i] = root_ind;
2074+
}
2075+
return root_inds;
2076+
}
20832077

2084-
if (root_ind->isZeroInt()) {
2078+
std::vector<Val*> Index::getGlobalConsumerStridedIndices(
2079+
const TensorView* consumer_tv,
2080+
const std::vector<kir::ForLoop*>& loops) {
2081+
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
2082+
2083+
auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv);
2084+
auto consumer_indexing = index_from_id_graph.index;
2085+
auto strides = getStrides(consumer_tv);
2086+
auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph);
2087+
2088+
// Global striding
2089+
auto vectorize_shift =
2090+
loops.empty() ? nullptr : loops.back()->vectorize_shift();
2091+
std::vector<Val*> strided_inds(
2092+
root_inds.size(), GpuLower::current()->kernel()->zeroVal());
2093+
for (const auto i : c10::irange(root_inds.size())) {
2094+
if (root_inds[i]->isZeroInt()) {
20852095
continue;
20862096
} else {
2087-
auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
2088-
if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
2097+
auto strided_ind =
2098+
SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]);
2099+
if (i == strides.size() - 1 && vectorize_shift != nullptr) {
20892100
strided_inds[i] =
20902101
SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
20912102
} else {

0 commit comments

Comments
 (0)