Skip to content

Commit 6cf7eb0

Browse files
authored
Transpose scheduler small dim sizes better support (#1910)
1 parent 9341ea9 commit 6cf7eb0

13 files changed

+649
-81
lines changed

c10/util/hash.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,14 @@ struct hash<std::tuple<Types...>> {
304304
}
305305
};
306306

307+
template <typename T1, typename T2>
308+
struct hash<std::pair<T1, T2>> {
309+
size_t operator()(const std::pair<T1, T2>& pair) const {
310+
std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
311+
return _hash_detail::simple_get_hash(tuple);
312+
};
313+
};
314+
307315
template <typename T>
308316
struct hash<c10::ArrayRef<T>> {
309317
size_t operator()(c10::ArrayRef<T> v) const {

test/cpp/jit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ if(USE_CUDA)
103103
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp)
104104
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
105105
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
106+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_scheduler_utils.cpp)
106107
endif()
107108

108109
add_executable(test_jit

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,7 @@ void TensorDomain::split(
17651765
resetDomains();
17661766
}
17671767

1768-
// Merge "axis" and "axis+1" into 1 dimension
1768+
// Merge "axis_o" and "axis_i" into 1 dimension
17691769
void TensorDomain::merge(int axis_o, int axis_i) {
17701770
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
17711771
if (axis_o < 0)

torch/csrc/jit/codegen/cuda/scheduler/registry.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ class ReductionScheduler : public SchedulerEntry {
829829

830830
//! Check if the reduction heuristics apply in given fusion
831831
static bool canScheduleCompileTime(Fusion* fusion) {
832-
// Temporarily allow view in reduction scheduler
832+
// Temporarily disallow view in reduction scheduler
833833
// TODO Add more testing before enabling
834834
auto view_tvs = scheduler_utils::getViewTVs(fusion);
835835
if (view_tvs.size() > 0) {
@@ -1259,6 +1259,15 @@ class TransposeScheduler : public SchedulerEntry {
12591259
// Not enabling this yet. Needs more validation.
12601260
return false;
12611261
#if 0
1262+
// Temporarily disallow view in transpose scheduler
1263+
// TODO Add more testing before enabling
1264+
auto view_tvs = scheduler_utils::getViewTVs(fusion);
1265+
if (view_tvs.size() > 0) {
1266+
scheduler_debug_utils::canScheduleRejectReason(
1267+
ScheduleHeuristic::Reduction, "No support for view op");
1268+
return false;
1269+
}
1270+
12621271
if (!hasAtLeastTwoValidGroups(fusion)) {
12631272
scheduler_debug_utils::canScheduleRejectReason(
12641273
ScheduleHeuristic::Transpose,

0 commit comments

Comments
 (0)