Skip to content

Commit e0b34b1

Browse files
apbosegs-olive
authored andcommitted
layer_norm converter
Layer norm linting correction ops file correction fixing lint Acc_ops layer_norm correction
1 parent 59354e5 commit e0b34b1

File tree

4 files changed

+254
-163
lines changed

4 files changed

+254
-163
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+7-161
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
3131
from torch_tensorrt.fx.converters.impl.elementwise import fmod
3232
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
33+
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
3334
from torch_tensorrt.fx.converters.impl.unary import sign
3435
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3536
convert_binary_elementwise,
@@ -649,171 +650,16 @@ def acc_ops_batch_norm(
649650

650651
@tensorrt_converter(acc_ops.layer_norm)
651652
def acc_ops_layer_norm(network, target, args, kwargs, name):
652-
input_val = kwargs["input"]
653-
654-
if not isinstance(input_val, trt.tensorrt.ITensor):
655-
raise RuntimeError(
656-
f"LayerNorm received input {input_val} that is not part "
657-
"of the TensorRT region!"
658-
)
659-
660-
gamma = kwargs["weight"].detach().cpu().float().numpy()
661-
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
662-
beta = kwargs["bias"].detach().cpu().float().numpy()
663-
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
664-
eps_field = trt.PluginField(
665-
"eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
666-
)
667-
normalized_shape = kwargs["normalized_shape"]
668-
try:
669-
normalized_shape = np.array(normalized_shape, dtype=np.int32)
670-
except TypeError:
671-
_LOGGER.error(
672-
f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []"
673-
)
674-
normalized_shape = np.array([], dtype=np.int32)
675-
676-
normalized_shape_filed = trt.PluginField(
677-
"normalized_shape", normalized_shape, trt.PluginFieldType.INT32
678-
)
679-
field_collection = trt.PluginFieldCollection(
680-
[gamma_field, beta_field, eps_field, normalized_shape_filed]
681-
)
682-
683-
try:
684-
if network.has_implicit_batch_dimension:
685-
plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
686-
else:
687-
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
688-
except AssertionError:
689-
_LOGGER.error(
690-
"Unable to find layer norm plugin, fall back to TensorRT implementation."
691-
)
692-
return layer_norm(network, target, args, kwargs, name)
693-
layer = network.add_plugin_v2([input_val], plugin)
694-
layer.name = name
695-
return layer.get_output(0)
696-
697-
698-
def layer_norm(
699-
network: TRTNetwork,
700-
target: Target,
701-
args: Tuple[Argument, ...],
702-
kwargs: Dict[str, Argument],
703-
name: str,
704-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
705-
input_val = kwargs["input"]
706-
707-
if not isinstance(input_val, TRTTensor):
708-
raise RuntimeError(
709-
f"LayerNorm received input {input_val} that is not part "
710-
"of the TensorRT region!"
711-
)
712-
713-
shape = kwargs["weight"].shape # type: ignore[union-attr]
714-
broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape
715-
gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr]
716-
beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr]
717-
eps = kwargs["eps"]
718-
719-
axes = 0
720-
for d in range(len(shape)):
721-
axes |= 1 << (len(input_val.shape) - d - 1)
722-
723-
# E[x]
724-
mean_expected_layer = network.add_reduce(
725-
input_val, trt.ReduceOperation.AVG, axes, keep_dims=True
726-
)
727-
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")
728-
729-
# X-E[x]
730-
sub_trt = convert_binary_elementwise(
731-
network,
732-
target,
733-
SourceIR.ACC,
734-
f"{name}_sub",
735-
trt.ElementWiseOperation.SUB,
736-
input_val,
737-
mean_expected_layer.get_output(0),
738-
)
739-
# Variance = mean(pow(x_sub_mean,2))
740-
pow_tensor = network.add_constant(
741-
(1,) * len(input_val.shape),
742-
trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)),
743-
)
744-
pow_tensor.name = f"{name}_power"
745-
pow_var = convert_binary_elementwise(
746-
network,
747-
target,
748-
SourceIR.ACC,
749-
f"{name}_pow_var",
750-
trt.ElementWiseOperation.POW,
751-
sub_trt,
752-
pow_tensor.get_output(0),
753-
)
754-
mean_trt_layer = network.add_reduce(
755-
pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
756-
)
757-
set_layer_name(mean_trt_layer, target, f"{name}_mean")
758-
# Variance + eps
759-
eps_tensor = network.add_constant(
760-
(1,) * len(input_val.shape),
761-
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
762-
)
763-
eps_tensor.name = f"{name}_eps"
764-
add_trt = convert_binary_elementwise(
765-
network,
766-
target,
767-
SourceIR.ACC,
768-
f"{name}_add",
769-
trt.ElementWiseOperation.SUM,
770-
mean_trt_layer.get_output(0),
771-
eps_tensor.get_output(0),
772-
)
773-
# SQRT((Var + eps))
774-
sqrt_trt = convert_unary(
775-
network,
776-
target,
777-
SourceIR.ACC,
778-
f"{name}_sqrt",
779-
trt.UnaryOperation.SQRT,
780-
add_trt,
781-
)
782-
# (x - E[x]) / sqrt((var + eps))
783-
div_trt = convert_binary_elementwise(
784-
network,
785-
target,
786-
SourceIR.ACC,
787-
f"{name}_div_trt",
788-
trt.ElementWiseOperation.DIV,
789-
sub_trt,
790-
sqrt_trt,
791-
)
792-
793-
assert gamma is not None
794-
gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined]
795-
gamma_tensor.name = f"{name}_gamma"
796-
assert beta is not None
797-
beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined]
798-
beta_tensor.name = f"{name}_beta"
799-
# y * gamma + beta
800-
scale_layer = convert_binary_elementwise(
801-
network,
802-
target,
803-
SourceIR.ACC,
804-
f"{name}_scale",
805-
trt.ElementWiseOperation.PROD,
806-
div_trt,
807-
gamma_tensor.get_output(0),
808-
)
809-
return convert_binary_elementwise(
653+
return layer_norm(
810654
network,
811655
target,
812656
SourceIR.ACC,
813657
name,
814-
trt.ElementWiseOperation.SUM,
815-
scale_layer,
816-
beta_tensor.get_output(0),
658+
kwargs["input"],
659+
kwargs["normalized_shape"],
660+
kwargs["weight"],
661+
kwargs["bias"],
662+
kwargs["eps"],
817663
)
818664

819665

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch_tensorrt.fx.converters.impl.elementwise import fmod
2727
from torch_tensorrt.fx.converters.impl.elementwise import rsub
2828
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
29+
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
2930

3031
_LOGGER: logging.Logger = logging.getLogger(__name__)
3132

@@ -263,10 +264,30 @@ def aten_ops_leaky_relu(
263264
kwargs: Dict[str, Argument],
264265
name: str,
265266
) -> Union[TRTTensor, Sequence[TRTTensor]]:
266-
267267
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
268268

269269

270+
@tensorrt_converter(torch.ops.aten.layer_norm.default)
271+
def aten_ops_layernorm(
272+
network: TRTNetwork,
273+
target: Target,
274+
args: Tuple[Argument, ...],
275+
kwargs: Dict[str, Argument],
276+
name: str,
277+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
278+
return layer_norm(
279+
network,
280+
target,
281+
SourceIR.ATEN,
282+
name,
283+
args[0],
284+
args[1],
285+
args[2],
286+
args[3],
287+
args[4],
288+
)
289+
290+
270291
@tensorrt_converter(torch.ops.aten.linear)
271292
def aten_ops_linear(
272293
network: TRTNetwork,

0 commit comments

Comments
 (0)