Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,46 @@ class TORCH_CUDA_API TensorDomain : public Val {
static bool hasBroadcast(const std::vector<IterDomain*>&);
static bool hasReduction(const std::vector<IterDomain*>&);

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i]
// = -1.
static std::vector<int64_t> mapDomainCtoP(
const std::vector<IterDomain*>& consumer,
const std::vector<IterDomain*>& producer);

// Create a map from consumer root IterDomains -> producer root IterDomains.
// Constrain will restrict which consumer root IterDomains we map to the
// producer IterDomains. Only those root consumer IDs present in
// consumer_root_dims_to_map will be attempted to map to their corresponding
// producer IDs.
static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
const TensorDomain* consumer,
const TensorDomain* producer,
bool constrain = false,
const std::unordered_set<IterDomain*>& consumer_root_dims_to_map =
std::unordered_set<IterDomain*>());

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i]
// = -1.
static std::vector<int64_t> mapDomainPtoC(
const std::vector<IterDomain*>& producer,
const std::vector<IterDomain*>& consumer);

// Create a map from producer root IterDomains -> consumer root IterDomains.
// Constrain will restrict which producer root IterDomains we map to the
// consumer IterDomains. Only those root producer IDs present in
// producer_root_dims_to_map will be attempted to map to their corresponding
// consumer IDs.
static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
const TensorDomain* producer,
const TensorDomain* consumer,
bool constrain = false,
const std::unordered_set<IterDomain*>& producer_root_dims_to_map =
std::unordered_set<IterDomain*>());

// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> rFactor(const std::vector<int>& axes);

Expand Down
121 changes: 121 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,127 @@ bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) {
return false;
}

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i] =
// -1.
std::vector<int64_t> TensorDomain::mapDomainCtoP(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very similar to mapDomainPtoC. I have some refactoring idea. Will send a PR once this is merged.

const std::vector<IterDomain*>& consumer,
const std::vector<IterDomain*>& producer) {
std::vector<int64_t> consumer_to_producer(consumer.size(), -1);

size_t itc = 0, itp = 0;
while (itc < consumer.size() && itp < producer.size()) {
if (consumer[itc]->isBroadcast() && !producer[itp]->isBroadcast()) {
itc++;
continue;
}
if (producer[itp]->isReduction()) {
itp++;
continue;
}

consumer_to_producer[itc] = itp;
itc++;
itp++;
}
return consumer_to_producer;
}

// Create a map from consumer root IterDomains -> producer root IterDomains.
// Constrain will restrict which consumer root IterDomains we map to the
// producer IterDomains. Only those root consumer IDs present in
// consumer_root_dims_to_map will be attempted to map to their corresponding
// producer IDs.
std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootCtoP(
const TensorDomain* consumer,
const TensorDomain* producer,
bool constrain,
std::unordered_set<IterDomain*> consumer_root_dims_to_map) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it const &?

auto consumer_root = consumer->rootDomain();
auto producer_root = producer->hasRFactor() ? producer->rfactorDomain()
: producer->rootDomain();

auto c_to_p = mapDomainCtoP(consumer_root, producer_root);

std::unordered_map<IterDomain*, IterDomain*> root_id_map;

for (int64_t itc = 0; itc < (int64_t)c_to_p.size(); itc++) {
int64_t itp = c_to_p[itc];
if (itp == -1)
continue;

if (!constrain ||
(constrain &&
consumer_root_dims_to_map.find(consumer_root[itc]) !=
consumer_root_dims_to_map.end())) {
root_id_map[consumer_root[itc]] = producer_root[itp];
}
}
return root_id_map;
}

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i] =
// -1.
std::vector<int64_t> TensorDomain::mapDomainPtoC(
const std::vector<IterDomain*>& producer,
const std::vector<IterDomain*>& consumer) {
std::vector<int64_t> producer_to_consumer(producer.size(), -1);

size_t itc = 0, itp = 0;
while (itc < consumer.size() && itp < producer.size()) {
if (consumer[itc]->isBroadcast() && !producer[itp]->isBroadcast()) {
itc++;
continue;
}
if (producer[itp]->isReduction()) {
itp++;
continue;
}

producer_to_consumer[itp] = itc;
itc++;
itp++;
}

return producer_to_consumer;
}

