Skip to content

Commit 89330aa

Browse files
authored
Tensor factories must set the output shape as its input (#1939)
1 parent b2fd01e commit 89330aa

File tree

5 files changed

+55
-52
lines changed

5 files changed

+55
-52
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -362,19 +362,6 @@ Val* getMaximumValue(DataType v) {
362362

363363
} // namespace
364364

365-
// TENSOR FACTORIES
366-
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
367-
auto n = shape.size();
368-
auto out = TensorViewBuilder()
369-
.ndims(n)
370-
.dtype(dtype)
371-
.contiguity(std::vector<bool>(n, true))
372-
.shape(shape)
373-
.build();
374-
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
375-
return out;
376-
}
377-
378365
Val* castOp(DataType dtype, Val* v1) {
379366
if (v1->getDataType().value() == dtype) {
380367
return set(v1);
@@ -454,19 +441,27 @@ TensorView* unaryOp(
454441
}
455442

456443
// TENSOR FACTORIES
457-
TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype) {
444+
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
445+
auto n = shape.size();
446+
auto out = TensorViewBuilder()
447+
.ndims(n)
448+
.dtype(dtype)
449+
.contiguity(std::vector<bool>(n, true))
450+
.shape(shape)
451+
.build();
452+
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
453+
return out;
454+
}
455+
456+
TensorView* arange(Val* end, DataType dtype) {
458457
return arange(FusionGuard::getCurFusion()->zeroVal(), end, dtype);
459458
}
460459

461-
TORCH_CUDA_CU_API TensorView* arange(Val* start, Val* end, DataType dtype) {
460+
TensorView* arange(Val* start, Val* end, DataType dtype) {
462461
return arange(start, end, FusionGuard::getCurFusion()->oneVal(), dtype);
463462
}
464463

465-
TORCH_CUDA_CU_API TensorView* arange(
466-
Val* start,
467-
Val* end,
468-
Val* step,
469-
DataType dtype) {
464+
TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
470465
if (isIntegralType(dtype)) {
471466
start = castOp(DataType::Int, start);
472467
end = castOp(DataType::Int, end);

torch/csrc/jit/codegen/cuda/fusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ void Fusion::printMath(bool from_outputs_only) {
376376
std::vector<Val*> Fusion::inputsAndCreated() {
377377
auto result = inputs_;
378378
for (auto expr : exprs()) {
379-
if (expr->inputs().empty()) {
379+
auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
380+
if (tv_inputs.empty()) {
380381
for (auto v : expr->outputs()) {
381382
result.emplace_back(v);
382383
}

torch/csrc/jit/codegen/cuda/ir_iostream.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -419,27 +419,28 @@ void IrPrinter::handle(const TernaryOp* top) {
419419
}
420420

421421
void IrPrinter::handle(const RNGOp* rop) {
422-
bool istvop = ir_utils::isTvOp(rop);
423422
if (!print_inline_) {
424423
indent();
425-
os_ << rop->output(0);
426-
427-
// tensor operations tend to be long, break them up into multiple lines
428-
if (istvop) {
429-
os_ << "\n";
430-
indent_size_++;
431-
indent();
432-
}
433-
424+
os_ << rop->output(0) << "\n";
425+
indent_size_++;
426+
indent();
434427
os_ << " = ";
435428
} else {
436429
checkInlineable(rop);
437430
}
438431

439-
os_ << rop->getRNGOpType() << "()";
432+
os_ << rop->getRNGOpType() << "(";
433+
bool first = true;
434+
for (auto i : rop->inputs()) {
435+
if (!first) {
436+
os_ << ", ";
437+
}
438+
handle(i);
439+
first = false;
440+
}
441+
os_ << ")";
440442

441-
if (istvop)
442-
indent_size_--;
443+
indent_size_--;
443444

444445
if (!print_inline_)
445446
os_ << ";\n";

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ RNGOp::RNGOp(
353353
rng_op_type_(type),
354354
rng_offset_(rng_offset),
355355
philox_index_(philox_index) {
356+
if (out->isA<TensorView>()) {
357+
for (auto id : out->as<TensorView>()->getRootDomain()) {
358+
addInput(id->extent());
359+
}
360+
}
356361
addOutput(out);
357362
}
358363

torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,32 +106,33 @@ at::Tensor generate_uniform(int64_t size, at::ScalarType dtype) {
106106
} // namespace
107107
108108
TEST_F(NVFuserTest, FusionRNGValidateWithCURand_CUDA) {
109-
for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
110-
for (auto dtype : {kFloat, kDouble}) {
111-
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
112-
auto fusion = fusion_ptr.get();
113-
FusionGuard fg(fusion);
109+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
110+
auto fusion = fusion_ptr.get();
111+
FusionGuard fg(fusion);
114112
115-
Int* size_val = IrBuilder::create<Int>();
116-
fusion->addInput(size_val);
117-
TensorView* tv0 = rand({size_val}, aten_to_data_type(dtype));
118-
fusion->addOutput(tv0);
113+
Int* size_val = IrBuilder::create<Int>();
114+
fusion->addInput(size_val);
115+
TensorView* tv0 = rand({size_val}, DataType::Float);
116+
TensorView* tv1 = rand({size_val}, DataType::Double);
117+
fusion->addOutput(tv0);
118+
fusion->addOutput(tv1);
119119
120-
FusionExecutorCache fec(std::move(fusion_ptr));
120+
FusionExecutorCache fec(std::move(fusion_ptr));
121121
122-
at::manual_seed(0);
123-
auto cg_outputs = fec.runFusionWithInputs({size});
124-
auto out = cg_outputs[0];
122+
for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
123+
at::manual_seed(0);
124+
auto cg_outputs = fec.runFusionWithInputs({size});
125125
126-
at::manual_seed(0);
127-
auto ref = generate_uniform(size, dtype);
126+
at::manual_seed(0);
127+
auto ref0 = generate_uniform(size, kFloat);
128+
auto ref1 = generate_uniform(size, kDouble);
128129
129-
testValidate(fec.fusion(), {out}, {size}, {ref}, __LINE__, __FILE__);
130-
}
130+
testValidate(
131+
fec.fusion(), cg_outputs, {size}, {ref0, ref1}, __LINE__, __FILE__);
131132
}
132133
}
133134
134-
TEST_F(NVFuserTest, FusionRNGSimpleValidateWithCURand_CUDA) {
135+
TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) {
135136
int64_t size = 128;
136137
auto dtype = kFloat;
137138
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();

0 commit comments

Comments
 (0)