Skip to content

Commit e842a9b

Browse files
authored
Minor cleanup in pointwise scheduler (#1858)
1 parent 9ee850c commit e842a9b

File tree

1 file changed

+22
-46
lines changed

1 file changed

+22
-46
lines changed

torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp

Lines changed: 22 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@ namespace {
2727
// Unused at the moment, commenting for clang tidy
2828
constexpr int64_t kThreadX = 128;
2929

30+
// Returns number of non-reduction/non-broadcast dims in rfactor domain
31+
size_t nRootDims(const TensorView* tv) {
32+
auto root_dom = tv->getMaybeRFactorDomain();
33+
size_t tv_n_dims = 0;
34+
for (auto dim : root_dom) {
35+
if (!dim->isReduction() && !dim->isBroadcast()) {
36+
tv_n_dims++;
37+
}
38+
}
39+
return tv_n_dims;
40+
}
41+
3042
// DomainMap uses the ComputeAtMap to find a reference TensorView
3143
// that maps to all iterDomains in the fusion.
3244
class DomainMap {
@@ -38,15 +50,21 @@ class DomainMap {
3850
// The pointwise scheduler heuristics requires a minimum number of axes.
3951
// The output reference tensor should respect this requirement.
4052
TensorView* findReferenceTensorView(int minimum_num_axes = 0) const {
53+
TensorView* result = nullptr;
54+
int max_dims = -1;
4155
for (auto output_tv :
4256
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
4357
if (isValidReference(output_tv) &&
4458
hasMinimumSize(output_tv, minimum_num_axes) &&
4559
!output_tv->isFusionInput()) {
46-
return output_tv;
60+
int n_dims = nRootDims(output_tv);
61+
if (n_dims > max_dims) {
62+
result = output_tv;
63+
max_dims = n_dims;
64+
}
4765
}
4866
}
49-
return nullptr;
67+
return result;
5068
}
5169

5270
static bool hasReferenceTensorView(Fusion* fusion) {
@@ -187,35 +205,11 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
187205
// Incase any buffer is of type DataType::Index
188206
DataType index_type = indexModeToDtype(runtime_info.getIndexMode());
189207

190-
TensorView* largest_out = nullptr;
191-
int max_dims = -1;
192-
193208
auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
194-
// Will want to access this with direct indexing later, convert now.
195-
std::vector<TensorView*> out_tvs;
196-
// Only use valid reference tensors during heuristics analysis
209+
197210
DomainMap domain_map(fusion);
198-
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
199-
if (domain_map.isValidReference(out_tv)) {
200-
out_tvs.push_back(out_tv);
201-
}
202-
}
203-
TORCH_INTERNAL_ASSERT(
204-
!out_tvs.empty(), "No valid reference outputs were found!");
205211

206-
for (auto out_tv : out_tvs) {
207-
int n_dims = 0;
208-
for (auto id : out_tv->getMaybeRFactorDomain()) {
209-
if (id->isReduction() || id->isBroadcast()) {
210-
continue;
211-
}
212-
n_dims++;
213-
}
214-
if (n_dims > max_dims) {
215-
largest_out = out_tv;
216-
max_dims = n_dims;
217-
}
218-
}
212+
TensorView* largest_out = domain_map.findReferenceTensorView();
219213

220214
TORCH_INTERNAL_ASSERT(largest_out != nullptr);
221215

@@ -224,15 +218,12 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
224218

225219
// TODO: Set to 1?
226220
int64_t max_input_dtype_size = 2;
227-
size_t n_tensors = 0;
228221

229222
for (auto inp : in_tvs) {
230223
max_input_dtype_size = std::max(
231224
max_input_dtype_size,
232225
(int64_t)dataTypeSize(inp->getDataType().value(), index_type));
233-
n_tensors++;
234226
}
235-
n_tensors += std::distance(out_tvs.begin(), out_tvs.end());
236227

237228
auto ref_root = largest_out->getMaybeRFactorDomain();
238229
std::vector<int64_t> elem_counts(ref_root.size(), 1);
@@ -533,7 +524,6 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
533524
std::cerr << "\n===== Pointwise Stats ========\n"
534525
<< "num_elems: " << n_elems << "\n"
535526
<< "elem_counts: " << elem_counts << "\n"
536-
<< "n_tensor_inputs: " << n_tensors << "\n"
537527
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
538528
<< "vectorize_factor: " << vectorize_factor << std::endl;
539529
std::cerr << "broadcast_byte_multiples: ";
@@ -563,20 +553,6 @@ LaunchParams schedulePointwise(
563553
return params.value().lparams;
564554
}
565555

566-
namespace {
567-
// Returns number of non-reduction/non-broadcast dims in rfactor domain
568-
size_t nRootDims(const TensorView* tv) {
569-
auto root_dom = tv->getMaybeRFactorDomain();
570-
size_t tv_n_dims = 0;
571-
for (auto dim : root_dom) {
572-
if (!dim->isReduction() && !dim->isBroadcast()) {
573-
tv_n_dims++;
574-
}
575-
}
576-
return tv_n_dims;
577-
}
578-
} // namespace
579-
580556
bool hasReferenceTensorView(Fusion* fusion) {
581557
return DomainMap::hasReferenceTensorView(fusion);
582558
}

0 commit comments

Comments
 (0)