// Create a map from producer root IterDomains -> consumer root IterDomains.
// Constrain will restrict which producer root IterDomains we map to the
// consumer IterDomains. Only those root producer IDs present in
// producer_root_dims_to_map will be attempted to map to their corresponding
// consumer IDs.
std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootPtoC(
const TensorDomain* producer,
const TensorDomain* consumer,
bool constrain,
std::unordered_set<IterDomain*> producer_root_dims_to_map) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it const &?

auto consumer_root = consumer->rootDomain();
auto producer_root = producer->hasRFactor() ? producer->rfactorDomain()
: producer->rootDomain();

auto p_to_c = mapDomainPtoC(producer_root, consumer_root);

std::unordered_map<IterDomain*, IterDomain*> root_id_map;

for (int64_t itp = 0; itp < (int64_t)p_to_c.size(); itp++) {
int64_t itc = p_to_c[itp];
if (itc == -1)
continue;

if (!constrain ||
(constrain &&
producer_root_dims_to_map.find(producer_root[itp]) !=
producer_root_dims_to_map.end())) {
root_id_map[producer_root[itp]] = consumer_root[itc];
}
}
return root_id_map;
}

// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
const std::vector<int>& axes_) {
Expand Down
130 changes: 38 additions & 92 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,50 +193,24 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
consumer->domain().begin() + consumer_compute_at_axis);

// Figure out all inputs required to generate the compute_at dimensions
std::unordered_set<Val*> consumer_CA_root_ids = IterVisitor::getInputsTo(
std::unordered_set<Val*> consumer_CA_root_vals = IterVisitor::getInputsTo(
std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end()));

// Map of consumer_CA_root_ids to related producer_CA_ids
id_map replay_root_map;

// Grab root domains of producer and consumer
std::vector<IterDomain*> consumer_root = consumer->rootDomain();
std::vector<IterDomain*> producer_root = producer->rootDomain();
std::unordered_set<IterDomain*> consumer_CA_root_ids;
for (auto val : consumer_CA_root_vals) {
if (val->getValType().value() == ValType::IterDomain) {
consumer_CA_root_ids.emplace(val->as<IterDomain>());
}
}

// If producer has an rfactor root, that's what will match with consumer,
// as it means the consumer was a result of the rfactor operation.
if (producer->hasRFactor())
producer_root = producer->rfactorDomain();
// Map of consumer_CA_root_ids to related producer_CA_ids
auto replay_root_map =
TensorDomain::mapRootCtoP(consumer, producer, true, consumer_CA_root_ids);

// Track which root axes in producer we will send to replay
std::unordered_set<IterDomain*> producer_roots4replay;

// Map related axes from producer and consumer roots. Make sure we go to the
// end of both.
{
size_t itc = 0, itp = 0;
while (itc < consumer_root.size() || itp < producer_root.size()) {
if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast() &&
(itp >= producer_root.size() || !producer_root[itp]->isBroadcast())) {
itc++;
continue;
}
if (itp < producer_root.size() && producer_root[itp]->isReduction()) {
itp++;
continue;
}
TORCH_INTERNAL_ASSERT(
itc < consumer_root.size() && itp < producer_root.size(),
"Error during replay, wanted to keep going, but ran out of root dimensions.");

if (consumer_CA_root_ids.find(consumer_root[itc]) !=
consumer_CA_root_ids.end()) {
replay_root_map[consumer_root[itc]] = producer_root[itp];
producer_roots4replay.emplace(producer_root[itp]);
}
itc++;
itp++;
}
for (auto entry : replay_root_map) {
producer_roots4replay.emplace(entry.second);
}

// Instead of replaying from the root, lets try to play forward the history of
Expand Down Expand Up @@ -282,6 +256,9 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
for (auto entry : leaf_ids)
producer_self_replay_map[entry.first] = entry.first;

