@@ -27,6 +27,18 @@ namespace {
27
27
// Unused at the moment, commenting for clang tidy
28
28
constexpr int64_t kThreadX = 128 ;
29
29
30
+ // Returns number of non-reduction/non-broadcast dims in rfactor domain
31
+ size_t nRootDims (const TensorView* tv) {
32
+ auto root_dom = tv->getMaybeRFactorDomain ();
33
+ size_t tv_n_dims = 0 ;
34
+ for (auto dim : root_dom) {
35
+ if (!dim->isReduction () && !dim->isBroadcast ()) {
36
+ tv_n_dims++;
37
+ }
38
+ }
39
+ return tv_n_dims;
40
+ }
41
+
30
42
// DomainMap uses the ComputeAtMap to find a reference TensorView
31
43
// that maps to all iterDomains in the fusion.
32
44
class DomainMap {
@@ -38,15 +50,21 @@ class DomainMap {
38
50
// The pointwise scheduler heuristics requires a minimum number of axes.
39
51
// The output reference tensor should respect this requirement.
40
52
TensorView* findReferenceTensorView (int minimum_num_axes = 0 ) const {
53
+ TensorView* result = nullptr ;
54
+ int max_dims = -1 ;
41
55
for (auto output_tv :
42
56
ir_utils::filterByType<TensorView>(fusion_->outputs ())) {
43
57
if (isValidReference (output_tv) &&
44
58
hasMinimumSize (output_tv, minimum_num_axes) &&
45
59
!output_tv->isFusionInput ()) {
46
- return output_tv;
60
+ int n_dims = nRootDims (output_tv);
61
+ if (n_dims > max_dims) {
62
+ result = output_tv;
63
+ max_dims = n_dims;
64
+ }
47
65
}
48
66
}
49
- return nullptr ;
67
+ return result ;
50
68
}
51
69
52
70
static bool hasReferenceTensorView (Fusion* fusion) {
@@ -187,35 +205,11 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
187
205
// Incase any buffer is of type DataType::Index
188
206
DataType index_type = indexModeToDtype (runtime_info.getIndexMode ());
189
207
190
- TensorView* largest_out = nullptr ;
191
- int max_dims = -1 ;
192
-
193
208
auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs ());
194
- // Will want to access this with direct indexing later, convert now.
195
- std::vector<TensorView*> out_tvs;
196
- // Only use valid reference tensors during heuristics analysis
209
+
197
210
DomainMap domain_map (fusion);
198
- for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs ())) {
199
- if (domain_map.isValidReference (out_tv)) {
200
- out_tvs.push_back (out_tv);
201
- }
202
- }
203
- TORCH_INTERNAL_ASSERT (
204
- !out_tvs.empty (), " No valid reference outputs were found!" );
205
211
206
- for (auto out_tv : out_tvs) {
207
- int n_dims = 0 ;
208
- for (auto id : out_tv->getMaybeRFactorDomain ()) {
209
- if (id->isReduction () || id->isBroadcast ()) {
210
- continue ;
211
- }
212
- n_dims++;
213
- }
214
- if (n_dims > max_dims) {
215
- largest_out = out_tv;
216
- max_dims = n_dims;
217
- }
218
- }
212
+ TensorView* largest_out = domain_map.findReferenceTensorView ();
219
213
220
214
TORCH_INTERNAL_ASSERT (largest_out != nullptr );
221
215
@@ -224,15 +218,12 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
224
218
225
219
// TODO: Set to 1?
226
220
int64_t max_input_dtype_size = 2 ;
227
- size_t n_tensors = 0 ;
228
221
229
222
for (auto inp : in_tvs) {
230
223
max_input_dtype_size = std::max (
231
224
max_input_dtype_size,
232
225
(int64_t )dataTypeSize (inp->getDataType ().value (), index_type));
233
- n_tensors++;
234
226
}
235
- n_tensors += std::distance (out_tvs.begin (), out_tvs.end ());
236
227
237
228
auto ref_root = largest_out->getMaybeRFactorDomain ();
238
229
std::vector<int64_t > elem_counts (ref_root.size (), 1 );
@@ -533,7 +524,6 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
533
524
std::cerr << " \n ===== Pointwise Stats ========\n "
534
525
<< " num_elems: " << n_elems << " \n "
535
526
<< " elem_counts: " << elem_counts << " \n "
536
- << " n_tensor_inputs: " << n_tensors << " \n "
537
527
<< " max_input_dtype_size: " << max_input_dtype_size << " \n "
538
528
<< " vectorize_factor: " << vectorize_factor << std::endl;
539
529
std::cerr << " broadcast_byte_multiples: " ;
@@ -563,20 +553,6 @@ LaunchParams schedulePointwise(
563
553
return params.value ().lparams ;
564
554
}
565
555
566
- namespace {
567
- // Returns number of non-reduction/non-broadcast dims in rfactor domain
568
- size_t nRootDims (const TensorView* tv) {
569
- auto root_dom = tv->getMaybeRFactorDomain ();
570
- size_t tv_n_dims = 0 ;
571
- for (auto dim : root_dom) {
572
- if (!dim->isReduction () && !dim->isBroadcast ()) {
573
- tv_n_dims++;
574
- }
575
- }
576
- return tv_n_dims;
577
- }
578
- } // namespace
579
-
580
556
bool hasReferenceTensorView (Fusion* fusion) {
581
557
return DomainMap::hasReferenceTensorView (fusion);
582
558
}
0 commit comments