@@ -1937,52 +1937,55 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
1937
1937
return strided_inds;
1938
1938
}
1939
1939
1940
- std::vector<Val*> Index::getLinearIndex (
1941
- TensorView* consumer_tv,
1942
- const std::vector<kir::ForLoop*>& loops) {
1940
+ template <typename func_t >
1941
+ auto evaluateWithOverridenContiguity (
1942
+ TensorView* tv,
1943
+ bool contiguity,
1944
+ const func_t & functor) -> decltype(functor()) {
1943
1945
// Use domain guard to ignore the contiguity of
1944
1946
// consumer tv.
1945
- TensorDomain* consumer_tv_no_contiguity_domain = nullptr ;
1946
- auto contiguity_vector =
1947
- std::vector< bool >(consumer_tv ->getMaybeRFactorDomain ().size (), true );
1948
- if (consumer_tv ->hasRFactor ()) {
1949
- consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1950
- consumer_tv ->getRootDomain (),
1951
- consumer_tv ->getRFactorDomain (),
1952
- consumer_tv ->domain ()->domain (),
1947
+ TensorDomain* domain_with_specified_contiguity = nullptr ;
1948
+ std::vector< bool > contiguity_vector (
1949
+ tv ->getMaybeRFactorDomain ().size (), contiguity );
1950
+ if (tv ->hasRFactor ()) {
1951
+ domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
1952
+ tv ->getRootDomain (),
1953
+ tv ->getRFactorDomain (),
1954
+ tv ->domain ()->domain (),
1953
1955
contiguity_vector);
1954
1956
} else {
1955
- consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1956
- consumer_tv->getRootDomain (),
1957
- consumer_tv->domain ()->domain (),
1958
- contiguity_vector);
1957
+ domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
1958
+ tv->getRootDomain (), tv->domain ()->domain (), contiguity_vector);
1959
1959
}
1960
1960
1961
- ir_utils::TVDomainGuard domain_guard (
1962
- consumer_tv, consumer_tv_no_contiguity_domain);
1961
+ ir_utils::TVDomainGuard domain_guard (tv, domain_with_specified_contiguity);
1963
1962
1964
- // TODO:
1965
- // More optimization on the underlying tensor layout
1966
- // will be done in a follow up.
1967
- return getGlobalConsumerStridedIndices (consumer_tv, loops);
1963
+ return functor ();
1968
1964
}
1969
1965
1970
- std::vector<Val*> Index::getGlobalConsumerStridedIndices (
1971
- const TensorView* consumer_tv,
1966
+ std::vector<Val*> Index::getLinearLogicalIndex (
1967
+ TensorView* consumer_tv,
1972
1968
const std::vector<kir::ForLoop*>& loops) {
1973
- FUSER_PERF_SCOPE (" GpuLower::Lower::getGlobalConsumerIndex" );
1974
-
1975
- auto gpu_lower = GpuLower::current ();
1976
-
1977
- auto index_from_id_graph = getTensorIndexFromIdGraph (loops, consumer_tv);
1969
+ return evaluateWithOverridenContiguity (consumer_tv, true , [&]() {
1970
+ return getGlobalConsumerStridedIndices (consumer_tv, loops);
1971
+ });
1972
+ }
1978
1973
1979
- auto consumer_indexing = index_from_id_graph.index ;
1974
+ std::vector<Val*> Index::getPerDimLogicalIndex (
1975
+ TensorView* consumer_tv,
1976
+ const std::vector<kir::ForLoop*>& loops) {
1977
+ return evaluateWithOverridenContiguity (consumer_tv, false , [&]() {
1978
+ IndexFromIdGraph index_from_id_graph =
1979
+ getTensorIndexFromIdGraph (loops, consumer_tv);
1980
+ return getRootIndices (consumer_tv, loops, index_from_id_graph);
1981
+ });
1982
+ }
1980
1983
1984
+ std::vector<Val*> Index::getStrides (const TensorView* tv) {
1981
1985
// Indices should now be mapped onto IterDomains in consumer, so just grab
1982
1986
// and use them.
1983
- auto root_dom = consumer_tv ->getMaybeRFactorDomain ();
1987
+ auto root_dom = tv ->getMaybeRFactorDomain ();
1984
1988
1985
- // TODO: Abstract stride logic to reuse with producer indexing
1986
1989
std::vector<Val*> strides (
1987
1990
root_dom.size (), GpuLower::current ()->kernel ()->oneVal ());
1988
1991
{
@@ -1993,39 +1996,21 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
1993
1996
continue ;
1994
1997
}
1995
1998
std::stringstream ss;
1996
- ss << " T" << consumer_tv ->name () << " .stride[" << stride_i++ << " ]" ;
1999
+ ss << " T" << tv ->name () << " .stride[" << stride_i++ << " ]" ;
1997
2000
strides[i] =
1998
2001
SimplifyingIrBuilder::create<NamedScalar>(ss.str (), DataType::Int);
1999
2002
}
2000
2003
}
2001
2004
2002
- TORCH_INTERNAL_ASSERT (
2003
- root_dom.size () == consumer_tv->domain ()->contiguity ().size ());
2005
+ TORCH_INTERNAL_ASSERT (root_dom.size () == tv->domain ()->contiguity ().size ());
2004
2006
Val* cur_contig_stride = GpuLower::current ()->kernel ()->oneVal ();
2005
2007
for (const auto i : c10::irange (root_dom.size ())) {
2006
2008
auto dim = root_dom.size () - i - 1 ;
2007
2009
if (root_dom[dim]->isReduction () || root_dom[dim]->isStride ()) {
2008
2010
continue ;
2009
2011
}
2010
2012
2011
- Val* root_ind = nullptr ;
2012
- if (consumer_indexing.indexMap ().find (root_dom[dim]) !=
2013
- consumer_indexing.indexMap ().end ()) {
2014
- root_ind = consumer_indexing.indexMap ().at (root_dom[dim]);
2015
- } else if (root_dom[dim]->isBroadcast ()) {
2016
- root_ind = GpuLower::current ()->kernel ()->zeroVal ();
2017
- }
2018
-
2019
- TORCH_INTERNAL_ASSERT (
2020
- root_ind != nullptr ,
2021
- " Couldn't find root mapping for " ,
2022
- consumer_tv->toString (),
2023
- " dim: " ,
2024
- dim,
2025
- " id: " ,
2026
- root_dom[dim]->toString ());
2027
-
2028
- if (consumer_tv->domain ()->contiguity ()[dim]) {
2013
+ if (tv->domain ()->contiguity ()[dim]) {
2029
2014
// If contig, used the stored stride which may be the previous
2030
2015
// dimensions stride * previous dimensions size
2031
2016
strides[dim] = cur_contig_stride;
@@ -2041,12 +2026,18 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
2041
2026
strides[dim], getHaloExtentOfRootAxis (root_dom[dim]));
2042
2027
}
2043
2028
}
2029
+ return strides;
2030
+ }
2044
2031
2045
- auto vectorize_shift =
2046
- loops.empty () ? nullptr : loops.back ()->vectorize_shift ();
2032
+ std::vector<Val*> Index::getRootIndices (
2033
+ const TensorView* tv,
2034
+ const std::vector<kir::ForLoop*>& loops,
2035
+ const IndexFromIdGraph& index_from_id_graph) {
2036
+ auto gpu_lower = GpuLower::current ();
2037
+ auto root_dom = tv->getMaybeRFactorDomain ();
2038
+ auto indexing = index_from_id_graph.index ;
2047
2039
2048
- // Global striding
2049
- std::vector<Val*> strided_inds (
2040
+ std::vector<Val*> root_inds (
2050
2041
root_dom.size (), GpuLower::current ()->kernel ()->zeroVal ());
2051
2042
for (const auto i : c10::irange (root_dom.size ())) {
2052
2043
// See a comment in indexing to root domains in getGlobalProducerIndex.
@@ -2057,35 +2048,55 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
2057
2048
}
2058
2049
2059
2050
TORCH_INTERNAL_ASSERT (
2060
- consumer_indexing.indexMap ().find (root_dom[i]) !=
2061
- consumer_indexing.indexMap ().end (),
2051
+ indexing.indexMap ().find (root_dom[i]) != indexing.indexMap ().end (),
2062
2052
" Couldn't find root mapping for " ,
2063
- consumer_tv ->toString (),
2053
+ tv ->toString (),
2064
2054
" dim: " ,
2065
2055
i,
2066
2056
" id: " ,
2067
2057
root_dom[i]->toString ());
2068
2058
2069
- auto root_ind = consumer_indexing .indexMap ().at (root_dom[i]);
2059
+ auto root_ind = indexing .indexMap ().at (root_dom[i]);
2070
2060
2071
2061
// index hoist must be done before the adjustments for halo
2072
2062
root_ind = hoistConsumerIndex (
2073
2063
root_dom[i],
2074
- consumer_tv ,
2075
- consumer_indexing ,
2064
+ tv ,
2065
+ indexing ,
2076
2066
index_from_id_graph.resolved_loop_domains ,
2077
2067
index_from_id_graph.initial_concrete_index_map ,
2078
2068
loops,
2079
2069
root_ind);
2080
2070
2081
2071
root_ind = SimplifyingIrBuilder::addExpr (
2082
2072
root_ind, getGlobalConsumerOffsetWithPartialSplit (root_dom[i]));
2073
+ root_inds[i] = root_ind;
2074
+ }
2075
+ return root_inds;
2076
+ }
2083
2077
2084
- if (root_ind->isZeroInt ()) {
2078
+ std::vector<Val*> Index::getGlobalConsumerStridedIndices (
2079
+ const TensorView* consumer_tv,
2080
+ const std::vector<kir::ForLoop*>& loops) {
2081
+ FUSER_PERF_SCOPE (" GpuLower::Lower::getGlobalConsumerIndex" );
2082
+
2083
+ auto index_from_id_graph = getTensorIndexFromIdGraph (loops, consumer_tv);
2084
+ auto consumer_indexing = index_from_id_graph.index ;
2085
+ auto strides = getStrides (consumer_tv);
2086
+ auto root_inds = getRootIndices (consumer_tv, loops, index_from_id_graph);
2087
+
2088
+ // Global striding
2089
+ auto vectorize_shift =
2090
+ loops.empty () ? nullptr : loops.back ()->vectorize_shift ();
2091
+ std::vector<Val*> strided_inds (
2092
+ root_inds.size (), GpuLower::current ()->kernel ()->zeroVal ());
2093
+ for (const auto i : c10::irange (root_inds.size ())) {
2094
+ if (root_inds[i]->isZeroInt ()) {
2085
2095
continue ;
2086
2096
} else {
2087
- auto strided_ind = SimplifyingIrBuilder::mulExpr (root_ind, strides[i]);
2088
- if (i == root_dom.size () - 1 && vectorize_shift != nullptr ) {
2097
+ auto strided_ind =
2098
+ SimplifyingIrBuilder::mulExpr (root_inds[i], strides[i]);
2099
+ if (i == strides.size () - 1 && vectorize_shift != nullptr ) {
2089
2100
strided_inds[i] =
2090
2101
SimplifyingIrBuilder::addExpr (strided_ind, vectorize_shift);
2091
2102
} else {
0 commit comments