Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ ComputeAtData::ComputeAtData(TensorView* tv)
: tv_ref_(tv),
original_has_compute_at_(tv->hasComputeAt()),
original_compute_at_position(tv->getThisComputeAtAxis()),
original_domain_(tv->domain()) {}
original_domain_(tv->domain()),
new_compute_at_domain_(tv->domain()) {}

// Clear pass based data
void ComputeAtData::clearPass() {
Expand Down Expand Up @@ -187,16 +188,17 @@ unsigned int ComputeAt::backwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
unsigned int consumer_compute_at_axis) {
auto& entry = tv_data.at(producer);
auto& producer_entry = tv_data.at(producer);

// Use TensorDomain interface so it doesn't set computeAt automatically
auto replay = TransformReplay::replayPasC(
producer, consumer, (int)consumer_compute_at_axis);

entry.setPassPosition(replay.second);
producer_entry.setPassPosition(replay.second);

if (entry.shouldSetComputeAt(replay.second)) {
if (producer_entry.shouldSetComputeAt(replay.second)) {
producer->setComputeAt(consumer, (int)consumer_compute_at_axis);
producer_entry.setComputeAtDomain(producer->domain());
}

return replay.second;
Expand All @@ -213,11 +215,17 @@ unsigned int ComputeAt::forwardComputeAt_impl(
auto replay = TransformReplay::replayCasP(
consumer, producer, (int)producer_compute_at_axis);

consumer_entry.setPassPosition(replay.second);
if (producer_entry.shouldSetComputeAt(producer_compute_at_axis)) {
producer->setComputeAt(consumer, replay.second);
}

consumer_entry.setPassPosition(replay.second);
if ((consumer_entry.shouldSetComputeAt(replay.second) &&
consumer != consumer_) ||
(consumer == consumer_ && replay.second >= consumer_position_)) {
consumer_entry.setComputeAtDomain(consumer->domain());
}

return replay.second;
}

Expand Down Expand Up @@ -359,9 +367,19 @@ void ComputeAt::runPass() {

setupOutputs();

for (const auto entry : tv_data) {
for (const auto& entry : tv_data) {
entry.first->setDomain(entry.second.getComputeAtDomain());
entry.second.validateNewComputeAt();
}

TORCH_INTERNAL_ASSERT(
BestEffortReplay::findFirstMismatchedID(
consumer_->domain(), tv_data.at(consumer_).getOriginalDomain()) ==
consumer_->domain()->nDims(),
"ComputeAt logic changed the consumer domain which should not happen. Domain was ",
tv_data.at(consumer_).getOriginalDomain(),
" but is now: ",
consumer_->domain());
}

void ComputeAt::setupOutputs() {
Expand Down
30 changes: 23 additions & 7 deletions torch/csrc/jit/codegen/cuda/compute_at.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ComputeAtData {
// an invalid compute_at that would require tensor replication.
void setPassPosition(unsigned int pos);

// Returns if new postion is greater or equal to previous seen
// Returns if new postion is greater or equal to previous seen, if
bool shouldSetComputeAt(unsigned int pos) const {
return pos > original_compute_at_position &&
pos > new_compute_at_position && pos >= current_traversal_position;
Expand All @@ -48,14 +48,23 @@ class ComputeAtData {
return touched_;
}

// Traversal domain, public as it can freely be set without impacting any
// other data. Just a convenience to have it included here.
TensorDomain* traversal_domain = nullptr;
TensorDomain* getOriginalDomain() const {
return original_domain_;
}

private:
// Position to update after a traversal
unsigned int new_compute_at_position = 0;
// If we set computeAt, save the domain so we can reset it after traversal.
// Traversal state can deviate from the domain we will want to save after the
// entire computeAt pass.
void setComputeAtDomain(TensorDomain* td) {
new_compute_at_domain_ = td;
}

// Return domain set in setComputeAtDomain
TensorDomain* getComputeAtDomain() const {
return new_compute_at_domain_;
}

private:
// Was the position ever modified?
bool touched_ = false;

Expand All @@ -76,6 +85,13 @@ class ComputeAtData {

// Did this traversal set a position or not yet
bool current_traversal_position_set = false;

// Position to update after a traversal
unsigned int new_compute_at_position = 0;

// Domain when we actually set computeAt, will set back to this after the
// pass.
TensorDomain* new_compute_at_domain_;
};

class ComputeAt {
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner)
is_broadcast_domain_(src->is_broadcast_domain_) {}

bool IterDomain::sameAs(const IterDomain* const other) const {
if (other == this)
return true;

bool is_same = isReduction() == other->isReduction() &&
parallel_method() == other->parallel_method();
is_same = is_same && ScalarCheck::sameAs(extent(), other->extent());
Expand Down
17 changes: 12 additions & 5 deletions torch/csrc/jit/codegen/cuda/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,17 @@ int BestEffortReplay::findFirstMismatchedID(
const TensorDomain* td2) {
std::unordered_map<IterDomain*, IterDomain*> id_map;
auto rd1 = td1->rootDomain();
auto rd2 = td2->rootDomain();
std::unordered_set<IterDomain*> rd2_set(
td2->rootDomain().begin(), td2->rootDomain().end());

// Find matching root IterDomains, we could make this O(nlog(n)) if we could
// sort IterDomains.
for (auto rd1i : rd1) {
for (IterDomain* rd2_id : rd2_set) {
if (rd1i->sameAs(rd2_id)) {
id_map[rd1i] = rd2_id;
rd2_set.erase(rd2_id);
for (auto rd2i : rd2) {
if (rd1i->sameAs(rd2i) && rd2_set.find(rd2i) != rd2_set.end()) {
id_map[rd1i] = rd2i;
rd2_set.erase(rd2i);
break;
}
}
Expand All @@ -387,8 +388,14 @@ int BestEffortReplay::findFirstMismatchedID(
BestEffortReplay ber(td2->domain(), td1->domain(), id_map);

for (size_t i = 0; i < td1->domain().size(); i++) {
if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end())
if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) {
return i;
}
// Order is important.
auto td2_axis = ber.getReplay().at(td1->axis(i));
if (td2->axis(i) != td2_axis) {
return i;
}
}
return td1->nDims();
}
Expand Down