Skip to content

Commit 68a6113

Browse files
IvanYashchukpytorchmergebot
authored andcommitted
Add nvFuser support for torch.native_batch_norm (#85562)
This PR adds nvFuser's implementation for batch_norm as there's no reference yet (#81191) and no in-place copy support (#84545). Pull Request resolved: #85562 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
1 parent d28a882 commit 68a6113

File tree

11 files changed

+371
-4
lines changed

11 files changed

+371
-4
lines changed

functorch/test/test_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def wrapped_fn(*args, **kwargs):
395395
skip('nn.functional.max_unpool1d'), # fails everywhere except on mac
396396
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
397397
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
398+
xfail("native_batch_norm"),
398399
399400
xfail('nn.functional.rrelu') # in-place test errors out with no formula implemented
400401
}))
@@ -643,6 +644,7 @@ def fn(inp, *args, **kwargs):
643644
xfail("nn.functional.batch_norm", 'without_cudnn'),
644645
# view doesn't work on sparse
645646
xfail("to_sparse"),
647+
xfail("native_batch_norm"),
646648
}))
647649
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
648650
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@@ -725,6 +727,7 @@ def vjp_of_vjp(*args_and_cotangents):
725727
# ---------------------------- BUGS ------------------------------------
726728
# All of the following are bugs and need to be fixed
727729
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
730+
skip("native_batch_norm"),
728731
xfail('__getitem__', ''), # dynamic error
729732
xfail('linalg.eig'), # Uses aten::allclose
730733
xfail('linalg.householder_product'), # needs select_scatter
@@ -833,6 +836,7 @@ def test_vmapvjp(self, device, dtype, op):
833836
# erroring because running_mean and running_var aren't differentiable
834837
xfail('nn.functional.batch_norm'),
835838
xfail('nn.functional.batch_norm', 'without_cudnn'),
839+
xfail("native_batch_norm"),
836840
# ----------------------------------------------------------------------
837841
}
838842

@@ -1030,6 +1034,7 @@ def test():
10301034
xfail('linalg.vecdot', ''),
10311035
xfail('segment_reduce', 'lengths'),
10321036
xfail('sparse.sampled_addmm', ''),
1037+
xfail("native_batch_norm"),
10331038
}))
10341039
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
10351040
if not op.supports_autograd:
@@ -1095,6 +1100,7 @@ def test():
10951100
xfail('nn.functional.dropout3d', ''),
10961101
xfail('as_strided_scatter', ''),
10971102
xfail('sparse.sampled_addmm', ''),
1103+
xfail("native_batch_norm"),
10981104
}))
10991105
def test_vjpvmap(self, device, dtype, op):
11001106
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@@ -1338,6 +1344,10 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
13381344
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
13391345
xfail('to_sparse'), # Forward AD not implemented and no decomposition
13401346
xfail('view_as_complex'), # RuntimeError: Tensor must have a last dimension with stride 1
1347+
# RuntimeError: Batch norm got a batched tensor as
1348+
# input while the running_mean or running_var, which will be updated in
1349+
# place, were not batched.
1350+
xfail("native_batch_norm"),
13411351
}))
13421352
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
13431353
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})

functorch/test/test_vmap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3287,6 +3287,7 @@ def test():
32873287
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
32883288
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
32893289
xfail('cat'),
3290+
xfail('native_batch_norm'),
32903291
}))
32913292
def test_vmap_exhaustive(self, device, dtype, op):
32923293
# needs to be fixed
@@ -3306,6 +3307,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
33063307
xfail('cat'),
33073308
xfail('complex'),
33083309
xfail('copysign'),
3310+
xfail('native_batch_norm'),
33093311
xfail('histogram'),
33103312
xfail('index_fill'),
33113313
xfail('nansum'),

test/test_prims.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,69 @@ def func(a):
548548
self.assertFalse(node.target == torch.ops.prims.add.default)
549549
self.assertFalse(node.target == torch.ops.aten.add.default)
550550

