@@ -678,7 +678,13 @@ def acc_ops_batch_norm(
678
678
679
679
680
680
@tensorrt_converter (acc_ops .layer_norm )
681
- def acc_ops_layer_norm (network , target , args , kwargs , name ):
681
+ def acc_ops_layer_norm (
682
+ network : TRTNetwork ,
683
+ target : Target ,
684
+ args : Tuple [Argument , ...],
685
+ kwargs : Dict [str , Argument ],
686
+ name : str ,
687
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
682
688
return add_layer_norm (network , target , kwargs , name )
683
689
684
690
@@ -690,37 +696,7 @@ def acc_ops_softmax(
690
696
kwargs : Dict [str , Argument ],
691
697
name : str ,
692
698
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
693
- input_val = kwargs ["input" ]
694
- input_ranks = len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 ) # type: ignore[union-attr]
695
-
696
- if not isinstance (input_val , TRTTensor ):
697
- raise RuntimeError (
698
- f"softmax received input { input_val } that is not part "
699
- "of the TensorRT region!"
700
- )
701
-
702
- # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
703
- def get_softmax_dim (ndim : int ) -> int :
704
- if ndim == 0 or ndim == 1 or ndim == 3 :
705
- ret = 0
706
- else :
707
- ret = 1
708
- return ret
709
-
710
- if kwargs ["dim" ] is None :
711
- dim = get_softmax_dim (input_ranks )
712
- else :
713
- dim = cast (int , kwargs ["dim" ])
714
-
715
- dim = get_positive_dim (dim , input_ranks )
716
- if network .has_implicit_batch_dimension :
717
- assert dim != 0 , "Can't apply softmax on batch dimension when it's implicit."
718
- dim -= 1
719
-
720
- layer = network .add_softmax (input_val )
721
- layer .axes = 1 << dim
722
- set_layer_name (layer , target , name )
723
- return layer .get_output (0 )
699
+ return add_softmax (network , target , kwargs , name )
724
700
725
701
726
702
@tensorrt_converter (acc_ops .tile )
@@ -956,9 +932,7 @@ def acc_ops_sqrt(
956
932
kwargs : Dict [str , Argument ],
957
933
name : str ,
958
934
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
959
- input_val = kwargs ["input" ]
960
- operation_type = trt .UnaryOperation .SQRT
961
- return add_unary_layer (network , input_val , operation_type , target , name )
935
+ return add_sqrt (network , target , kwargs , name )
962
936
963
937
964
938
@tensorrt_converter (acc_ops .reciprocal )
@@ -1619,40 +1593,7 @@ def acc_ops_squeeze(
1619
1593
kwargs : Dict [str , Argument ],
1620
1594
name : str ,
1621
1595
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1622
- input_val = kwargs ["input" ]
1623
-
1624
- if not isinstance (input_val , TRTTensor ):
1625
- raise RuntimeError (
1626
- f"squeeze received input { input_val } that is not part "
1627
- "of the TensorRT region!"
1628
- )
1629
-
1630
- dim = cast (Optional [int ], kwargs ["dim" ] if "dim" in kwargs else None )
1631
- # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
1632
- # dim, which is a very rare case. For now we just claim not supporting dim=None.
1633
- assert dim is not None , "We don't support dim=None right now for squeeze."
1634
-
1635
- dim = get_positive_dim (
1636
- dim , len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
1637
- )
1638
- if network .has_implicit_batch_dimension :
1639
- assert dim != 0 , "We don't support squeeze batch dim when it's implicit."
1640
- dim -= 1
1641
-
1642
- assert input_val .shape [dim ] != - 1 , "We don't support squeeze dynamic dim."
1643
- assert (
1644
- len (get_dynamic_dims (input_val .shape )) <= 1
1645
- ), "Currently more than one dynamic dim for input to squeeze is not supported."
1646
-
1647
- output_shape = []
1648
- for i , s in enumerate (input_val .shape ):
1649
- if i == dim and s == 1 :
1650
- continue
1651
- output_shape .append (s )
1652
- layer = network .add_shuffle (input_val )
1653
- layer .reshape_dims = tuple (output_shape )
1654
- set_layer_name (layer , target , name )
1655
- return layer .get_output (0 )
1596
+ return add_squeeze (network , target , kwargs , name )
1656
1597
1657
1598
1658
1599
@tensorrt_converter (acc_ops .add )
@@ -2022,89 +1963,7 @@ def acc_ops_where(
2022
1963
kwargs : Dict [str , Argument ],
2023
1964
name : str ,
2024
1965
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2025
-
2026
- condition_t = kwargs ["condition" ]
2027
- x_t = kwargs ["x" ]
2028
- y_t = kwargs ["y" ]
2029
-
2030
- if type (x_t ) != TRTTensor :
2031
- assert type (x_t ) is torch .Tensor , f"value { x_t } is not torch.Tensor!"
2032
-
2033
- if type (y_t ) != TRTTensor :
2034
- assert type (y_t ) is torch .Tensor , f"value { y_t } is not torch.Tensor!"
2035
-
2036
- # get output shape
2037
-
2038
- x_shape = list (x_t .shape )
2039
- y_shape = list (y_t .shape )
2040
- condition_shape = list (condition_t .shape )
2041
- output_shape = list (torch .broadcast_shapes (condition_shape , x_shape , y_shape ))
2042
-
2043
- # expand shape
2044
- if type (condition_t ) != TRTTensor :
2045
- assert condition_t .dtype == torch .bool , "condition dtype is not bool"
2046
- if condition_shape != output_shape :
2047
- condition_t .expand (output_shape )
2048
- condition_t = condition_t .to (torch .int32 )
2049
- condition_const = get_trt_tensor (network , condition_t , f"{ name } _condition" )
2050
- condition_layer = network .add_identity (condition_const )
2051
- condition_layer .set_output_type (0 , trt .bool )
2052
- set_layer_name (condition_layer , target , f"{ name } _condition" )
2053
- condition_val = condition_layer .get_output (0 )
2054
- else :
2055
- assert condition_t .dtype == trt .bool , "mask dtype is not bool!"
2056
- if condition_shape != output_shape :
2057
- condition_val = acc_ops_expand_tensor (
2058
- network ,
2059
- target ,
2060
- None ,
2061
- {"input" : condition_t , "sizes" : output_shape },
2062
- name = f"{ name } _expand" ,
2063
- )
2064
- else :
2065
- condition_val = condition_t
2066
-
2067
- if type (x_t ) != TRTTensor :
2068
- if x_shape != output_shape :
2069
- # special case where 1 element in x_t
2070
- if len (x_t .shape ) == 0 :
2071
- x_t = x_t .unsqueeze (0 )
2072
- x_t = x_t .expand (output_shape )
2073
- x_val = get_trt_tensor (network , x_t , f"{ name } _x" )
2074
- else :
2075
- x_val = x_t
2076
- if x_shape != output_shape :
2077
- x_val = acc_ops_expand_tensor (
2078
- network ,
2079
- target ,
2080
- None ,
2081
- {"input" : x_val , "sizes" : output_shape },
2082
- name = f"{ name } _x_expand" ,
2083
- )
2084
-
2085
- if type (y_t ) != TRTTensor :
2086
- if y_shape != output_shape :
2087
- # special case where 1 element in y_t
2088
- if len (y_t .shape ) == 0 :
2089
- y_t = y_t .unsqueeze (0 )
2090
- y_t = y_t .expand (output_shape )
2091
- y_val = get_trt_tensor (network , y_t , f"{ name } _y" )
2092
- else :
2093
- y_val = y_t
2094
- if y_shape != output_shape :
2095
- y_val = acc_ops_expand_tensor (
2096
- network ,
2097
- target ,
2098
- None ,
2099
- {"input" : y_val , "sizes" : output_shape },
2100
- name = f"{ name } _y_expand" ,
2101
- )
2102
-
2103
- select_layer = network .add_select (condition_val , x_val , y_val )
2104
-
2105
- set_layer_name (select_layer , target , f"{ name } _select" )
2106
-
2107
- return select_layer .get_output (0 )
1966
+ return add_where (network , target , kwargs , name )
2108
1967
2109
1968
2110
1969
@tensorrt_converter (acc_ops .masked_fill , no_implicit_batch_dim = True )
@@ -2721,62 +2580,7 @@ def acc_ops_chunk(
2721
2580
kwargs : Dict [str , Argument ],
2722
2581
name : str ,
2723
2582
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2724
- input_val = kwargs ["input" ]
2725
- chunks = cast (int , kwargs ["chunks" ])
2726
- dim = cast (int , kwargs ["dim" ])
2727
- input_dim_size = len (input_val .shape ) # type: ignore[union-attr]
2728
-
2729
- if not isinstance (input_val , TRTTensor ):
2730
- raise RuntimeError (
2731
- f"chunk received input { input_val } that is not part "
2732
- "of the TensorRT region!"
2733
- )
2734
-
2735
- dynamic_shape = has_dynamic_shape (input_val .shape )
2736
- if network .has_implicit_batch_dimension :
2737
- input_dim_size += 1
2738
- dim = get_positive_dim (dim , input_dim_size )
2739
- assert dim != 0 , "Can't chunk on batch dim when it's implicit!"
2740
- dim -= 1
2741
- else :
2742
- if dynamic_shape :
2743
- assert input_val .shape [dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
2744
- dim = get_positive_dim (dim , input_dim_size )
2745
-
2746
- if chunks > input_val .shape [dim ]:
2747
- warnings .warn (
2748
- f"Asked for { chunks } chunks along dimention "
2749
- f"{ dim } on tensor with size { input_val .shape } , chunks "
2750
- f"will default to { input_val .shape [dim ]} " ,
2751
- RuntimeWarning ,
2752
- )
2753
- chunks = input_val .shape [dim ]
2754
-
2755
- start = [0 ] * len (input_val .shape )
2756
- stride = [1 ] * len (start )
2757
- offset = 0
2758
- split_size = (input_val .shape [dim ] + chunks - 1 ) // chunks
2759
-
2760
- max_offset = input_val .shape [dim ]
2761
- # add slice layers
2762
- output = []
2763
- for i in range (chunks ):
2764
- shape = list (input_val .shape )
2765
- shape [dim ] = min (split_size , max_offset - offset )
2766
- if dynamic_shape :
2767
- shape = get_shape_with_dynamic_shape (
2768
- network , shape , input_val , target , f"{ name } _{ i } "
2769
- )
2770
- start [dim ] = offset
2771
- layer = network .add_slice (
2772
- input_val , start = start , shape = [] if dynamic_shape else shape , stride = stride
2773
- )
2774
- if dynamic_shape :
2775
- layer .set_input (2 , shape )
2776
- offset += split_size
2777
- set_layer_name (layer , target , f"{ name } _{ i } " )
2778
- output .append (layer .get_output (0 ))
2779
- return output
2583
+ return add_chunk (network , target , kwargs , name )
2780
2584
2781
2585
2782
2586
@tensorrt_converter (acc_ops .cumsum , no_implicit_batch_dim = True )
0 commit comments