@@ -21,43 +21,79 @@ namespace {
21
21
22
22
// ! Checks that the current loop nest is not realizing a serial
23
23
// ! broadcast so that each index of producer buffer will only
24
- // ! be visited once.
25
- // ! TODO: should refactor this utility now to use loop maps in a
26
- // ! follow up.
24
+ // ! be visited once, which is the only case where aggressive
25
+ // ! inner sharing is valid.
26
+ // !
27
27
bool isSerialBroadcastResolution (TensorView* producer, TensorView* consumer) {
28
- auto producer_root =
29
- TensorDomain::noReductions (producer->getMaybeRFactorDomain ());
30
- auto consumer_root =
31
- TensorDomain::noReductions (consumer->getMaybeRFactorDomain ());
32
-
33
- if (producer_root.size () != consumer_root.size ()) {
34
- // This case would be a single broadcast or a single reduce
35
- // which wouldn't be a broadcast resolution
36
- return false ;
28
+ // ! Note: see issue #1785:
29
+ // ! serial broadcast resolution doesn't only happen to
30
+ // ! immediate producers of broadcast ops. We can also have
31
+ // ! example:
32
+ // ! T1[I,B] = broadcast(T0[I]])
33
+ // ! T3[I,I] = T1[I,B] + T2[I,I]
34
+ // ! T4[I,I] = T3[I,I]
35
+ // ! and generates the following loop:
36
+ // ! alloc T0[4]
37
+ // ! For i in 0..3
38
+ // ! T0[...] =
39
+ // !
40
+ // ! For j in 0...X:
41
+ // ! alloc T3[4]
42
+ // ! for k in 0..3:
43
+ // ! alloc T1[1]
44
+ // ! T1[0] = T0[k] // <- This is actually a broadcast resolution
45
+ // ! T3[k] = T1[0] + T2[...]
46
+ // ! T4[...] = T3[...]
47
+ // !
48
+ // ! In this case we are actually visiting each pixel of T0 in each iteration
49
+ // ! of the j loop while T1 was the broadcasted tensor causing this reuse.
50
+ // !
51
+ // ! The current version of checking covers this scenario by checking the root
52
+ // ! ids of the consumer concrete loop id's. Any time a local tensor like T0
53
+ // ! appears in a re-use scenario like above, we should see a serial loop id
54
+ // ! that was derived from some root id that doesn't concretely map to T0's
55
+ // ! domain.
56
+
57
+ // Serial concrete loop id's that cover consumer's iter domain.
58
+ std::vector<Val*> consumer_serial_loop_concrete_ids;
59
+
60
+ for (auto consumer_leaf_id : consumer->domain ()->domain ()) {
61
+ auto concrete_loop_id = GpuLower::current ()->caMap ()->getConcreteMappedID (
62
+ consumer_leaf_id, IdMappingMode::LOOP);
63
+
64
+ // Check for any serial loop id with non-trivial extent
65
+ if (!concrete_loop_id->isThread () &&
66
+ !concrete_loop_id->extent ()->isOneInt ()) {
67
+ consumer_serial_loop_concrete_ids.push_back (concrete_loop_id);
68
+ }
37
69
}
38
70
39
- std::vector<Val*> serial_ids;
40
- std::copy_if (
41
- producer->domain ()->domain ().begin (),
42
- producer->domain ()->domain ().end (),
43
- std::back_inserter (serial_ids),
44
- [](IterDomain* id) { return !id->isThread (); });
45
-
46
- auto serial_producer_roots =
47
- InputsOf::outputs (FusionGuard::getCurFusion (), serial_ids);
48
- auto serial_root_id =
49
- ir_utils::filterByType<IterDomain>(serial_producer_roots);
50
- std::unordered_set<IterDomain*> serial_producer_root_set (
51
- serial_root_id.begin (), serial_root_id.end ());
52
-
53
- for (const auto idx : c10::irange (producer_root.size ())) {
54
- if (producer_root[idx]->isBroadcast () &&
55
- !consumer_root[idx]->isBroadcast ()) {
56
- // Check if this broadcast contributed to any serial
57
- // scheduled iterdomains:
58
- if (serial_producer_root_set.count (producer_root[idx])) {
59
- return true ;
60
- }
71
+ // Collect the root id's that the serial loop iterdomain
72
+ // are transformed from.
73
+ auto serial_loop_roots = InputsOf::outputs (
74
+ FusionGuard::getCurFusion (), consumer_serial_loop_concrete_ids);
75
+
76
+ // Collect exact concrete id's in producer's root domain
77
+ std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
78
+ auto producer_root =
79
+ TensorDomain::noReductions (producer->getMaybeRFactorDomain ());
80
+ std::transform (
81
+ producer_root.begin (),
82
+ producer_root.end (),
83
+ std::inserter (
84
+ producer_exact_concrete_root_ids,
85
+ producer_exact_concrete_root_ids.begin ()),
86
+ ir_utils::caMapExactConcreteId);
87
+
88
+ // Check if serial loop roots indexes any exact root id's that
89
+ // is not within the set of producer's root exact id's. These
90
+ // id's will imply that the same producer pixel is accessed
91
+ // in multiple iterations of the materialized serial loop.
92
+ for (auto serial_loop_root :
93
+ ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
94
+ if (!producer_exact_concrete_root_ids.count (
95
+ ir_utils::caMapExactConcreteId (serial_loop_root))) {
96
+ return true ;
61
97
}
62
98
}
63
99
@@ -998,7 +1034,6 @@ class AllocateReuseModifier {
998
1034
struct InPlaceSharingInfo {
999
1035
bool has_broadcast_between = false ;
1000
1036
bool has_unsupported_op = false ;
1001
- bool has_serial_broadcast_resolution_between = false ;
1002
1037
};
1003
1038
1004
1039
// ! Careful heavy check on inner sharing candidates,
@@ -1044,13 +1079,6 @@ class AllocateReuseModifier {
1044
1079
return false ;
1045
1080
}
1046
1081
1047
- // TODO: blanket disable reuse across broadcast concretization
1048
- // to unblock issue for now.
1049
- // Should improve the precision of this analysis in a follow up.
1050
- if (topo_info.has_serial_broadcast_resolution_between ) {
1051
- return false ;
1052
- }
1053
-
1054
1082
// Get information on the allocated domains of the
1055
1083
// two buffers
1056
1084
auto & local_alloc_map = GpuLower::current ()->localAllocationInfoMap ();
@@ -1103,14 +1131,6 @@ class AllocateReuseModifier {
1103
1131
info.has_unsupported_op = true ;
1104
1132
}
1105
1133
}
1106
-
1107
- for (auto in_tv :
1108
- ir_utils::filterByType<TensorView>(tv_def->inputs ())) {
1109
- if (all_used_val_set.count (in_tv) &&
1110
- isSerialBroadcastResolution (in_tv, tv)) {
1111
- info.has_serial_broadcast_resolution_between = true ;
1112
- }
1113
- }
1114
1134
}
1115
1135
}
1116
1136
return info;
0 commit comments