Skip to content

Commit 14a53e6

Browse files
authored
Nullary RNGOp (#1892)
1 parent 3c3c89e commit 14a53e6

36 files changed

+374
-141
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,19 @@ Val* getMaximumValue(DataType v) {
358358

359359
} // namespace
360360

361+
// TENSOR FACTORIES
362+
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
363+
auto n = shape.size();
364+
auto out = TensorViewBuilder()
365+
.ndims(n)
366+
.dtype(dtype)
367+
.contiguity(std::vector<bool>(n, true))
368+
.shape(shape)
369+
.build();
370+
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
371+
return out;
372+
}
373+
361374
Val* castOp(DataType dtype, Val* v1) {
362375
if (v1->getDataType().value() == dtype) {
363376
return set(v1);
@@ -404,17 +417,6 @@ Val* unaryOp(UnaryOpType type, Val* v1) {
404417
TORCH_INTERNAL_ASSERT(
405418
type != UnaryOpType::Address,
406419
"The reference operator & is not accessible in the Fusion IR");
407-
408-
// TODO: We should add the following, but we need to go through schedulers
409-
// and make sure all calls to "fusion->inputs" includes the output of RandLike
410-
//
411-
// If rand like, there isn't a real dependency on the input value, so map it
412-
// to a dummy scalar. if
413-
//
414-
// (type == UnaryOpType::RandLike) {
415-
// v1 = new NamedScalar("__rnd", v1->getDataType().value());
416-
// }
417-
418420
Val* out = newValLike(v1, v1->getDataType().value());
419421
IrBuilder::create<UnaryOp>(type, out, v1);
420422
return out;
@@ -469,28 +471,21 @@ NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
469471
NVFUSER_DEFINE_UNARY_OP(print, Print)
470472
#undef NVFUSER_DEFINE_UNARY_OP
471473

472-
Val* randlike(Val* v) {
474+
TensorView* randlike(TensorView* v) {
473475
TORCH_CHECK(
474476
isFloatingPointType(v->dtype()),
475477
"input must have floating point type, but got ",
476478
v->dtype());
477-
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
478-
return where(
479-
eq(rand_vals, IrBuilder::create<Double>(1.0)),
480-
IrBuilder::create<Double>(0.0),
481-
rand_vals);
479+
std::vector<Val*> shape;
480+
shape.reserve(v->getMaybeRFactorDomain().size());
481+
for (auto id : v->getMaybeRFactorDomain()) {
482+
shape.emplace_back(id->getMaybeExpandedExtent());
483+
}
484+
return rand(shape, v->dtype());
482485
}
483486

484-
TensorView* randlike(TensorView* v) {
485-
TORCH_CHECK(
486-
isFloatingPointType(v->dtype()),
487-
"input must have floating point type, but got ",
488-
v->dtype());
489-
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
490-
return where(
491-
eq(rand_vals, IrBuilder::create<Double>(1.0)),
492-
IrBuilder::create<Double>(0.0),
493-
rand_vals);
487+
Val* randlike(Val* v) {
488+
return randlike(v->as<TensorView>());
494489
}
495490

496491
Val* bitwise_not(Val* v) {

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ TORCH_CUDA_CU_API WelfordResult Welford(
121121
// import IrBuilder just for this one interface.
122122
Int* init_N = nullptr);
123123

124+
// TENSOR FACTORIES
125+
TORCH_CUDA_CU_API TensorView* rand(
126+
const std::vector<Int*>& shape,
127+
DataType dtype);
128+
124129
// UNARY OPERATIONS
125130
// abs
126131
TORCH_CUDA_CU_API Val* abs(Val*);

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -706,34 +706,12 @@ class CudaKernelGenerator : private OptOutConstDispatch {
706706
}
707707

708708
if (!print_inline_) {
709-
if (op_type == UnaryOpType::RandLike) {
710-
auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
711-
auto index = genTensorIndex(uop->in()->as<kir::TensorIndex>());
712-
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
713-
indent() << "nvfuser_index_t rng_subseq" << uop->name() << " = ("
714-
<< index << ") / " << multiple << ";\n";
715-
indent() << "nvfuser_index_t rng_component" << uop->name() << " = ("
716-
<< index << ") % " << multiple << ";\n";
717-
indent() << "nvfuser_index_t rng_offset" << uop->name() << " = "
718-
<< uop->getRNGOffset() << ";\n";
719-
indent() << "if (rng_subseq != rng_subseq" << uop->name()
720-
<< " || rng_offset != rng_offset" << uop->name() << ") {\n";
721-
indent() << " rng_result = philox(philox_args.seed_, rng_subseq"
722-
<< uop->name() << ", philox_offset / 4 + rng_offset"
723-
<< uop->name() << ");\n";
724-
indent() << " rng_subseq = rng_subseq" << uop->name() << ";\n";
725-
indent() << " rng_offset = rng_offset" << uop->name() << ";\n";
726-
indent() << "}\n";
727-
}
728-
729709
indent() << gen(uop->out());
730710
if (!uop->out()->isScalar() && !uop->in()->isScalar()) {
731711
code_ << "\n";
732712
indent() << kTab;
733713
}
734714
code_ << " = ";
735-
} else {
736-
TORCH_INTERNAL_ASSERT(op_type != UnaryOpType::RandLike);
737715
}
738716

739717
if (auto op = inline_op_str(op_type)) {
@@ -762,20 +740,43 @@ class CudaKernelGenerator : private OptOutConstDispatch {
762740
}
763741
}
764742

765-
code_ << "(";
766-
if (op_type == UnaryOpType::RandLike) {
767-
code_ << "rng_result, rng_component" << uop->name();
768-
} else {
769-
code_ << gen(uop->in());
770-
}
771-
code_ << ")";
743+
code_ << "(" << gen(uop->in()) << ")";
772744
}
773745

774746
if (!print_inline_) {
775747
code_ << ";\n";
776748
}
777749
}
778750

751+
void handle(const RNGOp* rop) final {
752+
// TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an
753+
// innermost ID of size 4 (float) or size 2 (double)?
754+
auto out_tv = rop->output(0)->as<kir::TensorIndex>()->view();
755+
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
756+
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
757+
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = (" << index
758+
<< ") / " << multiple << ";\n";
759+
indent() << "nvfuser_index_t rng_component" << rop->name() << " = ("
760+
<< index << ") % " << multiple << ";\n";
761+
indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
762+
<< rop->getRNGOffset() << ";\n";
763+
indent() << "if (rng_subseq != rng_subseq" << rop->name()
764+
<< " || rng_offset != rng_offset" << rop->name() << ") {\n";
765+
indent() << " rng_result = philox(philox_args.seed_, rng_subseq"
766+
<< rop->name() << ", philox_offset / 4 + rng_offset" << rop->name()
767+
<< ");\n";
768+
indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n";
769+
indent() << " rng_offset = rng_offset" << rop->name() << ";\n";
770+
indent() << "}\n";
771+
auto op_type = rop->getRNGOpType();
772+
indent() << gen(rop->output(0)) << " = " << op_type;
773+
if (needFloatSuffix(op_type) &&
774+
rop->output(0)->dtype() == DataType::Float) {
775+
code_ << "f";
776+
}
777+
code_ << "(rng_result, rng_component" << rop->name() << ");\n";
778+
}
779+
779780
std::string genBinaryOp(
780781
BinaryOpType op_type,
781782
DataType data_type,

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ void Expr::dispatch(T handler, Expr* expr) {
104104
case ExprType::TernaryOp:
105105
ptr(handler)->handle(expr->as<TernaryOp>());
106106
return;
107+
case ExprType::RNGOp:
108+
ptr(handler)->handle(expr->as<RNGOp>());
109+
return;
107110
case ExprType::ReductionOp:
108111
ptr(handler)->handle(expr->as<ReductionOp>());
109112
return;
@@ -284,6 +287,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
284287
case ExprType::TernaryOp:
285288
ptr(handler)->handle(expr->as<TernaryOp>());
286289
return;
290+
case ExprType::RNGOp:
291+
ptr(handler)->handle(expr->as<RNGOp>());
292+
return;
287293
case ExprType::ReductionOp:
288294
ptr(handler)->handle(expr->as<ReductionOp>());
289295
return;
@@ -472,6 +478,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
472478
case ExprType::TernaryOp:
473479
ptr(mutator)->mutate(expr->as<TernaryOp>());
474480
return;
481+
case ExprType::RNGOp:
482+
ptr(mutator)->mutate(expr->as<RNGOp>());
483+
return;
475484
case ExprType::ReductionOp:
476485
ptr(mutator)->mutate(expr->as<ReductionOp>());
477486
return;
@@ -725,6 +734,9 @@ void OptOutConstDispatch::handle(const BinaryOp* stmt) {
725734
void OptOutConstDispatch::handle(const TernaryOp* stmt) {
726735
unhandled(stmt);
727736
}
737+
void OptOutConstDispatch::handle(const RNGOp* stmt) {
738+
unhandled(stmt);
739+
}
728740
void OptOutConstDispatch::handle(const ReductionOp* stmt) {
729741
unhandled(stmt);
730742
}
@@ -875,6 +887,9 @@ void OptOutDispatch::handle(BinaryOp* stmt) {
875887
void OptOutDispatch::handle(TernaryOp* stmt) {
876888
unhandled(stmt);
877889
}
890+
void OptOutDispatch::handle(RNGOp* stmt) {
891+
unhandled(stmt);
892+
}
878893
void OptOutDispatch::handle(ReductionOp* stmt) {
879894
unhandled(stmt);
880895
}

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class NamedScalar;
7171
class UnaryOp;
7272
class BinaryOp;
7373
class TernaryOp;
74+
class RNGOp;
7475
class ReductionOp;
7576
class GroupedReductionOp;
7677
class WelfordOp;
@@ -145,6 +146,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
145146
virtual void handle(const UnaryOp* stmt);
146147
virtual void handle(const BinaryOp* stmt);
147148
virtual void handle(const TernaryOp* stmt);
149+
virtual void handle(const RNGOp* stmt);
148150
virtual void handle(const ReductionOp* stmt);
149151
virtual void handle(const GroupedReductionOp* stmt);
150152
virtual void handle(const WelfordOp* stmt);
@@ -210,6 +212,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
210212
virtual void handle(UnaryOp* stmt);
211213
virtual void handle(BinaryOp* stmt);
212214
virtual void handle(TernaryOp* stmt);
215+
virtual void handle(RNGOp* stmt);
213216
virtual void handle(ReductionOp* stmt);
214217
virtual void handle(GroupedReductionOp* stmt);
215218
virtual void handle(WelfordOp* stmt);
@@ -316,6 +319,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
316319
virtual void mutate(UnaryOp*);
317320
virtual void mutate(BinaryOp*);
318321
virtual void mutate(TernaryOp*);
322+
virtual void mutate(RNGOp*);
319323
virtual void mutate(ReductionOp*);
320324
virtual void mutate(GroupedReductionOp*);
321325
virtual void mutate(WelfordOp*);

torch/csrc/jit/codegen/cuda/fusion.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ void Fusion::printMath(bool from_outputs_only) {
373373
std::cout << "}\n\n";
374374
}
375375

376+
std::vector<Val*> Fusion::inputsAndCreated() {
377+
auto result = inputs_;
378+
for (auto expr : exprs()) {
379+
if (expr->inputs().empty()) {
380+
for (auto v : expr->outputs()) {
381+
result.emplace_back(v);
382+
}
383+
}
384+
}
385+
return result;
386+
}
387+
376388
void Fusion::printTransforms() {
377389
FUSER_PERF_SCOPE("Fusion::printTransforms");
378390

@@ -531,14 +543,15 @@ Expr* Fusion::definition(const Val* val) const {
531543

532544
// Indicate to kernel to set itself up to generate random numbers
533545
bool Fusion::isStochastic() {
534-
for (auto expr : exprs())
535-
if (expr->getExprType() == ExprType::UnaryOp)
536-
if (expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike)
537-
return true;
546+
for (auto expr : exprs()) {
547+
if (expr->getExprType() == ExprType::RNGOp) {
548+
return true;
549+
}
550+
}
538551
return false;
539552
}
540553

541-
std::vector<Val*> Fusion::getTerminatingOutputs() {
554+
std::vector<Val*> Fusion::getTerminatingOutputs() const {
542555
FUSER_PERF_SCOPE("getTerminatingOutputs");
543556

544557
auto is_reachable_to_output = [](Val* val) {

torch/csrc/jit/codegen/cuda/fusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,13 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
175175
return inputs_;
176176
}
177177

178+
std::vector<Val*> inputsAndCreated();
179+
178180
const auto& outputs() const {
179181
return outputs_;
180182
}
181183

182-
std::vector<Val*> getTerminatingOutputs();
184+
std::vector<Val*> getTerminatingOutputs() const;
183185

184186
// Aliasing output to input value, this is a WAR to allow inplace update on
185187
// input tensor.

torch/csrc/jit/codegen/cuda/ir_base_nodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Expr;
4848
class Val;
4949
class UnaryOp;
5050
class BinaryOp;
51+
class RNGOp;
5152
class IterDomain;
5253
class IrCloner;
5354
class IrContainer;

torch/csrc/jit/codegen/cuda/ir_builder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ IR_BUILDER_INSTANTIATE(ViewOp)
6363
IR_BUILDER_INSTANTIATE(UnaryOp)
6464
IR_BUILDER_INSTANTIATE(BinaryOp)
6565
IR_BUILDER_INSTANTIATE(TernaryOp)
66+
IR_BUILDER_INSTANTIATE(RNGOp)
6667
IR_BUILDER_INSTANTIATE(ReductionOp)
6768
IR_BUILDER_INSTANTIATE(GroupedReductionOp)
6869
IR_BUILDER_INSTANTIATE(WelfordOp)

torch/csrc/jit/codegen/cuda/ir_cloner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ void IrCloner::handle(const TernaryOp* op) {
100100
clone_ = IrBuilder::clone(op, this);
101101
}
102102

103+
void IrCloner::handle(const RNGOp* op) {
104+
clone_ = IrBuilder::clone(op, this);
105+
}
106+
103107
void IrCloner::handle(const BroadcastOp* op) {
104108
clone_ = IrBuilder::clone(op, this);
105109
}

torch/csrc/jit/codegen/cuda/ir_cloner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
7171
void handle(const UnaryOp*) override;
7272
void handle(const BinaryOp*) override;
7373
void handle(const TernaryOp*) override;
74+
void handle(const RNGOp*) override;
7475
void handle(const BroadcastOp*) override;
7576
void handle(const ReductionOp*) override;
7677
void handle(const GroupedReductionOp*) override;

torch/csrc/jit/codegen/cuda/ir_graphviz.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,16 @@ void IrGraphGenerator::handle(const TernaryOp* op) {
443443
addArc(op, op->out());
444444
}
445445

446+
void IrGraphGenerator::handle(const RNGOp* op) {
447+
// node
448+
std::stringstream label;
449+
label << op->getRNGOpType();
450+
printExpr(op, label.str());
451+
452+
// inputs & outputs
453+
addArc(op, op->output(0));
454+
}
455+
446456
void IrGraphGenerator::handle(const BroadcastOp* op) {
447457
printExpr(op, "Broadcast");
448458
addArc(op->in(), op);

torch/csrc/jit/codegen/cuda/ir_graphviz.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
8585
void handle(const UnaryOp*) override;
8686
void handle(const BinaryOp*) override;
8787
void handle(const TernaryOp*) override;
88+
void handle(const RNGOp*) override;
8889
void handle(const BroadcastOp*) override;
8990
void handle(const ReductionOp*) override;
9091

0 commit comments

Comments
 (0)