Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions third_party/nvfuser/csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,25 +330,16 @@ void IterDomainGraph::build(Fusion* fusion) {
const auto& domain = tv->domain()->domain();
auto all_ids = ir_utils::allIDsOf(tv);

// Check is this domain is a consumer of a view-like operation
bool view_like_domain = tv->domain()->hasViewLikeRFactor();

for (auto id : all_ids) {
// Check if this id is a view like rfactor id
bool is_view_rfactor_id = false;
if (view_like_domain && id->isRFactorProduct()) {
// If the tensor domain is a view like domain, and the iteration domain
// is marked as an rfactor product and is in the rfactor domain, it's a
// view like rfactor iteration domain
const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain();
if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) !=
rfactor_domain.end()) {
is_view_rfactor_id = true;
}
}
// Check if this id is an rfactor id in the rfactor domain
bool is_rfactor_domain_id = id->isRFactorProduct() &&
std::find(
tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end(),
id) != tv->getMaybeRFactorDomain().end();
bool is_leaf_id =
std::find(domain.begin(), domain.end(), id) != domain.end();
initializeId(id, is_view_rfactor_id, is_leaf_id);
initializeId(id, is_rfactor_domain_id, is_leaf_id);
}
}

Expand Down Expand Up @@ -687,7 +678,7 @@ void IterDomainGraph::build(Fusion* fusion) {

void IterDomainGraph::initializeId(
IterDomain* id,
bool is_view_rfactor_id,
bool is_rfactor_id,
bool is_leaf_id) {
permissive_nodes_.initializeSet(id);
exact_nodes_.initializeSet(id);
Expand All @@ -700,8 +691,8 @@ void IterDomainGraph::initializeId(

all_ids_.pushBack(id);

if (is_view_rfactor_id) {
view_rfactor_ids_.emplace(id);
if (is_rfactor_id) {
rfactor_ids_.emplace(id);
}
}

Expand Down Expand Up @@ -994,7 +985,7 @@ IterDomain* ComputeAtMap::computeConcreteId(
if (std::none_of(
exact_set->vector().begin(),
exact_set->vector().end(),
[&](IterDomain* id) { return isViewRfactor(id); })) {
[&](IterDomain* id) { return isRfactor(id); })) {
continue;
}
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
Expand Down Expand Up @@ -1372,19 +1363,18 @@ std::string ComputeAtMap::toString() const {
return ss.str();
}

bool ComputeAtMap::isViewRfactor(IterDomain* ref_id) const {
return id_graph_.viewRfactorIds().find(ref_id) !=
id_graph_.viewRfactorIds().end();
bool ComputeAtMap::isRfactor(IterDomain* ref_id) const {
return id_graph_.rfactorIds().find(ref_id) != id_graph_.rfactorIds().end();
}

std::vector<IterDomain*> ComputeAtMap::getViewRfactorDomainsOfIdGroup(
std::vector<IterDomain*> ComputeAtMap::getRfactorDomainsOfIdGroup(
IterDomain* ref_id,
IdMappingMode mode) const {
auto disjoint_set = disjointSetOf(ref_id, mode);
std::vector<IterDomain*> rfactor_ids;
for (auto disjoint_id : disjoint_set->vector()) {
if (id_graph_.viewRfactorIds().find(disjoint_id) !=
id_graph_.viewRfactorIds().end()) {
if (id_graph_.rfactorIds().find(disjoint_id) !=
id_graph_.rfactorIds().end()) {
rfactor_ids.push_back(disjoint_id);
}
}
Expand Down Expand Up @@ -1453,7 +1443,7 @@ ComputeAtMap::getInputDisjointSetsOf(IterDomain* of_id, bool stop_at_rfactor) {
std::any_of(
currently_visiting->vector().begin(),
currently_visiting->vector().end(),
[&](IterDomain* id) { return isViewRfactor(id); })) {
[&](IterDomain* id) { return isRfactor(id); })) {
input_disjoint_sets.pushBack(currently_visiting);
continue;
}
Expand Down
20 changes: 11 additions & 9 deletions third_party/nvfuser/csrc/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class TORCH_CUDA_CU_API IterDomainGraph {
return all_ids_;
}

const std::unordered_set<IterDomain*>& viewRfactorIds() const {
return view_rfactor_ids_;
const std::unordered_set<IterDomain*>& rfactorIds() const {
return rfactor_ids_;
}

// Returns if first and second are expressions through which the provided
Expand All @@ -115,7 +115,7 @@ class TORCH_CUDA_CU_API IterDomainGraph {
private:
void build(Fusion* fusion);

void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id);
void initializeId(IterDomain* id, bool is_rfactor_id, bool is_leaf_id);

// Checks if exprsMap then if forward will map outputs else inputs in exact
// and permissive map.
Expand All @@ -136,7 +136,9 @@ class TORCH_CUDA_CU_API IterDomainGraph {

VectorOfUniqueEntries<IterDomain*> all_ids_;

std::unordered_set<IterDomain*> view_rfactor_ids_;
// This used to only have non-reduction rfactor IDs. Changed to
// include reduction rfactor IDs as well at PR #2562
std::unordered_set<IterDomain*> rfactor_ids_;

c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
self_mapping_info_ = c10::nullopt;
Expand Down Expand Up @@ -214,13 +216,13 @@ class TORCH_CUDA_CU_API ComputeAtMap {
// Prints mapping information, forwards to an internal IterDomainGraph
std::string toString() const;

// Returns if the provided ID is a view like rfactor id
bool isViewRfactor(IterDomain* ref_id) const;
// Returns if the provided ID is an rfactor id
bool isRfactor(IterDomain* ref_id) const;

// Returns all rfactor domains in rfactor_concrete_count_reset_domains_ that
// are in the disjoint set of the provided IterDomain. This will be every view
// like rfactor ID the provided ID "depends" on in the map.
std::vector<IterDomain*> getViewRfactorDomainsOfIdGroup(
// are in the disjoint set of the provided IterDomain. This will be every
// rfactor ID the provided ID "depends" on in the map.
std::vector<IterDomain*> getRfactorDomainsOfIdGroup(
IterDomain* ref_id,
IdMappingMode mode) const;

Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ IterDomain* getRfactorIDToTraverse(
IterDomain* id,
const std::vector<Val*>& consumer_all_ids) {
const auto& rfactor_ids =
GpuLower::current()->caMap()->getViewRfactorDomainsOfIdGroup(
GpuLower::current()->caMap()->getRfactorDomainsOfIdGroup(
id, IdMappingMode::PERMISSIVE);

if (rfactor_ids.empty()) {
Expand Down
4 changes: 2 additions & 2 deletions third_party/nvfuser/test/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6089,14 +6089,14 @@ TEST_F(NVFuserTest, FusionRepro2094_CUDA) {
auto tv0 = TensorViewBuilder()
.ndims(1)
.shape(neg_one_vec)
.contiguity({true})
.contiguity(true)
.dtype(DataType::Float)
.build();
fusion->addInput(tv0);
auto tv1 = TensorViewBuilder()
.ndims(1)
.shape(neg_one_vec)
.contiguity({true})
.contiguity(true)
.dtype(DataType::Float)
.build();
fusion->addInput(tv1);
Expand Down
41 changes: 41 additions & 0 deletions third_party/nvfuser/test/test_gpu_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <gtest/gtest.h>

#include <executor.h>
#include <inlining.h>
#include <ir_all_nodes.h>
#include <ir_builder.h>
#include <ops/arith.h>
Expand Down Expand Up @@ -783,4 +784,44 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) {
&fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
}

// Repro of issue #2560
TEST_F(NVFuserTest, FusionIndexing18_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);

auto tv2 = broadcast(tv0, {false, true});
auto tv3 = add(tv2, tv1);
auto tv4 = sum(tv3, {0, 1});
fusion.addOutput(tv4);

tv4->merge(0);
tv4->split(0, 4);
auto tv5 = tv4->rFactor({1});

MaxRootDomainInfoSpanningTree tree(tv5);
TransformPropagator tp(tv5);
tree.traverse(&tp);

inlineAllAt(tv4, 1, true);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(1);
at::Tensor t0 = at::randn({5}, options);
at::Tensor t1 = at::randn({5, 3}, options);
std::vector<c10::IValue> inputs = {t0, t1};

FusionExecutor fe;
fe.compileFusion(&fusion, inputs);
auto cg_outputs = fe.runFusion(inputs);

auto ref = (t0.unsqueeze(-1) + t1).sum();

testValidate(fe.kernel(), cg_outputs, inputs, {ref}, __LINE__, __FILE__);
}

} // namespace nvfuser