Skip to content

Commit 15091c4

Browse files
authored
Rework RNG to correctly support broadcasted dropout (#1888)
1 parent aafe2d0 commit 15091c4

21 files changed

+435
-177
lines changed

benchmarks/cpp/nvfuser/bert.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
342342
bytes * int64_t(benchmark_state.iterations()));
343343
}
344344

345-
static void MagicScheduler_fp32_BiasDropoutAddLayernormFwd(
346-
benchmark::State& benchmark_state) {
347-
MagicScheduler_BiasDropoutAddLayernormFwd(benchmark_state, DataType::Float);
348-
}
349-
350345
static void setupBiasDropoutAddLayernormBwd1(Fusion* fusion, DataType dtype) {
351346
FusionGuard fg(fusion);
352347

@@ -677,6 +672,16 @@ static void DivMaxSoftDropBwd_fp16(benchmark::State& benchmark_state) {
677672
MagicScheduler_DivMaxSoftDropBwd(benchmark_state, DataType::Half);
678673
}
679674

675+
static void BiasDropoutAddLayernormFwd_fp32(
676+
benchmark::State& benchmark_state) {
677+
MagicScheduler_BiasDropoutAddLayernormFwd(benchmark_state, DataType::Float);
678+
}
679+
680+
static void BiasDropoutAddLayernormFwd_tf32(
681+
benchmark::State& benchmark_state) {
682+
MagicScheduler_BiasDropoutAddLayernormFwd(benchmark_state, DataType::Float);
683+
}
684+
680685
static void BiasDropoutAddLayernormBwd1_fp32(
681686
benchmark::State& benchmark_state) {
682687
MagicScheduler_BiasDropoutAddLayernormBwd1(benchmark_state, DataType::Float);
@@ -724,6 +729,19 @@ BENCHMARK(DivMaxSoftDropBwd_fp16)
724729
->Unit(benchmark::kMicrosecond)
725730
->UseManualTime();
726731

732+
BENCHMARK(BiasDropoutAddLayernormFwd_fp32)
733+
// ->RangeMultiplier(2)
734+
->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
735+
->Unit(benchmark::kMicrosecond)
736+
->UseManualTime();
737+
738+
// Use full ampere wave here
739+
BENCHMARK(BiasDropoutAddLayernormFwd_tf32)
740+
// ->RangeMultiplier(2)
741+
->Ranges({{32, 1024}, {128, 128}, {864, 864}})
742+
->Unit(benchmark::kMicrosecond)
743+
->UseManualTime();
744+
727745
BENCHMARK(BiasDropoutAddLayernormBwd1_fp32)
728746
// ->RangeMultiplier(2)
729747
->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})

test/cpp/jit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ if(USE_CUDA)
101101
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
102102
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp)
103103
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp)
104+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
104105
endif()
105106

106107
add_executable(test_jit

test/test_jit_cuda_fuser.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,10 +2357,13 @@ def t(x: torch.Tensor):
23572357
self.assertEqual(o, jit_o)
23582358
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
23592359

2360+
@unittest.skip("Skipped due to rand_like behavior change")
23602361
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
23612362
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
23622363
"Requires fusion optimization pass to be effective")
23632364
def test_profiling_node(self):
2365+
# TODO: should we change this test to not use rand_like, or just
2366+
# remove this test?
23642367
dtype = torch.float
23652368
device = "cuda"
23662369
x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device)
@@ -2372,26 +2375,6 @@ def repro(x: torch.Tensor, alpha: float):
23722375
repro_jit = torch.jit.script(repro)
23732376
self._run_helper(repro_jit, repro, x, 0.6)
23742377