551+
@onlyCUDA
552+
@skipCUDAIfRocm
553+
@dtypes(torch.float32, torch.float64)
554+
def test_native_batch_norm_nvprims(self, device, dtype):
555+
from torch._prims.context import TorchRefsNvfuserCapabilityMode
556+
from torch._prims.executor import execute
557+
558+
# This test verifies that native_batch_norm is translated into nvprims
559+
# and can be executed with nvFuser
560+
from torch.fx.experimental.proxy_tensor import make_fx
561+
from torch.testing._internal.common_methods_invocations import (
562+
sample_inputs_native_batch_norm,
563+
)
564+
565+
samples = sample_inputs_native_batch_norm(
566+
None, device, dtype, requires_grad=False
567+
)
568+
batch_norms = [
569+
torch.native_batch_norm,
570+
torch.ops.aten.native_batch_norm,
571+
torch.ops.aten.native_batch_norm.default,
572+
torch.ops.nvprims.native_batch_norm.default,
573+
]
574+
for sample, batch_norm in product(samples, batch_norms):
575+
if sample.input.numel() == 0:
576+
continue
577+
578+
def func(
579+
input, weight, bias, running_mean, running_var, training, momentum, eps
580+
):
581+
return batch_norm(
582+
input,
583+
weight,
584+
bias,
585+
running_mean,
586+
running_var,
587+
training,
588+
momentum,
589+
eps,
590+
)
591+
592+
with TorchRefsNvfuserCapabilityMode():
593+
gm = make_fx(func)(sample.input, *sample.args)
594+
595+
call_function_nodes = list(
596+
filter(lambda n: n.op == "call_function", gm.graph.nodes)
597+
)
598+
includes_aten_batch_norm = any(
599+
torch.ops.aten.native_batch_norm.default == node.target
600+
for node in call_function_nodes
601+
)
602+
self.assertFalse(includes_aten_batch_norm)
603+
604+
includes_nvprims_batch_norm = any(
605+
torch.ops.nvprims.native_batch_norm.default == node.target
606+
for node in call_function_nodes
607+
)
608+
self.assertTrue(includes_nvprims_batch_norm)
609+
610+
# Check that the graph can be executed with nvFuser
611+
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
612+
self.assertEqual(out, gm(sample.input, *sample.args))
613+
551614
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
552615
@onlyCUDA
553616
@dtypes(torch.float32, torch.float16)

torch/_prims/context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ def _is_var_mean(self, func):
265265
and "aten.var_mean" in str(func)
266266
)
267267

268+
def _is_native_batch_norm(self, func):
269+
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
270+
func == torch.ops.aten.native_batch_norm.default
271+
or func == torch.ops.aten.native_batch_norm
272+
)
273+
268274
def _is_rand_like(self, func):
269275
result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
270276
func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
@@ -283,9 +289,14 @@ def __torch_function__(
283289
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
284290
if self._is_var_mean(orig_func):
285291
return torch.ops.nvprims.var_mean(*args, **kwargs)
292+
293+
if self._is_native_batch_norm(orig_func):
294+
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
295+
286296
if self._is_rand_like(orig_func):
287297
if len(kwargs) > 0:
288298
warn("rand_like has ignored kwars!")
289299
return torch.ops.nvprims.rand_like(*args)
300+
290301
# Then we use TorchRefsMode to interpret the rest
291302
return super().__torch_function__(orig_func, types, args, kwargs)

torch/_prims/nvfuser_executor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ def run_node(self, node):
136136
args, kwargs = self.fetch_args_kwargs_from_env(node)
137137
args = [args[0], original_shape, args[1]]
138138
return self.call_function(node.target, args, node.kwargs)
139+
140+
if node.target in [
141+
torch.ops.nvprims.native_batch_norm,
142+
torch.ops.nvprims.native_batch_norm.default,
143+
]:
144+
args, kwargs = self.fetch_args_kwargs_from_env(node)
145+
assert len(args) == 8
146+
training = args[5]
147+
args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
148+
args = args[:5] + (training,) + args6_end
149+
return node.target.impl_nvfuser(fd, *args, **kwargs)
150+
139151
return super().run_node(node)
140152

141153
def call_function(self, target, args, kwargs):

torch/_prims/nvfuser_prims.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,29 @@ def _{fname}_nvfuser(fd, a, b, c):
210210
)
211211

