Skip to content

Commit 07b9266

Browse files
narendasanapbose
authored andcommitted
refactor: Moving elementwise and unary core to impl
Signed-off-by: Naren Dasan <[email protected]> new file: ../converters/impl/unary/base.py
1 parent bd9c29a commit 07b9266

File tree

10 files changed

+885
-473
lines changed

10 files changed

+885
-473
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 380 additions & 131 deletions
Large diffs are not rendered by default.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
23-
from torch_tensorrt.fx.converters.impl import activation, convolution
23+
from torch_tensorrt.fx.converters.impl import activation
24+
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2425

2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
2627

@@ -182,9 +183,7 @@ def aten_ops_div(
182183
network, target, None, kwargs_new, name
183184
)
184185
elif rounding_mode == "trunc":
185-
return acc_ops_converters.acc_ops_trunc_div(
186-
network, target, None, kwargs_new, name
187-
)
186+
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
188187
else:
189188
raise RuntimeError(
190189
f"Target {target} does not support rounding mode {rounding_mode}"

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 4 additions & 338 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SourceIR(Enum):
2828
ACC = auto()
2929
ATEN = auto()
3030
PRIM = auto()
31+
TORCHTRT_LOWERED = auto()
3132
UNKNOWN = auto()
3233

3334
def __str__(self):
@@ -39,6 +40,8 @@ def __str__(self):
3940
return "aten"
4041
elif self == SourceIR.PRIM:
4142
return "prim"
43+
elif self == SourceIR.TORCHTRT_LOWERED:
44+
return "torchtrt_lowered"
4245
else:
4346
return "unknown_ir"
4447

@@ -409,176 +412,7 @@ def broadcast(
409412
return a, b
410413

411414

412-
def get_shape_with_dynamic_shape(
413-
network: TRTNetwork,
414-
shape: Union[list, tuple, torch.Tensor],
415-
input_val: TRTTensor,
416-
target: Target,
417-
name: str,
418-
) -> TRTTensor:
419-
"""
420-
Prepare the real output tensor shape for dynamic shape mode tensor input.
421-
How this functions works:
422-
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
423-
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
424-
reduce operation output shape. Steps of calculations are:
425-
1. get the actual tensor shape of input_val via add_shape layer;
426-
2. create a all 0 tensor [0, 0, 0];
427-
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
428-
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
429-
all -1 dynamic shape dimensions with actual batch_size value;
430-
5. output shape with actual batch_size as [2048, 128, 256]
431-
432-
Args:
433-
network (TRTNetwork): TensorRT network object.
434-
shape: calculated shape of the expected output tensor
435-
input_val (TRTTensor): A TensorRT ITensor.
436-
target (Target): Target of fx node.
437-
name (str): The name we want to assign to the created TensorRT layer.
438-
Returns:
439-
TensorRT ITensors that represents the actual shape of the input_val
440-
"""
441-
# Ger real shape info for input_val
442-
input_shape = network.add_shape(input_val).get_output(0)
443-
444-
scale_layer = network.add_constant(
445-
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
446-
)
447-
set_layer_name(scale_layer, target, f"{name}_scale")
448-
scale_res = scale_layer.get_output(0)
449-
450-
length = input_shape.shape[0]
451-
zero_layer = network.add_constant(
452-
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
453-
)
454-
set_layer_name(zero_layer, target, f"{name}_zeros")
455-
456-
condition_val = add_binary_elementwise_layer(
457-
network,
458-
scale_res,
459-
zero_layer.get_output(0),
460-
trt.ElementWiseOperation.LESS,
461-
target,
462-
f"{name}_shape",
463-
)
464-
select_layer = network.add_select(condition_val, input_shape, scale_res)
465-
set_layer_name(select_layer, target, f"{name}_select")
466-
return select_layer.get_output(0)
467-
468-
469-
def add_binary_elementwise_layer(
470-
network: TRTNetwork,
471-
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
472-
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
473-
op_type: trt.ElementWiseOperation,
474-
target: Target,
475-
name: str,
476-
) -> TRTTensor:
477-
"""
478-
This function adds a TensorRT elementwise layer. We allow both operands to be
479-
constant (not a trt tensor) because in implicit batch dimension mode, we could
480-
introduce constant via .size() op. Other scenario should be const folded first.
481-
If any operand is not a trt tensor, we make it a trt constant layer while preserve
482-
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
483-
484-
Limitation:
485-
If we are using implicit batch dim mode, the operand that is not a trt
486-
tensor are not allowed to have larger ranks than the trt tensor operand.
487-
488-
Args:
489-
network (TRTNetwork): TensorRT network object.
490-
lhs_val (TRTTensor): Left operand of the binary operation. Could
491-
be a TensorRT tensor, a PyTorch tensor or a simple value.
492-
rhs_val (TRTTensor): Right operand of the binary operation. Similar
493-
to lhs_val.
494-
op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation.
495-
target (Target): Target of fx node.
496-
name (str): The name we want to assign to the created TensorRT layer.
497-
498-
Returns:
499-
The output of TensorRT Elementwise layer.
500-
"""
501-
lhs_dtype = None
502-
rhs_dtype = None
503-
is_lhs_trt_tensor = False
504-
is_rhs_trt_tensor = False
505-
506-
if isinstance(lhs_val, TRTTensor):
507-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
508-
is_lhs_trt_tensor = True
509-
if isinstance(rhs_val, TRTTensor):
510-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
511-
is_rhs_trt_tensor = True
512-
513-
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
514-
warnings.warn(
515-
f"Both operands of the binary elementwise op {name} "
516-
"are constant. In this case, please consider constant fold the model first."
517-
)
518-
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)
519-
520-
# If the following conditions are true:
521-
# 1. the network has implicit batch dimension,
522-
# 2. one operand has shape [] (real shape is [batch_size]),
523-
# 3. another operand is a scalar,
524-
# then the result should also have shape [] (real shape is [batch_size]).
525-
#
526-
# In such case, we need to convert the scalar operand to tensor, because
527-
# this way the shape will become [1], and then will be properly squeezed
528-
# into [], meaning that the result will have shape [], which is what we
529-
# expect.
530-
#
531-
# Note that the dtype here is supposed to be the same as the scalar
532-
# dtype but we don't have a way to detect whether it makes sense for the
533-
# scalar to be float or half. Hence we go with the lhs dtype.
534-
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
535-
rhs_val = np.array(
536-
[rhs_val], dtype=unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
537-
)
538-
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
539-
lhs_val = np.array(
540-
[lhs_val], dtype=unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
541-
)
542-
543-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
544-
# will fail because lhs shape has fewer dimensions than rhs shape. This
545-
# happens when using implicit batch dimension, when we removed the 1st
546-
# dimension from input tensor, causing it to have shape [] - a scalar. We
547-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
548-
# scalar too. More generally, we squeeze_left on input if it's a constant
549-
# tensor. This is safe because broadcast will pad dimensions on the left
550-
# (prepend) to make lhs and rhs shape compatible.
551-
if network.has_implicit_batch_dimension:
552-
if isinstance(lhs_val, (torch.Tensor, np.ndarray)):
553-
lhs_val = squeeze_left(lhs_val)
554-
if isinstance(rhs_val, (torch.Tensor, np.ndarray)):
555-
rhs_val = squeeze_left(rhs_val)
556-
557-
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
558-
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)
559-
560-
# Check the limitation in the doc string.
561-
if network.has_implicit_batch_dimension:
562-
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
563-
assert len(lhs_val.shape) >= len(
564-
rhs_val.shape
565-
), f"{lhs_val.shape} >= {rhs_val.shape}"
566-
elif not is_lhs_trt_tensor and is_rhs_trt_tensor:
567-
assert len(rhs_val.shape) >= len(
568-
lhs_val.shape
569-
), f"{rhs_val.shape} >= {lhs_val.shape}"
570-
571-
lhs_val, rhs_val = broadcast(
572-
network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
573-
)
574-
layer = network.add_elementwise(lhs_val, rhs_val, op_type)
575-
set_layer_name(layer, target, name)
576-
output = layer.get_output(0)
577-
output.name = output.name + "_" + target.__name__
578-
return output
579-
580-
581-
def squeeze_left(const: Union[torch.Tensor, np.ndarray]):
415+
def squeeze_left(const: torch.Tensor):
582416
"""
583417
Squeeze the size-1 dimensions on the left side of the shape tuple.
584418
PyTorch's `squeeze()` doesn't support passing multiple `dim`s at once, so
@@ -594,38 +428,6 @@ def squeeze_left(const: Union[torch.Tensor, np.ndarray]):
594428
return const
595429

596430

597-
def add_unary_layer(
598-
network: TRTNetwork,
599-
input_val: TRTTensor,
600-
operation_type: trt.UnaryOperation,
601-
target: Target,
602-
name: str,
603-
) -> TRTTensor:
604-
"""
605-
Add a TensorRT Unary layer to `network`.
606-
607-
Args:
608-
network (TRTNetwork): TensorRT network object.
609-
input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor.
610-
op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation.
611-
target (Target): Target of fx node.
612-
name (str): The name we want to assign to the created TensorRT layer.
613-
614-
Returns:
615-
The output of TensorRT Unary layer.
616-
"""
617-
if not isinstance(input_val, TRTTensor):
618-
raise RuntimeError(
619-
f"{operation_type} received input {input_val} that is not part "
620-
"of the TensorRT region!"
621-
)
622-
layer = network.add_unary(input_val, operation_type)
623-
set_layer_name(layer, target, name)
624-
output = layer.get_output(0)
625-
output.name = output.name + "_" + target.__name__
626-
return layer.get_output(0)
627-
628-
629431
def add_reduce_layer(
630432
network: TRTNetwork,
631433
target: Target,
@@ -730,142 +532,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names):
730532
return inputs
731533

732534

733-
def sign(
734-
network: TRTNetwork, input_val: TRTTensor, target: Target, name: str
735-
) -> TRTTensor:
736-
"""
737-
Sign is calculated as below:
738-
x = input
739-
sign = (exp(x) // exp(abs(x))) * 2 - 1
740-
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
741-
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
742-
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
743-
744-
Args:
745-
network (TRTNetwork): TensorRT network object.
746-
input_val (TRTTensor): The input tensor.
747-
target (Target): fx node target.
748-
name (str): Name of the fx node with optional suffix.
749-
750-
Returns:
751-
A TensorRT tensor represent the result of sign operator.
752-
"""
753-
input_exp_output = add_unary_layer(
754-
network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp"
755-
)
756-
input_abs_output = add_unary_layer(
757-
network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs"
758-
)
759-
input_abs_exp_output = add_unary_layer(
760-
network,
761-
input_abs_output,
762-
trt.UnaryOperation.EXP,
763-
target,
764-
f"{name}_prod_abs_exp",
765-
)
766-
floor_div_output = add_binary_elementwise_layer(
767-
network,
768-
input_exp_output,
769-
input_abs_exp_output,
770-
trt.ElementWiseOperation.FLOOR_DIV,
771-
target,
772-
f"{name}_exp_floor_div",
773-
)
774-
double_floor_div_output = add_binary_elementwise_layer(
775-
network,
776-
floor_div_output,
777-
2,
778-
trt.ElementWiseOperation.PROD,
779-
target,
780-
f"{name}_floor_div*2",
781-
)
782-
return add_binary_elementwise_layer(
783-
network,
784-
double_floor_div_output,
785-
1,
786-
trt.ElementWiseOperation.SUB,
787-
target,
788-
f"{name}_sign",
789-
)
790-
791-
792-
def trunc_div(
793-
input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str
794-
) -> TRTTensor:
795-
"""
796-
Perform trunc divide on Tensor, result of divide will be round toward zero.
797-
This means for positive number, it will be floor round; for negative number,
798-
it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
799-
800-
Args:
801-
input: divisor.
802-
other: dividend.
803-
network: INetworkDefinition.
804-
target: node target.
805-
name: namespace for the op
806-
807-
Returns:
808-
A TensorRT tensor represent the result of trunc divide.
809-
"""
810-
prod_output = add_binary_elementwise_layer(
811-
network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod"
812-
)
813-
sign_output = sign(network, prod_output, target, name)
814-
815-
# Convert constant input into ITensor for UnaryOperation
816-
if not isinstance(input, trt.tensorrt.ITensor):
817-
input = get_trt_tensor(network, input, f"{name}_input")
818-
if not isinstance(other, trt.tensorrt.ITensor):
819-
other = get_trt_tensor(
820-
network,
821-
other,
822-
f"{name}_other",
823-
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
824-
)
825-
826-
abs_input_output = add_unary_layer(
827-
network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input"
828-
)
829-
abs_other_output = add_unary_layer(
830-
network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other"
831-
)
832-
abs_floor_output = add_binary_elementwise_layer(
833-
network,
834-
abs_input_output,
835-
abs_other_output,
836-
trt.ElementWiseOperation.FLOOR_DIV,
837-
target,
838-
f"{name}_floor_div",
839-
)
840-
output = add_binary_elementwise_layer(
841-
network,
842-
abs_floor_output,
843-
sign_output,
844-
trt.ElementWiseOperation.PROD,
845-
target,
846-
f"{name}_output",
847-
)
848-
849-
return output
850-
851-
852-
def get_python_op_from_trt_elementwise_op(
853-
trt_op: TRTElementWiseOp,
854-
) -> Callable[[Any, Any], Any]:
855-
if trt_op == trt.ElementWiseOperation.SUM:
856-
return operator.add
857-
elif trt_op == trt.ElementWiseOperation.PROD:
858-
return operator.mul
859-
elif trt_op == trt.ElementWiseOperation.SUB:
860-
return operator.sub
861-
elif trt_op == trt.ElementWiseOperation.DIV:
862-
return operator.truediv
863-
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
864-
return operator.floordiv
865-
else:
866-
raise RuntimeError(f"{trt_op} is not supported yet!")
867-
868-
869535
def dtype_uniform(
870536
network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor
871537
):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *

0 commit comments

Comments
 (0)