@@ -45,18 +45,7 @@ class DomainMap : public pointwise_utils::DomainMap {
45
45
return result;
46
46
}
47
47
48
- static bool hasAtLeastTwoValidGroups (Fusion* fusion) {
49
- FusionGuard fg (fusion);
50
- DomainMap domain_map (fusion);
51
- auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim ();
52
- if (grouped_inputs_outputs.size () < 2 ) {
53
- return false ;
54
- }
55
- return domain_map.findReferenceFor (grouped_inputs_outputs[0 ]) != nullptr &&
56
- domain_map.findReferenceFor (grouped_inputs_outputs[1 ]) != nullptr ;
57
- }
58
-
59
- int getInnerLeafDim (TensorView* tv, IterDomain* root_dim) const {
48
+ IterDomain* getMappedRootDimIn (TensorView* tv, IterDomain* root_dim) const {
60
49
// Find the root id mapped to `root_dim`
61
50
const auto & root_dom = tv->getRootDomain ();
62
51
IterDomain* mapped_id = nullptr ;
@@ -67,6 +56,29 @@ class DomainMap : public pointwise_utils::DomainMap {
67
56
break ;
68
57
}
69
58
}
59
+ return mapped_id;
60
+ }
61
+
62
+ static bool hasAtLeastTwoValidGroups (Fusion* fusion) {
63
+ FusionGuard fg (fusion);
64
+ DomainMap domain_map (fusion);
65
+ auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim ();
66
+ if (grouped_inputs_outputs.size () < 2 ) {
67
+ return false ;
68
+ }
69
+ auto ref1 = domain_map.findReferenceFor (grouped_inputs_outputs[0 ]);
70
+ auto ref2 = domain_map.findReferenceFor (grouped_inputs_outputs[1 ]);
71
+ if (ref1 == nullptr || ref2 == nullptr ) {
72
+ return false ;
73
+ }
74
+ // reference 1 is the global reference, so it must have dim mapped the
75
+ // innermost dim of both groups
76
+ auto innermost2 = scheduler_utils::innerMostRootDim (ref2);
77
+ return domain_map.getMappedRootDimIn (ref1, innermost2) != nullptr ;
78
+ }
79
+
80
+ int getInnerLeafDim (TensorView* tv, IterDomain* root_dim) const {
81
+ auto mapped_id = getMappedRootDimIn (tv, root_dim);
70
82
TORCH_INTERNAL_ASSERT (
71
83
mapped_id != nullptr ,
72
84
" Can not find ID mapped to " ,
0 commit comments