212212

213+
def _native_batch_norm_nvfuser(
214+
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
215+
):
216+
if weight is None:
217+
weight = fd.define_null_tensor()
218+
if bias is None:
219+
bias = fd.define_null_tensor()
220+
if running_mean is None:
221+
running_mean = fd.define_null_tensor()
222+
if running_var is None:
223+
running_var = fd.define_null_tensor()
224+
return fd.ops.batch_norm(
225+
input,
226+
weight,
227+
bias,
228+
running_mean,
229+
running_var,
230+
training,
231+
momentum,
232+
eps,
233+
)
234+
235+
213236
def _broadcast_in_dim_nvfuser(
214237
fd: Any,
215238
a: TensorLikeType,
@@ -299,6 +322,7 @@ def _amin_nvfuser(
299322
return fd.ops.min(a, dims, keep_dims)
300323

301324

325+
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
302326
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
303327
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
304328
_nvfuser_impls["transpose"] = _transpose_nvfuser
@@ -312,6 +336,36 @@ def _amin_nvfuser(
312336
_nvfuser_impls["amin"] = _amin_nvfuser
313337

314338

339+
def register_native_batch_norm():
340+
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
341+
name = "native_batch_norm"
342+
343+
nvprim.define(
344+
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
345+
+ "bool training, float momentum, float eps)"
346+
+ " -> (Tensor, Tensor, Tensor)"
347+
)
348+
349+
def _prim_impl(
350+
input, weight, bias, running_mean, running_var, training, momentum, eps
351+
):
352+
return torch.native_batch_norm(
353+
input, weight, bias, running_mean, running_var, training, momentum, eps
354+
)
355+
356+
nvprim_impl.impl(name, _prim_impl)
357+
nvprim_autograd_impl.impl(
358+
name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default)
359+
)
360+
361+
prim_packet = torch.ops.nvprims.native_batch_norm
362+
prim = prim_packet.default
363+
for p in (prim_packet, prim):
364+
p.__doc__ = "Computes batch normalization."
365+
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
366+
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
367+
368+
315369
def register_rand_like():
316370
name = "rand_like"
317371

@@ -471,6 +525,7 @@ def _var_mean_autograd(
471525
def register_nvprims():
472526
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
473527
register_var_mean()
528+
register_native_batch_norm()
474529
register_rand_like()
475530

476531
for name in nvprim_names:

torch/csrc/jit/codegen/cuda/ops/normalization.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,11 @@ ForwardNormResult batch_norm(
587587
auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
588588

589589
// During inference, mean/invstd output are empty tensors
590-
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
591-
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
590+
// on CPU, but not on CUDA. We need to make sure we have the same
591+
// behavior as with eager mode on CUDA.
592+
mean = set(running_mean); // use set to avoid "trivial input forwarding NOT
593+
// IMPLEMENTED" error
594+
invstd = unbiased_invstd;
592595
y = mul(x_sub_mean, invstd_bcast);
593596
}
594597

torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ void FusionInterface::addOutput(Nvf::Val* output) const {
3232

3333
std::vector<at::Tensor> FusionInterface::execute(
3434
const at::ArrayRef<c10::IValue>& inputs) const {
35-
return fusionExecutorCachePtr()->runFusionWithInputs(inputs);
35+
// aliasOutputToInput always adds Tensors as outputs that we don't want
36+
// to return to the user. We need to remove them.
37+
auto count_output_aliases = fusionPtr()->getOutputAliasIndices().size();
38+
auto result = fusionExecutorCachePtr()->runFusionWithInputs(inputs);
39+
result.erase(result.begin(), result.begin() + count_output_aliases);
40+
return result;
3641
}
3742

3843
Nvf::FusionGuard FusionInterface::guard() const {

0 commit comments

Comments
 (0)