Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,45 +462,45 @@ class TORCH_CUDA_API TensorDomain : public Val {
static bool hasBroadcast(const std::vector<IterDomain*>&);
static bool hasReduction(const std::vector<IterDomain*>&);

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i]
// = -1.
static std::vector<int64_t> mapDomainCtoP(
const std::vector<IterDomain*>& consumer,
const std::vector<IterDomain*>& producer);
// return pairs of producer axes and consumer axes that represent
// mapping between corresponding axies. Not all axes have
// corresponding mapping, e.g., broadcast axis in consumer
// does not have any corresponding axis in producer.
static std::vector<std::pair<int, int>> mapDomainPandC(
const std::vector<IterDomain*>& producer,
const std::vector<IterDomain*>& consumer);

// Create a map between producer root IterDomains and consumer root
// IterDomains.
static std::vector<std::pair<IterDomain*, IterDomain*>> mapRootPandC(
const TensorDomain* producer,
const TensorDomain* consumer);

// Create a map from consumer root IterDomains -> producer root IterDomains.
// Constrain will restrict which consumer root IterDomains we map to the
// producer IterDomains. Only those root consumer IDs present in
// consumer_root_dims_to_map will be attempted to map to their corresponding
// producer IDs.
// Only those root consumer IDs present in consumer_root_dims_to_map
// will be attempted to map to their corresponding producer IDs.
static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
const TensorDomain* consumer,
const TensorDomain* producer,
bool constrain = false,
const std::unordered_set<IterDomain*>& consumer_root_dims_to_map =
std::unordered_set<IterDomain*>());

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i]
// = -1.
static std::vector<int64_t> mapDomainPtoC(
const std::vector<IterDomain*>& producer,
const std::vector<IterDomain*>& consumer);
const std::unordered_set<IterDomain*>& consumer_root_dims_to_map);
static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
const TensorDomain* consumer,
const TensorDomain* producer) {
return mapRootCtoP(consumer, producer, {});
}

// Create a map from producer root IterDomains -> consumer root IterDomains.
// Constrain will restrict which producer root IterDomains we map to the
// consumer IterDomains. Only those root producer IDs present in
// producer_root_dims_to_map will be attempted to map to their corresponding
// consumer IDs.
// Only those root producer IDs present in producer_root_dims_to_map
// will be attempted to map to their corresponding consumer IDs.
static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
const TensorDomain* producer,
const TensorDomain* consumer,
bool constrain = false,
const std::unordered_set<IterDomain*>& producer_root_dims_to_map =
std::unordered_set<IterDomain*>());
const std::unordered_set<IterDomain*>& producer_root_dims_to_map);
static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
const TensorDomain* producer,
const TensorDomain* consumer) {
return mapRootPtoC(producer, consumer, {});
}

// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> rFactor(const std::vector<int>& axes);
Expand Down
113 changes: 31 additions & 82 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,10 @@ bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) {
return false;
}

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i] =
// -1.
std::vector<int64_t> TensorDomain::mapDomainCtoP(
const std::vector<IterDomain*>& consumer,
const std::vector<IterDomain*>& producer) {
std::vector<int64_t> consumer_to_producer(consumer.size(), -1);
std::vector<std::pair<int, int>> TensorDomain::mapDomainPandC(
const std::vector<IterDomain*>& producer,
const std::vector<IterDomain*>& consumer) {
std::vector<std::pair<int, int>> dom_map;

size_t itc = 0, itp = 0;
while (itc < consumer.size() && itp < producer.size()) {
Expand All @@ -869,102 +865,55 @@ std::vector<int64_t> TensorDomain::mapDomainCtoP(
continue;
}

consumer_to_producer[itc] = itp;
dom_map.emplace_back(std::make_pair(itp, itc));
itc++;
itp++;
}
return consumer_to_producer;
return dom_map;
}

// Create a map from consumer root IterDomains -> producer root IterDomains.
// Constrain will restrict which consumer root IterDomains we map to the
// producer IterDomains. Only those root consumer IDs present in
// consumer_root_dims_to_map will be attempted to map to their corresponding
// producer IDs.
std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootCtoP(
const TensorDomain* consumer,
std::vector<std::pair<IterDomain*, IterDomain*>> TensorDomain::mapRootPandC(
const TensorDomain* producer,
bool constrain,
const std::unordered_set<IterDomain*>& consumer_root_dims_to_map) {
const TensorDomain* consumer) {
auto consumer_root = consumer->rootDomain();
auto producer_root = producer->hasRFactor() ? producer->rfactorDomain()
: producer->rootDomain();

auto c_to_p = mapDomainCtoP(consumer_root, producer_root);

std::unordered_map<IterDomain*, IterDomain*> root_id_map;

for (int64_t itc = 0; itc < (int64_t)c_to_p.size(); itc++) {
int64_t itp = c_to_p[itc];
if (itp == -1)
continue;

if (!constrain ||
(constrain &&
consumer_root_dims_to_map.find(consumer_root[itc]) !=
consumer_root_dims_to_map.end())) {
root_id_map[consumer_root[itc]] = producer_root[itp];
}
std::vector<std::pair<IterDomain*, IterDomain*>> root_id_map;
for (const auto& m : mapDomainPandC(producer_root, consumer_root)) {
auto producer_axis = producer_root[m.first];
auto consumer_axis = consumer_root[m.second];
root_id_map.emplace_back(std::make_pair(producer_axis, consumer_axis));
}
return root_id_map;
}

// return mapping of consumer_domain[i] = producer_domain[result_vector[i]]
// assuming there exists a direct consumer-producer mapping. If axis exists in
// consumer (broadcast) but not in producer, mapping will be result_vector[i] =
// -1.
std::vector<int64_t> TensorDomain::mapDomainPtoC(
const std::vector<IterDomain*>& producer,
const std::vector<IterDomain*>& consumer) {
std::vector<int64_t> producer_to_consumer(producer.size(), -1);

size_t itc = 0, itp = 0;
while (itc < consumer.size() && itp < producer.size()) {
if (consumer[itc]->isBroadcast() && !producer[itp]->isBroadcast()) {
itc++;
continue;
}
if (producer[itp]->isReduction()) {
itp++;
continue;
std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootCtoP(
const TensorDomain* consumer,
const TensorDomain* producer,
const std::unordered_set<IterDomain*>& consumer_root_dims_to_map) {
std::unordered_map<IterDomain*, IterDomain*> root_id_map;
for (const auto& kv : mapRootPandC(producer, consumer)) {
auto producer_axis = kv.first;
auto consumer_axis = kv.second;
if (consumer_root_dims_to_map.find(consumer_axis) !=
consumer_root_dims_to_map.end()) {
root_id_map[consumer_axis] = producer_axis;
}

producer_to_consumer[itp] = itc;
itc++;
itp++;
}

return producer_to_consumer;
return root_id_map;
}

// Create a map from producer root IterDomains -> consumer root IterDomains.
// Constrain will restrict which producer root IterDomains we map to the
// consumer IterDomains. Only those root producer IDs present in
// producer_root_dims_to_map will be attempted to map to their corresponding
// consumer IDs.
std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootPtoC(
const TensorDomain* producer,
const TensorDomain* consumer,
bool constrain,
const std::unordered_set<IterDomain*>& producer_root_dims_to_map) {
auto consumer_root = consumer->rootDomain();
auto producer_root = producer->hasRFactor() ? producer->rfactorDomain()
: producer->rootDomain();

auto p_to_c = mapDomainPtoC(producer_root, consumer_root);

std::unordered_map<IterDomain*, IterDomain*> root_id_map;

for (int64_t itp = 0; itp < (int64_t)p_to_c.size(); itp++) {
int64_t itc = p_to_c[itp];
if (itc == -1)
continue;

if (!constrain ||
(constrain &&
producer_root_dims_to_map.find(producer_root[itp]) !=
producer_root_dims_to_map.end())) {
root_id_map[producer_root[itp]] = consumer_root[itc];
for (const auto& kv : mapRootPandC(producer, consumer)) {
auto producer_axis = kv.first;
auto consumer_axis = kv.second;
if (producer_root_dims_to_map.find(producer_axis) !=
producer_root_dims_to_map.end()) {
root_id_map[producer_axis] = consumer_axis;
}
}
return root_id_map;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(

// Map of consumer_CA_root_ids to related producer_CA_ids
auto replay_root_map =
TensorDomain::mapRootCtoP(consumer, producer, true, consumer_CA_root_ids);
TensorDomain::mapRootCtoP(consumer, producer, consumer_CA_root_ids);

// Track which root axes in producer we will send to replay
std::unordered_set<IterDomain*> producer_roots4replay;
Expand Down Expand Up @@ -387,7 +387,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
}

auto replay_root_map =
TensorDomain::mapRootPtoC(producer, consumer, true, producer_CA_root_ids);
TensorDomain::mapRootPtoC(producer, consumer, producer_CA_root_ids);

// Track which root axes in producer we will send to replay
std::unordered_set<IterDomain*> consumer_roots4replay;
Expand Down