Skip to content

Commit 03180aa

Browse files
authored
improve broadcast resolution (#1792)
1 parent bee6c69 commit 03180aa

File tree

1 file changed

+70
-50
lines changed

1 file changed

+70
-50
lines changed

torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,79 @@ namespace {
2121

2222
//! Checks that the current loop nest is not realizing a serial
2323
//! broadcast so that each index of producer buffer will only
24-
//! be visited once.
25-
//! TODO: should refactor this utility now to use loop maps in a
26-
//! follow up.
24+
//! be visited once, which is the only case where aggressive
25+
//! inner sharing is valid.
26+
//!
2727
bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
28-
auto producer_root =
29-
TensorDomain::noReductions(producer->getMaybeRFactorDomain());
30-
auto consumer_root =
31-
TensorDomain::noReductions(consumer->getMaybeRFactorDomain());
32-
33-
if (producer_root.size() != consumer_root.size()) {
34-
// This case would be a single broadcast or a single reduce
35-
// which wouldn't be a broadcast resolution
36-
return false;
28+
//! Note: see issue #1785:
29+
//! serial broadcast resolution doesn't only happen to
30+
//! immediate producers of broadcast ops. We can also have
31+
//! example:
32+
//! T1[I,B] = broadcast(T0[I]])
33+
//! T3[I,I] = T1[I,B] + T2[I,I]
34+
//! T4[I,I] = T3[I,I]
35+
//! and generates the following loop:
36+
//! alloc T0[4]
37+
//! For i in 0..3
38+
//! T0[...] =
39+
//!
40+
//! For j in 0...X:
41+
//! alloc T3[4]
42+
//! for k in 0..3:
43+
//! alloc T1[1]
44+
//! T1[0] = T0[k] // <- This is actually a broadcast resolution
45+
//! T3[k] = T1[0] + T2[...]
46+
//! T4[...] = T3[...]
47+
//!
48+
//! In this case we are actually visiting each pixel of T0 in each iteration
49+
//! of the j loop while T1 was the broadcasted tensor causing this reuse.
50+
//!
51+
//! The current version of checking covers this scenario by checking the root
52+
//! ids of the consumer concrete loop id's. Any time a local tensor like T0
53+
//! appears in a re-use scenario like above, we should see a serial loop id
54+
//! that was derived from some root id that doesn't concretely map to T0's
55+
//! domain.
56+
57+
// Serial concrete loop id's that cover consumer's iter domain.
58+
std::vector<Val*> consumer_serial_loop_concrete_ids;
59+
60+
for (auto consumer_leaf_id : consumer->domain()->domain()) {
61+
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
62+
consumer_leaf_id, IdMappingMode::LOOP);
63+
64+
// Check for any serial loop id with non-trivial extent
65+
if (!concrete_loop_id->isThread() &&
66+
!concrete_loop_id->extent()->isOneInt()) {
67+
consumer_serial_loop_concrete_ids.push_back(concrete_loop_id);
68+
}
3769
}
3870

39-
std::vector<Val*> serial_ids;
40-
std::copy_if(
41-
producer->domain()->domain().begin(),
42-
producer->domain()->domain().end(),
43-
std::back_inserter(serial_ids),
44-
[](IterDomain* id) { return !id->isThread(); });
45-
46-
auto serial_producer_roots =
47-
InputsOf::outputs(FusionGuard::getCurFusion(), serial_ids);
48-
auto serial_root_id =
49-
ir_utils::filterByType<IterDomain>(serial_producer_roots);
50-
std::unordered_set<IterDomain*> serial_producer_root_set(
51-
serial_root_id.begin(), serial_root_id.end());
52-
53-
for (const auto idx : c10::irange(producer_root.size())) {
54-
if (producer_root[idx]->isBroadcast() &&
55-
!consumer_root[idx]->isBroadcast()) {
56-
// Check if this broadcast contributed to any serial
57-
// scheduled iterdomains:
58-
if (serial_producer_root_set.count(producer_root[idx])) {
59-
return true;
60-
}
71+
// Collect the root id's that the serial loop iterdomain
72+
// are transformed from.
73+
auto serial_loop_roots = InputsOf::outputs(
74+
FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids);
75+
76+
// Collect exact concrete id's in producer's root domain
77+
std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
78+
auto producer_root =
79+
TensorDomain::noReductions(producer->getMaybeRFactorDomain());
80+
std::transform(
81+
producer_root.begin(),
82+
producer_root.end(),
83+
std::inserter(
84+
producer_exact_concrete_root_ids,
85+
producer_exact_concrete_root_ids.begin()),
86+
ir_utils::caMapExactConcreteId);
87+
88+
// Check if serial loop roots indexes any exact root id's that
89+
// is not within the set of producer's root exact id's. These
90+
// id's will imply that the same producer pixel is accessed
91+
// in multiple iterations of the materialized serial loop.
92+
for (auto serial_loop_root :
93+
ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
94+
if (!producer_exact_concrete_root_ids.count(
95+
ir_utils::caMapExactConcreteId(serial_loop_root))) {
96+
return true;
6197
}
6298
}
6399

@@ -998,7 +1034,6 @@ class AllocateReuseModifier {
9981034
struct InPlaceSharingInfo {
9991035
bool has_broadcast_between = false;
10001036
bool has_unsupported_op = false;
1001-
bool has_serial_broadcast_resolution_between = false;
10021037
};
10031038

10041039
//! Careful heavy check on inner sharing candidates,
@@ -1044,13 +1079,6 @@ class AllocateReuseModifier {
10441079
return false;
10451080
}
10461081

1047-
// TODO: blanket disable reuse across broadcast concretization
1048-
// to unblock issue for now.
1049-
// Should improve the precision of this analysis in a follow up.
1050-
if (topo_info.has_serial_broadcast_resolution_between) {
1051-
return false;
1052-
}
1053-
10541082
// Get information on the allocated domains of the
10551083
// two buffers
10561084
auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap();
@@ -1103,14 +1131,6 @@ class AllocateReuseModifier {
11031131
info.has_unsupported_op = true;
11041132
}
11051133
}
1106-
1107-
for (auto in_tv :
1108-
ir_utils::filterByType<TensorView>(tv_def->inputs())) {
1109-
if (all_used_val_set.count(in_tv) &&
1110-
isSerialBroadcastResolution(in_tv, tv)) {
1111-
info.has_serial_broadcast_resolution_between = true;
1112-
}
1113-
}
11141134
}
11151135
}
11161136
return info;

0 commit comments

Comments
 (0)