Skip to content

Commit deb58f8

Browse files
naoyamzasdfgbnmcsarofeen
authored
Vec perf testing devel merge (#2041)
* Fix vectorize size calculation (#2035) * Allow non-root trivial reductions (#2037) * Allow non-root trivial reductions Fixes #2008 Co-authored-by: Christian Sarofeen <[email protected]> * Test file cleanup (#2040) * Move test_gpu.cpp to test_gpu1.cpp * Split test_gpu1.cpp to test_gpu1.cpp, test_gpu2.cpp and test_gpu3.cpp. Each file should be up to 10K LoC. New tests should be added to test_gpu3.cpp until it gets 10K LoC. Co-authored-by: Gao, Xiang <[email protected]> Co-authored-by: Christian Sarofeen <[email protected]>
1 parent 8bbb00e commit deb58f8

File tree

12 files changed

+26592
-26205
lines changed

12 files changed

+26592
-26205
lines changed

test/cpp/jit/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ if(USE_CUDA)
9999
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp)
100100
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp)
101101
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp)
102-
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu.cpp)
102+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu1.cpp)
103+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu2.cpp)
104+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu3.cpp)
103105
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp)
104106
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp)
105107
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,16 +1416,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
14161416
}
14171417

14181418
//! Check if IterDomain is a reduction axis with size of 1, i.e.
1419-
//! a "squeeze" operator.
1420-
//!
1421-
//! NOTE: Detection of trivial reduction here is not
1422-
//! comprehensive. See detectTrivialReductionDerivedDomains for more
1423-
//! comprehensive analysis. We typically use this for root domain trivial
1424-
//! reduction checks. So we ship to the correct scheduler. It may
1425-
//! not be incredibly robust, but it makes sense to keep it for now.
1426-
bool isTrivialReduction() const {
1427-
return isReduction() && extent()->isOneInt();
1428-
}
1419+
//! a "squeeze" operator, or solely derived from such axes.
1420+
bool isTrivialReduction() const;
14291421

14301422
//! Split for stride by a given factor. It effectively does an inner
14311423
//! split by the factor and sets the inner domain as a Stride

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,37 @@ IterDomain* IterDomain::cloneWithoutRFactor() const {
17201720
return cloned;
17211721
}
17221722

1723+
bool IterDomain::isTrivialReduction() const {
1724+
if (!isReduction()) {
1725+
return false;
1726+
}
1727+
1728+
if (extent()->isOneInt()) {
1729+
return true;
1730+
}
1731+
1732+
// If this domain is an output of an expression, i.e., not a root
1733+
// domain, check if all root domains are trivial reductions. This is
1734+
// almost the same as the analysis done in TrivialReductionInfo, but
1735+
// is limited within a single tensor, whereas TrivialReductionInfo
1736+
// does more expensive analysis potentially traversing through
1737+
// rfactor domains
1738+
if (definition()) {
1739+
// Note: There's no const version of IterVisitor.
1740+
auto id_inputs = InputsOf::output(fusion(), const_cast<IterDomain*>(this));
1741+
if (std::all_of(
1742+
ir_utils::filterByType<IterDomain>(id_inputs).begin(),
1743+
ir_utils::filterByType<IterDomain>(id_inputs).end(),
1744+
[](IterDomain* root_id) {
1745+
return root_id->isReduction() && root_id->extent()->isOneInt();
1746+
})) {
1747+
return true;
1748+
}
1749+
}
1750+
1751+
return false;
1752+
}
1753+
17231754
std::vector<IterDomain*> IterDomain::clone(
17241755
const std::vector<IterDomain*>& domains) {
17251756
std::vector<IterDomain*> cloned_domains;
@@ -1744,7 +1775,11 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
17441775
outer->isReduction() == inner->isReduction() ||
17451776
(!outer->isReduction() && inner->isTrivialReduction()) ||
17461777
(outer->isTrivialReduction() && !inner->isReduction()),
1747-
"Merging IterDomains requires that their iteration types match.");
1778+
"Merging IterDomains requires that their iteration types match. ",
1779+
"Outer: ",
1780+
outer->toString(),
1781+
", Inner: ",
1782+
inner->toString());
17481783
TORCH_CHECK(
17491784
(outer->isGather() && inner->isGather()) ||
17501785
(!outer->isGather() && !inner->isGather()),

torch/csrc/jit/codegen/cuda/scheduler/registry.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,24 @@ void SchedulerRuntimeInfo::initialize(
463463
auto fusion_inp = complete_fusion_->inputs()[inp_i];
464464
auto data_ptr = tensor_arg_abstract->getPointer();
465465
input_ptrs_[fusion_inp] = (size_t)data_ptr;
466+
467+
// find and push discontiguous stride
468+
auto dtype_size = dataTypeSize(tensor_arg_abstract->getDataType());
469+
input_discontig_strides_[fusion_inp] = {};
470+
auto dims = tensor_arg_abstract->getRank();
471+
auto expected_stride = 1;
472+
for (auto dim = dims - 1; dim >= 0; dim--) {
473+
auto size = tensor_arg_abstract->getSize(dim);
474+
if (size <= 1) {
475+
continue;
476+
}
477+
auto stride = tensor_arg_abstract->getStride(dim);
478+
if (stride != expected_stride) {
479+
input_discontig_strides_[fusion_inp].push_back(stride * dtype_size);
480+
expected_stride = stride;
481+
}
482+
expected_stride *= size;
483+
}
466484
}
467485
}
468486