2375-
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
2376-
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
2377-
"Requires fusion optimization pass to be effective")
2378-
def test_rand_like(self):
2379-
dtype = torch.float
2380-
device = "cuda"
2381-
2382-
def t(x: torch.Tensor, alpha: float):
2383-
o = torch.rand_like(x)
2384-
o = torch.add(o, alpha)
2385-
return o
2386-
2387-
# disabling cache so new inputs would generate new graph
2388-
t.__disable_jit_function_caching__ = True
2389-
2390-
for m_format in [torch.contiguous_format, torch.channels_last]:
2391-
x = torch.randn(4, 5, 6, 7, dtype=dtype, device=device).to(memory_format=m_format)
2392-
t_jit = torch.jit.script(t)
2393-
self._run_helper(t_jit, t, x, 0.6, check_stride=True)
2394-
23952378
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
23962379
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
23972380
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
@@ -4864,19 +4847,13 @@ def clamp_min(x):
48644847
def test_device_constant(self):
48654848
x = torch.randn(4, 2, device="cuda")
48664849

4867-
def t(x):
4868-
return torch.rand_like(x, device=torch.device(type='cuda'))
4869-
48704850
# cpu tensor shouldn't be fused
48714851
def t_cpu(x):
48724852
return torch.rand_like(x, device=torch.device(type='cpu'))
48734853

48744854
with nvfuser_singleton_fusion(True):
4875-
t_jit = torch.jit.script(t)
4876-
self._run_helper(t_jit, t, x)
4877-
48784855
t_cpu_jit = torch.jit.script(t_cpu)
4879-
for i in range(5):
4856+
for _ in range(5):
48804857
t_cpu_jit(x)
48814858

48824859
self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0)

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
247247
}
248248

