Skip to content

Commit f10afcd

Browse files
authored
issue 1189 repro and fix (pytorch#1193)
1 parent bee312c commit f10afcd

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

test/cpp/jit/test_gpu.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17250,6 +17250,47 @@ TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) {
1725017250
testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__);
1725117251
}
1725217252

17253+
TEST(NVFuserTest, FusionIssue1189_CUDA) {
17254+
Fusion fusion;
17255+
FusionGuard fg(&fusion);
17256+
17257+
auto tv0 = makeConcreteTensor({16, 16});
17258+
auto tv1 = makeConcreteTensor({16, 16});
17259+
17260+
auto tv0b = broadcast(tv0, {false, false, true});
17261+
auto tv1b = broadcast(tv1, {false, false, true});
17262+
17263+
fusion.addInput(tv0b);
17264+
fusion.addInput(tv1b);
17265+
17266+
auto tv2 = add(tv0b, tv1b);
17267+
auto tv3 = sum(tv2, {1});
17268+
fusion.addOutput(tv3);
17269+
17270+
auto parallelize = [](auto tv) {
17271+
tv->axis(0)->parallelize(ParallelType::TIDx);
17272+
tv->axis(1)->parallelize(ParallelType::BIDx);
17273+
tv->axis(2)->parallelize(ParallelType::BIDy);
17274+
};
17275+
17276+
parallelize(tv0b);
17277+
parallelize(tv1b);
17278+
parallelize(tv2);
17279+
parallelize(tv3);
17280+
17281+
FusionExecutor fe;
17282+
fe.compileFusion(&fusion);
17283+
17284+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
17285+
at::Tensor t0 = at::randn({16, 16, 1}, options);
17286+
at::Tensor t1 = at::randn({16, 16, 1}, options);
17287+
auto outputs = fe.runFusion({t0, t1});
17288+
17289+
auto ref = (t0 + t1).sum({1});
17290+
17291+
testValidate(&fusion, outputs, {t0, t1}, {ref}, __LINE__, __FILE__);
17292+
}
17293+
1725317294
TEST(NVFuserTest, FusionIssue1052_CUDA) {
1725417295
Fusion fusion;
1725517296
FusionGuard fg(&fusion);

torch/csrc/jit/codegen/cuda/executor.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,18 +339,19 @@ LaunchParams FusionExecutor::computeLaunchParams(
339339

340340
auto data_cache = compileTimeDataCache();
341341

342+
auto& lower = lowered_;
343+
342344
auto& used_tvs = getUsedTVs();
343345
auto parallel_binding_ids_entry =
344346
executor_utils::caching::ExecutorCompileTimeEntry<
345347
executor_utils::caching::ParallelBindingIterDomains>(
346-
data_cache, [&used_tvs]() {
348+
data_cache, [&used_tvs, &lower]() {
347349
return std::make_unique<std::vector<IterDomain*>>(
348-
executor_utils::getParallelBindingsIterDomains(used_tvs));
350+
executor_utils::getParallelBindingsIterDomains(
351+
lower, used_tvs));
349352
});
350353
auto& parallel_binding_ids = parallel_binding_ids_entry.get();
351354

352-
auto& lower = lowered_;
353-
354355
auto parallel_iter_extent_entry =
355356
executor_utils::caching::ExecutorCompileTimeEntry<
356357
executor_utils::caching::ParallelIterExtentMap>(

torch/csrc/jit/codegen/cuda/executor_utils.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,12 +989,25 @@ template class ExecutorCompileTimeEntry<OutputAliasIndices>;
989989
} // namespace caching
990990

991991
std::vector<IterDomain*> getParallelBindingsIterDomains(
992+
GpuLower& lower,
992993
const std::vector<TensorView*>& used_tvs) {
993994
std::vector<IterDomain*> parallel_ids;
994995
for (auto tv : used_tvs) {
995996
for (auto id : tv->domain()->domain()) {
996-
if (id->isThread() && !id->isBroadcast()) {
997-
parallel_ids.push_back(id);
997+
if (id->isThread()) {
998+
if (id->isBroadcast()) {
999+
// Want to keep the broadcast dimensions if they are not resolved
1000+
// TODO: piping down the parallel dimension map here would
1001+
// be helpful
1002+
auto& parallel_map = lower.caParallelMap();
1003+
if (parallel_map.getConcreteMappedID(id) == id) {
1004+
parallel_ids.push_back(id);
1005+
}
1006+
} else {
1007+
// Non broadcast ids are directly added to the binding
1008+
// ids.
1009+
parallel_ids.push_back(id);
1010+
}
9981011
}
9991012
}
10001013
}

torch/csrc/jit/codegen/cuda/executor_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ class ExecutorCompileTimeEntry {
282282
//! Returns the vector of tensorviews that will be used to bind parallel
283283
//! dimensions.
284284
std::vector<IterDomain*> getParallelBindingsIterDomains(
285+
GpuLower& lower,
285286
const std::vector<TensorView*>& used_tvs);
286287

287288
using ParallelExtentMap =

0 commit comments

Comments
 (0)