Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/autograd/functions/comm.cpp",
"torch/csrc/jit/codegen/cuda/arith.cpp",
"torch/csrc/jit/codegen/cuda/compute_at.cpp",
"torch/csrc/jit/codegen/cuda/inline_propagator.cpp",
"torch/csrc/jit/codegen/cuda/inlining.cpp",
"torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
"torch/csrc/jit/codegen/cuda/codegen.cpp",
"torch/csrc/jit/codegen/cuda/contiguity.cpp",
Expand Down
13 changes: 4 additions & 9 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,20 +213,17 @@ void ComputeAt::runAt(
auto selected = getPropagationSubgraph(producer, consumer);
ComputeAtSelector selector(selected);

InlinePropagator inline_propagator(
consumer, consumer_position, mode, selector.selected());

MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);

if (mode == ComputeAtMode::MostInlined) {
MostInlinedTransformPropagator propagator;
path.traverse(&propagator);
inlineMost(selected);
} else {
TransformPropagator propagator(consumer, consumer_position);
path.traverse(&propagator);
inlineSelectedAt(selected, consumer, consumer_position, mode == ComputeAtMode::BestEffort);
}

path.traverse(&inline_propagator);
}

void ComputeAt::runWith(
Expand All @@ -253,19 +250,17 @@ void ComputeAt::runWith(
auto selected = getPropagationSubgraph(producer, consumer);
ComputeAtSelector selector(selected);

InlinePropagator inline_propagator(
producer, producer_position, mode, selector.selected());

MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);

if (mode == ComputeAtMode::MostInlined) {
MostInlinedTransformPropagator propagator;
path.traverse(&propagator);
inlineMost(selected);
} else {
TransformPropagator propagator(producer, producer_position);
path.traverse(&propagator);
inlineSelectedAt(selected, producer, producer_position, mode == ComputeAtMode::BestEffort);
}
path.traverse(&inline_propagator);
}

} // namespace cuda
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/compute_at.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
#include <torch/csrc/jit/codegen/cuda/inlining.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>

Expand Down
18 changes: 6 additions & 12 deletions torch/csrc/jit/codegen/cuda/grouped_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool hasMatchingTransformations(TensorView* ref, TensorView* other) {
}

// Validate grouping of reductions and return a new max producer position
unsigned int validateReductionGrouping(
void validateReductionGrouping(
const std::vector<Val*>& inputs,
const std::vector<Val*>& outputs) {
TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size());
Expand All @@ -57,7 +57,6 @@ unsigned int validateReductionGrouping(
const auto num_root_dims = ref_domain.size();
const auto num_dims = ref_tv->nDims();
const auto ref_ca_pos = ref_tv->getComputeAtPosition();
auto max_producer_pos = ref_tv->getMaxProducerPosition();
for (const auto i : c10::irange(inputs.size())) {
auto output_tv = outputs.at(i)->as<TensorView>();
const auto& output_domain = output_tv->getRootDomain();
Expand Down Expand Up @@ -136,9 +135,6 @@ unsigned int validateReductionGrouping(
ref_tv->toString(),
". Mismatched tensor: ",
output_tv->toString());

max_producer_pos =
std::max(max_producer_pos, output_tv->getMaxProducerPosition());
}

// Must not have any data dependency from outputs to inputs
Expand All @@ -152,8 +148,6 @@ unsigned int validateReductionGrouping(
}
TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str());
}

return max_producer_pos;
}

} // namespace
Expand Down Expand Up @@ -194,14 +188,14 @@ void groupReductions(const std::vector<TensorView*>& reduction_outputs) {
inputs.at(i) = rop->in();
}

auto max_producer_pos = validateReductionGrouping(inputs, outputs);

for (auto output : ir_utils::filterByType<TensorView>(outputs)) {
output->setMaxProducer(max_producer_pos);
}
validateReductionGrouping(inputs, outputs);

IrBuilder::create<GroupedReductionOp>(
container, op_types, init_vals, outputs, inputs);

for (auto output : ir_utils::filterByType<TensorView>(outputs)) {
output->updateMaxProducerPosition();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks cleaner than before. Thanks for the refactoring.

}
}

} // namespace cuda
Expand Down
Loading