@@ -20,7 +20,9 @@ using namespace torch::jit::fuser::cuda;
20
20
static void setupRMSNorm_BWD (Fusion* fusion, DataType dtype) {
21
21
FusionGuard fg (fusion);
22
22
23
- TORCH_INTERNAL_ASSERT (dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16);
23
+ TORCH_INTERNAL_ASSERT (
24
+ dtype == DataType::Float || dtype == DataType::Half ||
25
+ dtype == DataType::BFloat16);
24
26
25
27
const int kReductionAxis = 2 ;
26
28
Double* eps_ptr = IrBuilder::create<Double>(1e-6 );
@@ -47,14 +49,12 @@ static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) {
47
49
rstd = castOp (DataType::Float, rstd);
48
50
}
49
51
50
- auto rms_norm_results = rms_norm_backward (
51
- grad_out, input, {1 }, rstd, weight, {true , true , true });
52
+ auto rms_norm_results =
53
+ rms_norm_backward ( grad_out, input, {1 }, rstd, weight, {true , true , true });
52
54
53
- if (dtype != DataType::Float ) {
54
- rms_norm_results.grad_input =
55
- castOp (dtype, rms_norm_results.grad_input );
56
- rms_norm_results.grad_weight =
57
- castOp (dtype, rms_norm_results.grad_weight );
55
+ if (dtype != DataType::Float) {
56
+ rms_norm_results.grad_input = castOp (dtype, rms_norm_results.grad_input );
57
+ rms_norm_results.grad_weight = castOp (dtype, rms_norm_results.grad_weight );
58
58
}
59
59
60
60
fusion->addOutput (rms_norm_results.grad_input );
@@ -65,10 +65,11 @@ static void NvFuserScheduler_RMSNorm_BWD(
65
65
benchmark::State& benchmark_state,
66
66
FusionExecutorCache* fusion_executor_cache,
67
67
DataType dtype) {
68
- TORCH_INTERNAL_ASSERT (dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16);
68
+ TORCH_INTERNAL_ASSERT (
69
+ dtype == DataType::Float || dtype == DataType::Half ||
70
+ dtype == DataType::BFloat16);
69
71
70
- std::vector<int64_t > input_shape{
71
- 8 , benchmark_state.range (0 ), 1024 };
72
+ std::vector<int64_t > input_shape{8 , benchmark_state.range (0 ), 1024 };
72
73
73
74
// inputs
74
75
at::manual_seed (0 );
@@ -79,15 +80,13 @@ static void NvFuserScheduler_RMSNorm_BWD(
79
80
at::Tensor weight = at::randn ({input_shape[2 ]}, options);
80
81
at::Tensor rstd = at::randn ({input_shape[0 ], input_shape[1 ], 1 }, options);
81
82
82
- std::vector<c10::IValue> aten_inputs (
83
- {grad_out, input, weight, rstd});
83
+ std::vector<c10::IValue> aten_inputs ({grad_out, input, weight, rstd});
84
84
85
85
runBenchmarkIterations (benchmark_state, fusion_executor_cache, aten_inputs);
86
86
87
87
benchmark_state.SetBytesProcessed (
88
88
int64_t (benchmark_state.iterations ()) *
89
- (3 * input.numel () + weight.numel () +
90
- rstd.numel ()) *
89
+ (3 * input.numel () + weight.numel () + rstd.numel ()) *
91
90
int64_t (dataTypeSize (dtype)));
92
91
}
93
92
0 commit comments