auto producer_root = producer->hasRFactor() ? producer->rfactorDomain()
: producer->rootDomain();

// Any root domain that was not used to generate computeIDs we can also put in
// the map to forward their transformations.
for (auto producer_root_id : producer_root)
Expand Down Expand Up @@ -382,71 +359,40 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
"Invalid axis in transform replayCasP.");

// producer ids we need to match in consumer
std::vector<IterDomain*> producer_CA_ids;
{
int itp = 0;
while (itp < producer_compute_at_axis) {
if (producer->axis(itp)->isReduction()) {
itp++;
} else {
producer_CA_ids.emplace_back(producer->axis(itp++));
}
}
}

// Map of producer_CA_root_ids to related producer_CA_ids
id_map replay_root_map;
std::vector<IterDomain*> producer_CA_ids(
producer->domain().begin(),
producer->domain().begin() + producer_compute_at_axis);
producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);

// Grab root domains of producer and consumer
std::vector<IterDomain*> consumer_root = consumer->rootDomain();
std::vector<IterDomain*> producer_root = producer->rootDomain();

// If producer has an rfactor root, that's the one that will match the
// consumer
if (producer->hasRFactor())
producer_root = producer->rfactorDomain();
// If producer has an rfactor root, that's what will match the consumer
std::vector<IterDomain*> producer_root = producer->hasRFactor()
? producer->rfactorDomain()
: producer->rootDomain();

// Figure out all inputs required to generate the compute_at dimensions
// Figure out all inputs required to generate the compute_at dimensions. We
// need all deps because inputs on producer may be in rootDomain, but we may
// need in rFactorDomain
std::unordered_set<Val*> all_CA_id_deps = DependencyCheck::getAllValsBetween(
std::unordered_set<Val*>(
producer->rootDomain().begin(), producer->rootDomain().end()),
std::vector<Val*>(producer_CA_ids.begin(), producer_CA_ids.end()));
{producer_root.begin(), producer_root.end()},
{producer_CA_ids.begin(), producer_CA_ids.end()});

// Figure out which root IDs we need:
std::unordered_set<Val*> producer_CA_root_ids;
for (Val* val : producer_root) {
if (all_CA_id_deps.find(val) != all_CA_id_deps.end())
producer_CA_root_ids.emplace(val);
std::unordered_set<IterDomain*> producer_CA_root_ids;
for (IterDomain* id : producer_root) {
if (all_CA_id_deps.find(id) != all_CA_id_deps.end())
producer_CA_root_ids.emplace(id);
}

// Track which root axes in consumer we send to replay
std::unordered_set<IterDomain*> consumer_roots4replay;
// Map related axes from producer and consumer roots. Make sure we go to the
// end of both.
{
size_t itc = 0, itp = 0;
while (itc < consumer_root.size() || itp < producer_root.size()) {
if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast() &&
(itp >= producer_root.size() || !producer_root[itp]->isBroadcast())) {
itc++;
continue;
}
if (itp < producer_root.size() && producer_root[itp]->isReduction()) {
itp++;
continue;
}
TORCH_INTERNAL_ASSERT(
itc < consumer_root.size() && itp < producer_root.size(),
"Error during replay, wanted to keep going, but ran out of root dimensions.");
auto replay_root_map =
TensorDomain::mapRootPtoC(producer, consumer, true, producer_CA_root_ids);

if (producer_CA_root_ids.find(producer_root[itp]) !=
producer_CA_root_ids.end()) {
replay_root_map[producer_root[itp]] = consumer_root[itc];
consumer_roots4replay.emplace(consumer_root[itc]);
}
itc++;
itp++;
}
// Track which root axes in producer we will send to replay
std::unordered_set<IterDomain*> consumer_roots4replay;
for (auto entry : replay_root_map) {
consumer_roots4replay.emplace(entry.second);
}

// Instead of replaying from the root, lets try to forward the history of
Expand Down