Skip to content

Commit 59354e5

Browse files
apbosegs-olive
authored andcommitted
Converter reorg batch norm
batch norm error fix and linting issue error fix
1 parent db15d27 commit 59354e5

File tree

4 files changed

+124
-63
lines changed

4 files changed

+124
-63
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 14 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch_tensorrt.fx.converters.impl import activation
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
3131
from torch_tensorrt.fx.converters.impl.elementwise import fmod
32+
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
3233
from torch_tensorrt.fx.converters.impl.unary import sign
3334
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3435
convert_binary_elementwise,
@@ -630,58 +631,20 @@ def acc_ops_batch_norm(
630631
kwargs: Dict[str, Argument],
631632
name: str,
632633
) -> Union[TRTTensor, Sequence[TRTTensor]]:
633-
input_val = kwargs["input"]
634-
635-
if not isinstance(input_val, TRTTensor):
636-
raise RuntimeError(
637-
f"BatchNorm2d received input {input_val} that is not part "
638-
"of the TensorRT region!"
639-
)
640-
641-
if has_dynamic_shape(input_val.shape):
642-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
643-
644-
scale = cast(
645-
torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))
646-
) / np.sqrt(
647-
cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"])))
648-
+ cast(float, kwargs["eps"])
649-
)
650-
651-
bias = (
652-
to_numpy(cast(torch.Tensor, kwargs["bias"]))
653-
- to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
634+
return batch_norm(
635+
network,
636+
target,
637+
SourceIR.ACC,
638+
name,
639+
kwargs["input"],
640+
kwargs["weight"],
641+
kwargs["bias"],
642+
kwargs["running_mean"],
643+
kwargs["running_var"],
644+
kwargs["training"],
645+
kwargs["momentum"],
646+
kwargs["eps"],
654647
)
655-
power = np.ones_like(scale)
656-
657-
# For BatchNorm1d, reshape 1d to 2d
658-
output_shape = input_val.shape
659-
if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
660-
assert (
661-
len(get_dynamic_dims(input_val.shape)) <= 1
662-
), "BatchNorm1D with more than one dynamic dims is not currently supported."
663-
reshape_layer = network.add_shuffle(input_val)
664-
if len(input_val.shape) == 2:
665-
reshape_layer.reshape_dims = (input_val.shape[0], input_val.shape[1], 1, 1)
666-
else: # len(input_val.shape) == 3
667-
reshape_layer.reshape_dims = (
668-
input_val.shape[0],
669-
input_val.shape[1],
670-
input_val.shape[2],
671-
1,
672-
)
673-
set_layer_name(reshape_layer, target, f"{name}_reshape_2d")
674-
input_val = reshape_layer.get_output(0)
675-
layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
676-
set_layer_name(layer, target, name)
677-
678-
# For BatchNorm1d, reshape output back to 1d
679-
if not network.has_implicit_batch_dimension and len(output_shape) < 4:
680-
reshape_output_layer = network.add_shuffle(layer.get_output(0))
681-
reshape_output_layer.reshape_dims = tuple(output_shape)
682-
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d")
683-
layer = reshape_output_layer
684-
return layer.get_output(0)
685648

686649

687650
@tensorrt_converter(acc_ops.layer_norm)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
2626
from torch_tensorrt.fx.converters.impl.elementwise import fmod
2727
from torch_tensorrt.fx.converters.impl.elementwise import rsub
28+
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
2829

2930
_LOGGER: logging.Logger = logging.getLogger(__name__)
3031

@@ -93,18 +94,19 @@ def aten_ops_batch_norm(
9394
kwargs: Dict[str, Argument],
9495
name: str,
9596
) -> Union[TRTTensor, Sequence[TRTTensor]]:
96-
kwargs_new = {
97-
"input": args[0],
98-
"weight": args[1],
99-
"bias": args[2],
100-
"running_mean": args[3],
101-
"running_var": args[4],
102-
"training": args[5],
103-
"momentum": args[6],
104-
"eps": args[7],
105-
}
106-
return acc_ops_converters.acc_ops_batch_norm(
107-
network, target, None, kwargs_new, name
97+
return batch_norm(
98+
network,
99+
target,
100+
SourceIR.ATEN,
101+
name,
102+
args[0],
103+
args[1],
104+
args[2],
105+
args[3],
106+
args[4],
107+
args[5],
108+
args[6],
109+
args[7],
108110
)
109111

110112

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import operator
2+
import warnings
3+
from typing import cast, Union, Callable, Any, Optional, Sequence
4+
import logging
5+
6+
import numpy as np
7+
8+
# @manual=//deeplearning/trt/python:py_tensorrt
9+
import tensorrt as trt
10+
import torch
11+
from torch.fx.node import Target
12+
13+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
14+
from torch_tensorrt.fx.utils import get_dynamic_dims
15+
16+
from torch_tensorrt.fx.converters.converter_utils import (
17+
SourceIR,
18+
set_layer_name,
19+
has_dynamic_shape,
20+
to_numpy,
21+
)
22+
23+
from torch_tensorrt.fx.converters.impl.unary.base import (
24+
convert_unary,
25+
)
26+
27+
from torch_tensorrt.fx.converters.impl.elementwise.base import (
28+
convert_binary_elementwise,
29+
)
30+
31+
_LOGGER: logging.Logger = logging.getLogger(__name__)
32+
33+
34+
def batch_norm(
35+
network: TRTNetwork,
36+
target: Target,
37+
source_ir: Optional[SourceIR],
38+
name: str,
39+
input: TRTTensor,
40+
weight: torch.Tensor,
41+
bias: torch.Tensor,
42+
running_mean: torch.Tensor,
43+
running_var: torch.Tensor,
44+
training: torch.Tensor,
45+
momentum: torch.Tensor,
46+
eps: list,
47+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
48+
49+
if not isinstance(input, TRTTensor):
50+
raise RuntimeError(
51+
f"BatchNorm2d received input {input} that is not part "
52+
"of the TensorRT region!"
53+
)
54+
55+
if has_dynamic_shape(input.shape):
56+
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
57+
58+
scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, weight))) / np.sqrt(
59+
cast(torch.Tensor, to_numpy(cast(torch.Tensor, running_var))) + cast(float, eps)
60+
)
61+
62+
bias = (
63+
to_numpy(cast(torch.Tensor, bias))
64+
- to_numpy(cast(torch.Tensor, running_mean)) * scale
65+
)
66+
power = np.ones_like(scale)
67+
68+
# For BatchNorm1d, reshape 1d to 2d
69+
output_shape = input.shape
70+
if not network.has_implicit_batch_dimension and len(input.shape) < 4:
71+
assert (
72+
len(get_dynamic_dims(input.shape)) <= 1
73+
), "BatchNorm1D with more than one dynamic dims is not currently supported."
74+
reshape_layer = network.add_shuffle(input)
75+
if len(input.shape) == 2:
76+
reshape_layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1)
77+
else: # len(input_val.shape) == 3
78+
reshape_layer.reshape_dims = (
79+
input.shape[0],
80+
input.shape[1],
81+
input.shape[2],
82+
1,
83+
)
84+
set_layer_name(reshape_layer, target, f"{name}_reshape_2d")
85+
input = reshape_layer.get_output(0)
86+
layer = network.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
87+
set_layer_name(layer, target, name)
88+
89+
# For BatchNorm1d, reshape output back to 1d
90+
if not network.has_implicit_batch_dimension and len(output_shape) < 4:
91+
reshape_output_layer = network.add_shuffle(layer.get_output(0))
92+
reshape_output_layer.reshape_dims = tuple(output_shape)
93+
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d")
94+
layer = reshape_output_layer
95+
return layer.get_output(0)

0 commit comments

Comments
 (0)