Skip to content

Commit 306d4a6

Browse files
authored
Fix canScheduleCompileTime check of transpose scheduler (#1969)
1 parent b1bd32c commit 306d4a6

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,7 @@ class DomainMap : public pointwise_utils::DomainMap {
4545
return result;
4646
}
4747

48-
static bool hasAtLeastTwoValidGroups(Fusion* fusion) {
49-
FusionGuard fg(fusion);
50-
DomainMap domain_map(fusion);
51-
auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim();
52-
if (grouped_inputs_outputs.size() < 2) {
53-
return false;
54-
}
55-
return domain_map.findReferenceFor(grouped_inputs_outputs[0]) != nullptr &&
56-
domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr;
57-
}
58-
59-
int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
48+
IterDomain* getMappedRootDimIn(TensorView* tv, IterDomain* root_dim) const {
6049
// Find the root id mapped to `root_dim`
6150
const auto& root_dom = tv->getRootDomain();
6251
IterDomain* mapped_id = nullptr;
@@ -67,6 +56,29 @@ class DomainMap : public pointwise_utils::DomainMap {
6756
break;
6857
}
6958
}
59+
return mapped_id;
60+
}
61+
62+
static bool hasAtLeastTwoValidGroups(Fusion* fusion) {
63+
FusionGuard fg(fusion);
64+
DomainMap domain_map(fusion);
65+
auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim();
66+
if (grouped_inputs_outputs.size() < 2) {
67+
return false;
68+
}
69+
auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]);
70+
auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]);
71+
if (ref1 == nullptr || ref2 == nullptr) {
72+
return false;
73+
}
74+
// reference 1 is the global reference, so it must have dim mapped the
75+
// innermost dim of both groups
76+
auto innermost2 = scheduler_utils::innerMostRootDim(ref2);
77+
return domain_map.getMappedRootDimIn(ref1, innermost2) != nullptr;
78+
}
79+
80+
int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
81+
auto mapped_id = getMappedRootDimIn(tv, root_dim);
7082
TORCH_INTERNAL_ASSERT(
7183
mapped_id != nullptr,
7284
"Can not find ID mapped to ",

0 commit comments

Comments
 (0)