|
| 1 | +import operator |
| 2 | +import warnings |
| 3 | +from typing import Optional, cast |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import tensorrt as trt |
| 8 | +import torch |
| 9 | +from torch.fx.node import Target |
| 10 | + |
| 11 | +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape |
| 12 | +from torch_tensorrt.fx.converters.converter_utils import ( |
| 13 | + SourceIR, |
| 14 | + broadcast, |
| 15 | + broadcastable, |
| 16 | + get_trt_tensor, |
| 17 | + set_layer_name, |
| 18 | +) |
| 19 | +from torch_tensorrt.fx.converters.impl.slice import expand |
| 20 | + |
| 21 | + |
| 22 | +def where( |
| 23 | + network: TRTNetwork, |
| 24 | + target: Target, |
| 25 | + source_ir: Optional[SourceIR], |
| 26 | + name: str, |
| 27 | + input: TRTTensor, |
| 28 | + other: TRTTensor, |
| 29 | + condition: TRTTensor, |
| 30 | +) -> TRTTensor: |
| 31 | + input_dim = len(tuple(input.shape)) |
| 32 | + other_dim = len(tuple(other.shape)) |
| 33 | + condition_dim = len(tuple(condition.shape)) |
| 34 | + |
| 35 | + if type(input) != TRTTensor: |
| 36 | + assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!" |
| 37 | + |
| 38 | + if type(other) != TRTTensor: |
| 39 | + assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!" |
| 40 | + |
| 41 | + if not (broadcastable(input, other)): |
| 42 | + assert f"The two torch tensors should be broadcastable" |
| 43 | + |
| 44 | + # get output shape |
| 45 | + # purpose of this is to bring input and other rank same as |
| 46 | + # output_shape to input it to the add_expand operation |
| 47 | + # condition will have dimension of either input or other |
| 48 | + input, other = broadcast(network, input, other, f"{name}_x", f"{name}_y") |
| 49 | + if len(tuple(condition.shape)) != len(tuple(input.shape)): |
| 50 | + condition, input = broadcast( |
| 51 | + network, condition, input, f"{name}_condition", f"{name}_x" |
| 52 | + ) |
| 53 | + |
| 54 | + x_shape = list(input.shape) |
| 55 | + y_shape = list(other.shape) |
| 56 | + condition_shape = list(condition.shape) |
| 57 | + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) |
| 58 | + |
| 59 | + # expand shape |
| 60 | + if type(condition) != TRTTensor: |
| 61 | + assert condition.dtype == torch.bool, "condition dtype is not bool" |
| 62 | + if condition_shape != output_shape: |
| 63 | + condition.expand(output_shape) |
| 64 | + condition = condition.to(torch.int32) |
| 65 | + condition_const = get_trt_tensor(network, condition, f"{name}_condition") |
| 66 | + condition_layer = network.add_identity(condition_const) |
| 67 | + condition_layer.set_output_type(0, trt.bool) |
| 68 | + set_layer_name(condition_layer, target, f"{name}_condition") |
| 69 | + condition_val = condition_layer.get_output(0) |
| 70 | + else: |
| 71 | + assert condition.dtype == trt.bool, "mask dtype is not bool!" |
| 72 | + if condition_shape != condition_dim: |
| 73 | + condition_val = expand( |
| 74 | + network, target, source_ir, f"{name}_expand", condition, output_shape |
| 75 | + ) |
| 76 | + else: |
| 77 | + condition_val = condition |
| 78 | + |
| 79 | + if type(input) != TRTTensor: |
| 80 | + if x_shape != input_dim: |
| 81 | + # special case where 1 element in input |
| 82 | + if len(input.shape) == 0: |
| 83 | + input = input.unsqueeze(0) |
| 84 | + input = input.expand(output_shape) |
| 85 | + x_val = get_trt_tensor(network, input, f"{name}_x") |
| 86 | + else: |
| 87 | + x_val = input |
| 88 | + if x_shape != output_shape: |
| 89 | + x_val = expand( |
| 90 | + network, target, source_ir, f"{name}_x_expand", input, output_shape |
| 91 | + ) |
| 92 | + |
| 93 | + if type(other) != TRTTensor: |
| 94 | + if y_shape != output_shape: |
| 95 | + # special case where 1 element in other |
| 96 | + if len(other.shape) == 0: |
| 97 | + other = other.unsqueeze(0) |
| 98 | + other = other.expand(output_shape) |
| 99 | + y_val = get_trt_tensor(network, other, f"{name}_y") |
| 100 | + else: |
| 101 | + y_val = other |
| 102 | + if y_shape != other_dim: |
| 103 | + y_val = expand( |
| 104 | + network, target, source_ir, f"{name}_y_expand", y_val, output_shape |
| 105 | + ) |
| 106 | + |
| 107 | + select_layer = network.add_select(condition_val, x_val, y_val) |
| 108 | + |
| 109 | + set_layer_name(select_layer, target, f"{name}_select") |
| 110 | + |
| 111 | + return select_layer.get_output(0) |
0 commit comments