@@ -529,6 +547,13 @@ size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) {
529547
}
530548

531549
auto alignment_size = SchedulerRuntimeInfo::computeAlignmentSize(ptrOf(tv));
550+
auto strides_it = input_discontig_strides_.find(tv);
551+
if (strides_it != input_discontig_strides_.end()) {
552+
for (auto stride : strides_it->second) {
553+
alignment_size = std::min(
554+
alignment_size, SchedulerRuntimeInfo::computeAlignmentSize(stride));
555+
}
556+
}
532557
alignment_map_[tv] = alignment_size;
533558
return alignment_size;
534559
}

torch/csrc/jit/codegen/cuda/scheduler/registry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ExpressionEvaluator;
2727
//! segmenter and schedulers.
2828
//! It is important that input id encoding should be up to date with any change
2929
//! of this class to avoid launching compiled kernels with illegal inputs.
30+
3031
class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
3132
public:
3233
// Max vector size we will consider, in bytes,
@@ -112,6 +113,9 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
112113
// TODO: Support output tensor pointers
113114
std::unordered_map<Val*, size_t> input_ptrs_;
114115

116+
// Copy of aten input tensor strides (in bytes)
117+
std::unordered_map<Val*, std::vector<size_t>> input_discontig_strides_;
118+
115119
// Cache for getAlignmentSize
116120
std::unordered_map<TensorView*, size_t> alignment_map_;
117121
// Cache for getMaxVectorizableWidth

torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,6 @@ size_t collectMaxVectorizeSizeWithContigMerge(
8282
size_t max_vector_size_in_byte,
8383
ExpressionEvaluator& expression_evaluator,
8484
DataType index_type) {
85-
// Maybe too conservative, but only handles fully contiguous tensors
86-
// TODO: Relax the contiguity constraint to be similar to that in index
87-
// computing. Just looking for all merged root domains in the right order,
88-
// all merged root dimensions are contiguous, all merged root dimensions are
89-
// next to eachother (exlcuding broadcast).
90-
if (std::any_of(
91-
tv->domain()->contiguity().begin(),
92-
tv->domain()->contiguity().end(),
93-
[](const auto contig) { return !contig; })) {
94-
return 1;
95-
}
96-
9785
auto dtype_size = dataTypeSize(tv->dtype(), index_type);
9886
const size_t max_vector_size = max_vector_size_in_byte / dtype_size;
9987

@@ -205,8 +193,16 @@ size_t expandVectorizationToContigMergedDomains(
205193

206194
// Merge the domains right of the break point
207195
const auto& ref_root = reference_tv->getMaybeRFactorDomain();
208-
const int num_merged_domains =
196+
const int max_num_merged_domains =
209197
static_cast<int>(ref_root.size()) - static_cast<int>(break_point);
198+
int64_t num_merged_domains = 0;
199+
while (num_merged_domains < max_num_merged_domains) {
200+
auto pos = (int64_t)ref_root.size() - 1 - num_merged_domains;
201+
if (!reference_tv->domain()->contiguity()[pos]) {
202+
break;
203+
}
204+
num_merged_domains++;
205+
}
210206

211207
// No expansion with no merged domain
212208
if (num_merged_domains == 0) {
@@ -245,14 +241,16 @@ size_t expandVectorizationToContigMergedDomains(
245241
const auto& tv_root = tv->getMaybeRFactorDomain();
246242

247243
int tv_num_merged_domains = 0;
248-
for (const auto i : c10::irange(num_merged_domains)) {
244+
for (const auto i : c10::irange(max_num_merged_domains)) {
249245
if (i == tv_root.size()) {
250246
break;
251247
}
252248
auto ref_id = ref_root.at(ref_root.size() - 1 - i);
253-
IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i);
249+
auto pos = tv_root.size() - 1 - i;
250+
IterDomain* tv_id = tv_root.at(pos);
254251
// If not mapped, stop expanding.
255-
if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) {
252+
if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT) ||
253+
!tv->domain()->contiguity()[pos]) {
256254
break;
257255
} else {
258256
++tv_num_merged_domains;

0 commit comments

Comments
 (0)