Skip to content

Commit 40e2703

Browse files
authored
Reduction rand like patch (#2031)
*_like operations are not filtering out reduction domain on inputs. This resulted with output differs in shape on input. Run into this issue on hugging face benchmark with python stack. 1. updated the operation to filter input domain with noReduction; 2. added a test case to verify the breakage and fix;
1 parent bc77266 commit 40e2703

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,17 +471,18 @@ TensorView* uniform(
471471
return out;
472472
}
473473

474-
TensorView* rand_like(TensorView* v) {
474+
TensorView* rand_like(TensorView* tv) {
475475
TORCH_CHECK(
476-
isFloatingPointType(v->dtype()),
476+
isFloatingPointType(tv->dtype()),
477477
"input must have floating point type, but got ",
478-
v->dtype());
478+
tv->dtype());
479479
std::vector<Val*> shape;
480-
shape.reserve(v->getMaybeRFactorDomain().size());
481-
for (auto id : v->getMaybeRFactorDomain()) {
480+
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
481+
shape.reserve(dom.size());
482+
for (auto id : dom) {
482483
shape.emplace_back(id->getMaybeExpandedExtent());
483484
}
484-
return rand(shape, v->dtype());
485+
return rand(shape, tv->dtype());
485486
}
486487

487488
Val* rand_like(Val* v) {
@@ -505,8 +506,9 @@ TensorView* full(
505506

506507
TensorView* full_like(TensorView* tv, Val* fill_value) {
507508
std::vector<Val*> shape;
508-
shape.reserve(tv->getMaybeRFactorDomain().size());
509-
for (auto id : tv->getMaybeRFactorDomain()) {
509+
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
510+
shape.reserve(dom.size());
511+
for (auto id : dom) {
510512
shape.emplace_back(id->getMaybeExpandedExtent());
511513
}
512514
return full(shape, fill_value, tv->dtype());

torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,5 +365,35 @@ TEST_F(NVFuserTest, FusionUniform_CUDA) {
365365
}
366366
}
367367
368+
TEST_F(NVFuserTest, FusionRandLikeReduction_CUDA) {
369+
auto dtype = kFloat;
370+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
371+
auto fusion = fusion_ptr.get();
372+
FusionGuard fg(fusion);
373+
374+
TensorView* tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype));
375+
fusion->addInput(tv0);
376+
auto tv1 = sum(tv0, {0});
377+
auto tv2 = rand_like(tv1);
378+
auto tv3 = add(tv1, tv2);
379+
fusion->addOutput(tv3);
380+
381+
FusionExecutorCache fec(std::move(fusion_ptr));
382+
383+
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
384+
at::Tensor t0 = at::zeros({2, 3}, options);
385+
386+
at::manual_seed(0);
387+
auto cg_outputs = fec.runFusionWithInputs({t0});
388+
auto out = cg_outputs[0];
389+
390+
at::manual_seed(0);
391+
auto t1 = t0.sum(0);
392+
auto t2 = generate_uniform(3, dtype).expand_as(t1);
393+
auto t3 = t1.add(t2);
394+
395+
testValidate(fec.fusion(), {out}, {t0}, {t3}, __LINE__, __FILE__);
396+
}
397+
368398
} // namespace jit
369399
} // namespace torch

0 commit comments

Comments
 (0)