Skip to content

Commit 40184b2

Browse files
jataylopytorchmergebot
authored andcommitted
[ROCm] enabling miopen_batch_norm lowering in inductor (pytorch#105740)
Enabling miopen_batch_norm lowering for inductor only. This is to avoid errors observed in some models and perf difference is very close from initial benchmarks. ``` LoweringException: RuntimeError: Expected contiguous tensor, but got non-contiguous tensor for argument #1 'input' (while checking arguments for miopen_batch_norm) target: aten.miopen_batch_norm.default ``` Pull Request resolved: pytorch#105740 Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
1 parent 7a3503d commit 40184b2

5 files changed

+33
-11
lines changed

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2831,7 +2831,6 @@ def test_batch_norm_2d(self):
28312831
)
28322832

28332833
# From yolov3
2834-
@skipIfRocm
28352834
def test_batch_norm_2d_2(self):
28362835
if self.device == "cpu":
28372836
raise unittest.SkipTest("requires CUDA")

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
IS_CI,
1212
IS_WINDOWS,
1313
TEST_WITH_ASAN,
14-
TEST_WITH_ROCM,
1514
TestCase,
1615
)
1716
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@@ -286,9 +285,6 @@ def run(*ex, **kwargs):
286285
"test_aliased_buffer_reuse_dynamic_shapes": TestFailure(("cpu",)),
287286
}
288287

289-
if TEST_WITH_ROCM:
290-
# aten.miopen_batch_norm is not registered for lowering
291-
test_failures["test_batch_norm_2d_dynamic_shapes"] = TestFailure("cuda")
292288

293289
DynamicShapesCodegenCommonTemplate = make_dynamic_cls(
294290
CommonTemplate, xfail_prop="_expected_failure_codegen_dynamic"

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@
5959
test_failures["test_expanded_reduction_dynamic_shapes"] = TestFailure(
6060
("cuda"), is_skip=True
6161
)
62-
test_failures["test_batch_norm_2d_dynamic_shapes"] = TestFailure(
63-
("cuda"), is_skip=True
64-
)
6562

6663

6764
def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"):

torch/_inductor/decomposition.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import math
44
import numbers
5+
import typing
56

67
import torch
78
import torch._decomp as decomp
@@ -400,6 +401,38 @@ def _foreach_lerp_scalar(start_tensors, end_tensors, weight):
400401
)
401402

402403

404+
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
405+
@register_decomposition(aten.miopen_batch_norm)
406+
def miopen_batch_norm(
407+
input: torch.Tensor,
408+
weight: torch.Tensor,
409+
bias: typing.Optional[torch.Tensor],
410+
running_mean: typing.Optional[torch.Tensor],
411+
running_var: typing.Optional[torch.Tensor],
412+
training: bool,
413+
exponential_average_factor: float,
414+
epsilon: float,
415+
):
416+
a, b, c = aten.native_batch_norm(
417+
input,
418+
weight,
419+
bias,
420+
running_mean,
421+
running_var,
422+
training,
423+
exponential_average_factor,
424+
epsilon,
425+
)
426+
427+
if training:
428+
return (a, b, c)
429+
return (
430+
a,
431+
weight.new_zeros((0,)),
432+
weight.new_zeros((0,)),
433+
)
434+
435+
403436
@functools.lru_cache(None)
404437
def fast_random_decomps():
405438
return {**decompositions, **extra_random_decomps}

torch/_inductor/lowering.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,9 +1872,6 @@ def apply_constraint(arg, fx_arg):
18721872
# fails accuracy on test_torch.py, and explicit fallback required to avoid warn=True on implicit
18731873
make_fallback(aten.exponential.default, warn=False)
18741874

1875-
# ROCm specific fallback, perf issues are observed when registered
1876-
make_fallback(aten.miopen_batch_norm, warn=False)
1877-
18781875

18791876
# Register with type_promotion_kind None.
18801877
# For example, fp16.copy_(fp32) should **not** promote the first input's dtype.

0 commit comments

Comments
 (0)