249249
// Kernels generating random numbers take extra (seed, offset) arguments
250-
if (kernel_summary.is_stochastic) {
250+
if (kernel_summary.max_rng_offsets >= 0) {
251251
code_ << ", at::PhiloxCudaState philox_args";
252252
}
253253

@@ -259,14 +259,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
259259
const auto& kernel_summary = kernel_->summary();
260260

261261
// Random number generator (optional)
262-
if (kernel_summary.is_stochastic) {
263-
indent()
264-
<< "const auto idx = ((((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * blockDim.z + threadIdx.z) * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x;";
262+
if (kernel_summary.max_rng_offsets >= 0) {
265263
indent() << "auto offset = philox_args.captured_ ?\n";
266264
indent()
267265
<< " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
268266
indent() << " philox_args.offset_.val;\n";
269-
indent() << "Philox rnd(philox_args.seed_, idx, offset);\n";
267+
indent() << "uint4 rng_result;\n";
268+
indent() << "nvfuser_index_t rng_subseq = -1;\n";
269+
indent() << "nvfuser_index_t rng_offset = -1;\n";
270270
}
271271

272272
// Do we have any dynamic shared memory buffers?
@@ -695,8 +695,9 @@ class CudaKernelGenerator : private OptOutConstDispatch {
695695
}
696696
}
697697

698+
const auto op_type = uop->getUnaryOpType();
699+
698700
if (uop->out()->isA<NamedScalar>()) {
699-
const auto op_type = uop->getUnaryOpType();
700701
if (auto op = inline_op_str(op_type)) {
701702
indent() << gen(uop->out()) << " = " << *op << genInline(uop->in())
702703
<< ";\n";
@@ -705,15 +706,36 @@ class CudaKernelGenerator : private OptOutConstDispatch {
705706
}
706707

707708
if (!print_inline_) {
709+
if (op_type == UnaryOpType::RandLike) {
710+
auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
711+
auto index = genTensorIndex(uop->out()->as<kir::TensorIndex>());
712+
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
713+
indent() << "nvfuser_index_t subseq" << uop->name() << " = (" << index
714+
<< ") / " << multiple << ";\n";
715+
indent() << "nvfuser_index_t component" << uop->name() << " = ("
716+
<< index << ") % " << multiple << ";\n";
717+
indent() << "nvfuser_index_t offset" << uop->name() << " = "
718+
<< uop->getRNGOffset() << ";\n";
719+
indent() << "if (rng_subseq != subseq" << uop->name()
720+
<< " || rng_offset != offset" << uop->name() << ") {\n";
721+
indent() << " rng_result = philox(philox_args.seed_, subseq"
722+
<< uop->name() << ", offset / 4 + offset" << uop->name()
723+
<< ");\n";
724+
indent() << " rng_subseq = subseq" << uop->name() << ";\n";
725+
indent() << " rng_offset = offset" << uop->name() << ";\n";
726+
indent() << "}\n";
727+
}
728+
708729
indent() << gen(uop->out());
709730
if (!uop->out()->isScalar() && !uop->in()->isScalar()) {
710731
code_ << "\n";
711732
indent() << kTab;
712733
}
713734
code_ << " = ";
735+
} else {
736+
TORCH_INTERNAL_ASSERT(op_type != UnaryOpType::RandLike);
714737
}
715738

716-
const auto op_type = uop->getUnaryOpType();
717739
if (auto op = inline_op_str(op_type)) {
718740
if (alsoBooleanOperator(op_type) &&
719741
uop->out()->dtype() == DataType::Bool) {
@@ -742,7 +764,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
742764

743765
code_ << "(";
744766
if (op_type == UnaryOpType::RandLike) {
745-
code_ << "rnd";
767+
code_ << "rng_result, component" << uop->name();
746768
} else {
747769
code_ << gen(uop->in());
748770
}

torch/csrc/jit/codegen/cuda/executor.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -916,18 +916,14 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
916916

917917
global_buffers = allocGlobalVals(expr_eval);
918918

919-
if (kernel()->summary().is_stochastic) {
919+
if (kernel()->summary().max_rng_offsets >= 0) {
920920
// NOTE: this is how we map offset to PW kernels in order to have
921921
// identical random number generator to match native PyTorch results.
922922
// But it doesn't really work as it takes assumption how threads are
923923
// binded but is not generally how we handle that in scheduler.
924924
// Refer to `Philox` in generated kernel to understand how the mapping
925925
// works.
926-
rand_offset = 4 *
927-
(std::ceil(
928-
allocated_outputs[0].numel() /
929-
(4.0 * 128 * launch_params_.gdimx())) + // NOLINT
930-
1);
926+
rand_offset = (kernel()->summary().max_rng_offsets + 1) * 4;
931927
}
932928

933929
// This is the entry when we have provided `opt_code` but the entry has not
@@ -961,7 +957,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
961957
kernel_arguments.push(inputs);
962958
kernel_arguments.push(allocated_outputs);
963959
kernel_arguments.push(global_buffers.buffers);
964-
if (lowered_->kernel()->summary().is_stochastic) {
960+
if (lowered_->kernel()->summary().max_rng_offsets >= 0) {
965961
kernel_arguments.appendPhiloxRNGSeed(rand_offset);
966962
}
967963
}

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,36 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18661866
return strided_inds;
18671867
}
18681868

1869+
std::vector<Val*> Index::getRandomTensorStridedIndices(
1870+
TensorView* consumer_tv,
1871+
const std::vector<kir::ForLoop*>& loops) {
1872+
// Use domain guard to ignore the contiguity of
1873+
// consumer tv.
1874+
TensorDomain* consumer_tv_no_contiguity_domain = nullptr;
1875+
auto contiguity_vector =
1876+
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), true);
1877+
if (consumer_tv->hasRFactor()) {
1878+
consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1879+
consumer_tv->getRootDomain(),
1880+
consumer_tv->getRFactorDomain(),
1881+
consumer_tv->domain()->domain(),
1882+
contiguity_vector);
1883+
} else {
1884+
consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1885+
consumer_tv->getRootDomain(),
1886+
consumer_tv->domain()->domain(),
1887+
contiguity_vector);
1888+
}
1889+
1890+
ir_utils::TVDomainGuard domain_guard(
1891+
consumer_tv, consumer_tv_no_contiguity_domain);
1892+
1893+
// TODO:
1894+
// More optimization on the underlying tensor layout
1895+
// will be done in a follow up.
1896+
return getGlobalConsumerStridedIndices(consumer_tv, loops);
1897+
}
1898+
18691899
std::vector<Val*> Index::getGlobalConsumerStridedIndices(
18701900
const TensorView* consumer_tv,
18711901
const std::vector<kir::ForLoop*>& loops) {

torch/csrc/jit/codegen/cuda/index_compute.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,15 @@ class Index {
341341
const TensorView* consumer,
342342
const std::vector<kir::ForLoop*>& loops);
343343

344+
//! Returns a vector of strided indices mapped onto the (rfactor)
345+
//! root domain of a consumer tensor. The returned index is intended
346+
//! to be used to index into Philox pseudo random sequences so that
347+
//! inlined multivisit to the same element in a random tensor returns
348+
//! consistent values.
349+
static std::vector<Val*> getRandomTensorStridedIndices(
350+
TensorView* consumer_tv,
351+
const std::vector<kir::ForLoop*>& loops);
352+
344353
//! Take a consumer tensorview and loop nest and generates predicates
345354
//! associated with the concrete roots of the loop nest. Returns a list of
346355
//! predicates, and a list of concrete roots they're associated with. It is

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ bool areEqualScalars(Val* v1, Val* v2);
3737
//! 4) split/merge
3838
class TORCH_CUDA_CU_API UnaryOp : public Expr {
3939
public:
40-
UnaryOp(IrBuilderPasskey, UnaryOpType type, Val* out, Val* in);
40+
UnaryOp(
41+
IrBuilderPasskey,
42+
UnaryOpType type,
43+
Val* out,
44+
Val* in,
45+
int rng_offset = -1);
4146

4247
UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
4348

@@ -52,12 +57,23 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {
5257
return unary_op_type_;
5358
}
5459

60+
int getRNGOffset() const {
61+
return rng_offset_;
62+
}
63+
64+
void setRNGOffset(int val) {
65+
rng_offset_ = val;
66+
}
67+
5568
bool sameAs(const Statement* other) const override;
5669

5770
private:
5871
const UnaryOpType unary_op_type_;
5972
Val* const out_ = nullptr;
6073
Val* const in_ = nullptr;
74+
// TODO: pull RNG op out of Unary ops
75+
// https://github.com/csarofeen/pytorch/pull/1892
76+
int rng_offset_ = -1;
6177
};
6278

6379
//! A specialization for Binary operations. Binary operations take in two inputs

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,17 @@ bool ComplexDouble::sameAs(const Statement* other) const {
182182
return false;
183183
}
184184

185-
UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in)
185+
UnaryOp::UnaryOp(
186+
IrBuilderPasskey passkey,
187+
UnaryOpType type,
188+
Val* out,
189+
Val* in,
190+
int rng_offset)
186191
: Expr(passkey, ExprType::UnaryOp),
187192
unary_op_type_{type},
188193
out_{out},
189-
in_{in} {
194+
in_{in},
195+
rng_offset_(rng_offset) {
190196
addOutput(out);
191197
addInput(in);
192198
}
@@ -195,7 +201,8 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner)
195201
: Expr(src, ir_cloner),
196202
unary_op_type_(src->unary_op_type_),
197203
out_(ir_cloner->clone(src->out_)),
198-
in_(ir_cloner->clone(src->in_)) {}
204+
in_(ir_cloner->clone(src->in_)),
205+
rng_offset_(src->rng_offset_) {}
199206

200207
bool UnaryOp::sameAs(const Statement* other) const {
201208
if (this == other) {

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ struct SubstituteInExpr : public OptInDispatch {
186186
auto out =
187187
reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out();
188188
expr_ = IrBuilder::create<UnaryOp>(
189-
unary_expr->container(), unary_expr->getUnaryOpType(), out, in);
189+
unary_expr->container(),
190+
unary_expr->getUnaryOpType(),
191+
out,
192+
in,
193+
unary_expr->getRNGOffset());
190194
}
191195

192196
void handle(BinaryOp* binary_expr) final {
@@ -887,7 +891,8 @@ struct ReplaceValInIndexVal : public OptInDispatch {
887891
auto inp = last_visited_val_;
888892
TORCH_INTERNAL_ASSERT(uop->out()->isA<Int>());
889893
auto out = IrBuilder::create<Int>(c10::nullopt);
890-
IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, inp);
894+
IrBuilder::create<UnaryOp>(
895+
uop->getUnaryOpType(), out, inp, uop->getRNGOffset());
891896
last_visited_val_ = out;
892897
}
893898

0 commit comments

Comments
 (0)