|
30 | 30 | from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
|
31 | 31 | from torch_tensorrt.fx.converters.impl.elementwise import fmod
|
32 | 32 | from torch_tensorrt.fx.converters.impl.normalization import batch_norm
|
| 33 | +from torch_tensorrt.fx.converters.impl.normalization import layer_norm |
33 | 34 | from torch_tensorrt.fx.converters.impl.unary import sign
|
34 | 35 | from torch_tensorrt.fx.converters.impl.elementwise.base import (
|
35 | 36 | convert_binary_elementwise,
|
@@ -649,171 +650,16 @@ def acc_ops_batch_norm(
|
649 | 650 |
|
650 | 651 | @tensorrt_converter(acc_ops.layer_norm)
|
651 | 652 | def acc_ops_layer_norm(network, target, args, kwargs, name):
|
652 |
| - input_val = kwargs["input"] |
653 |
| - |
654 |
| - if not isinstance(input_val, trt.tensorrt.ITensor): |
655 |
| - raise RuntimeError( |
656 |
| - f"LayerNorm received input {input_val} that is not part " |
657 |
| - "of the TensorRT region!" |
658 |
| - ) |
659 |
| - |
660 |
| - gamma = kwargs["weight"].detach().cpu().float().numpy() |
661 |
| - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) |
662 |
| - beta = kwargs["bias"].detach().cpu().float().numpy() |
663 |
| - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) |
664 |
| - eps_field = trt.PluginField( |
665 |
| - "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 |
666 |
| - ) |
667 |
| - normalized_shape = kwargs["normalized_shape"] |
668 |
| - try: |
669 |
| - normalized_shape = np.array(normalized_shape, dtype=np.int32) |
670 |
| - except TypeError: |
671 |
| - _LOGGER.error( |
672 |
| - f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []" |
673 |
| - ) |
674 |
| - normalized_shape = np.array([], dtype=np.int32) |
675 |
| - |
676 |
| - normalized_shape_filed = trt.PluginField( |
677 |
| - "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 |
678 |
| - ) |
679 |
| - field_collection = trt.PluginFieldCollection( |
680 |
| - [gamma_field, beta_field, eps_field, normalized_shape_filed] |
681 |
| - ) |
682 |
| - |
683 |
| - try: |
684 |
| - if network.has_implicit_batch_dimension: |
685 |
| - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") |
686 |
| - else: |
687 |
| - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") |
688 |
| - except AssertionError: |
689 |
| - _LOGGER.error( |
690 |
| - "Unable to find layer norm plugin, fall back to TensorRT implementation." |
691 |
| - ) |
692 |
| - return layer_norm(network, target, args, kwargs, name) |
693 |
| - layer = network.add_plugin_v2([input_val], plugin) |
694 |
| - layer.name = name |
695 |
| - return layer.get_output(0) |
696 |
| - |
697 |
| - |
698 |
| -def layer_norm( |
699 |
| - network: TRTNetwork, |
700 |
| - target: Target, |
701 |
| - args: Tuple[Argument, ...], |
702 |
| - kwargs: Dict[str, Argument], |
703 |
| - name: str, |
704 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
705 |
| - input_val = kwargs["input"] |
706 |
| - |
707 |
| - if not isinstance(input_val, TRTTensor): |
708 |
| - raise RuntimeError( |
709 |
| - f"LayerNorm received input {input_val} that is not part " |
710 |
| - "of the TensorRT region!" |
711 |
| - ) |
712 |
| - |
713 |
| - shape = kwargs["weight"].shape # type: ignore[union-attr] |
714 |
| - broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape |
715 |
| - gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr] |
716 |
| - beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr] |
717 |
| - eps = kwargs["eps"] |
718 |
| - |
719 |
| - axes = 0 |
720 |
| - for d in range(len(shape)): |
721 |
| - axes |= 1 << (len(input_val.shape) - d - 1) |
722 |
| - |
723 |
| - # E[x] |
724 |
| - mean_expected_layer = network.add_reduce( |
725 |
| - input_val, trt.ReduceOperation.AVG, axes, keep_dims=True |
726 |
| - ) |
727 |
| - set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") |
728 |
| - |
729 |
| - # X-E[x] |
730 |
| - sub_trt = convert_binary_elementwise( |
731 |
| - network, |
732 |
| - target, |
733 |
| - SourceIR.ACC, |
734 |
| - f"{name}_sub", |
735 |
| - trt.ElementWiseOperation.SUB, |
736 |
| - input_val, |
737 |
| - mean_expected_layer.get_output(0), |
738 |
| - ) |
739 |
| - # Variance = mean(pow(x_sub_mean,2)) |
740 |
| - pow_tensor = network.add_constant( |
741 |
| - (1,) * len(input_val.shape), |
742 |
| - trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), |
743 |
| - ) |
744 |
| - pow_tensor.name = f"{name}_power" |
745 |
| - pow_var = convert_binary_elementwise( |
746 |
| - network, |
747 |
| - target, |
748 |
| - SourceIR.ACC, |
749 |
| - f"{name}_pow_var", |
750 |
| - trt.ElementWiseOperation.POW, |
751 |
| - sub_trt, |
752 |
| - pow_tensor.get_output(0), |
753 |
| - ) |
754 |
| - mean_trt_layer = network.add_reduce( |
755 |
| - pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True |
756 |
| - ) |
757 |
| - set_layer_name(mean_trt_layer, target, f"{name}_mean") |
758 |
| - # Variance + eps |
759 |
| - eps_tensor = network.add_constant( |
760 |
| - (1,) * len(input_val.shape), |
761 |
| - trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), |
762 |
| - ) |
763 |
| - eps_tensor.name = f"{name}_eps" |
764 |
| - add_trt = convert_binary_elementwise( |
765 |
| - network, |
766 |
| - target, |
767 |
| - SourceIR.ACC, |
768 |
| - f"{name}_add", |
769 |
| - trt.ElementWiseOperation.SUM, |
770 |
| - mean_trt_layer.get_output(0), |
771 |
| - eps_tensor.get_output(0), |
772 |
| - ) |
773 |
| - # SQRT((Var + eps)) |
774 |
| - sqrt_trt = convert_unary( |
775 |
| - network, |
776 |
| - target, |
777 |
| - SourceIR.ACC, |
778 |
| - f"{name}_sqrt", |
779 |
| - trt.UnaryOperation.SQRT, |
780 |
| - add_trt, |
781 |
| - ) |
782 |
| - # (x - E[x]) / sqrt((var + eps)) |
783 |
| - div_trt = convert_binary_elementwise( |
784 |
| - network, |
785 |
| - target, |
786 |
| - SourceIR.ACC, |
787 |
| - f"{name}_div_trt", |
788 |
| - trt.ElementWiseOperation.DIV, |
789 |
| - sub_trt, |
790 |
| - sqrt_trt, |
791 |
| - ) |
792 |
| - |
793 |
| - assert gamma is not None |
794 |
| - gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] |
795 |
| - gamma_tensor.name = f"{name}_gamma" |
796 |
| - assert beta is not None |
797 |
| - beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] |
798 |
| - beta_tensor.name = f"{name}_beta" |
799 |
| - # y * gamma + beta |
800 |
| - scale_layer = convert_binary_elementwise( |
801 |
| - network, |
802 |
| - target, |
803 |
| - SourceIR.ACC, |
804 |
| - f"{name}_scale", |
805 |
| - trt.ElementWiseOperation.PROD, |
806 |
| - div_trt, |
807 |
| - gamma_tensor.get_output(0), |
808 |
| - ) |
809 |
| - return convert_binary_elementwise( |
| 653 | + return layer_norm( |
810 | 654 | network,
|
811 | 655 | target,
|
812 | 656 | SourceIR.ACC,
|
813 | 657 | name,
|
814 |
| - trt.ElementWiseOperation.SUM, |
815 |
| - scale_layer, |
816 |
| - beta_tensor.get_output(0), |
| 658 | + kwargs["input"], |
| 659 | + kwargs["normalized_shape"], |
| 660 | + kwargs["weight"], |
| 661 | + kwargs["bias"], |
| 662 | + kwargs["eps"], |
817 | 663 | )
|
818 | 664 |
|
819 | 665 |
|
|
0 commit comments