1
1
from typing import Optional
2
2
3
+ import numpy as np
3
4
import tensorrt as trt
4
5
import torch
5
6
from torch .fx .node import Target
@@ -22,16 +23,6 @@ def where(
22
23
other : TRTTensor ,
23
24
condition : TRTTensor ,
24
25
) -> TRTTensor :
25
- input_dim = len (tuple (input .shape ))
26
- other_dim = len (tuple (other .shape ))
27
- condition_dim = len (tuple (condition .shape ))
28
-
29
- if type (input ) != TRTTensor :
30
- assert type (input ) is torch .Tensor , f"value { input } is not torch.Tensor!"
31
-
32
- if type (other ) != TRTTensor :
33
- assert type (other ) is torch .Tensor , f"value { other } is not torch.Tensor!"
34
-
35
26
if not (broadcastable (input , other )):
36
27
assert "The two torch tensors should be broadcastable"
37
28
@@ -48,33 +39,37 @@ def where(
48
39
x_shape = list (input .shape )
49
40
y_shape = list (other .shape )
50
41
condition_shape = list (condition .shape )
42
+
51
43
output_shape = list (torch .broadcast_shapes (condition_shape , x_shape , y_shape ))
52
44
53
45
# expand shape
54
- if type (condition ) != TRTTensor :
55
- assert condition .dtype == torch .bool , "condition dtype is not bool"
46
+ if not isinstance (condition , TRTTensor ) :
47
+ assert condition .dtype in ( torch .bool , np . bool_ ) , "condition dtype is not bool"
56
48
if condition_shape != output_shape :
57
- condition .expand (output_shape )
58
- condition = condition .to (torch .int32 )
59
- condition_const = get_trt_tensor (network , condition , f"{ name } _condition" )
60
- condition_layer = network .add_identity (condition_const )
61
- condition_layer .set_output_type (0 , trt .bool )
62
- set_layer_name (condition_layer , target , f"{ name } _condition" )
63
- condition_val = condition_layer .get_output (0 )
49
+ condition = (
50
+ condition .expand (output_shape )
51
+ if isinstance (condition , torch .Tensor )
52
+ else np .broadcast_to (condition , output_shape )
53
+ )
54
+ condition_val = get_trt_tensor (network , condition , f"{ name } _condition" )
64
55
else :
65
56
assert condition .dtype == trt .bool , "mask dtype is not bool!"
66
- if len ( condition_shape ) != condition_dim :
57
+ if condition_shape != output_shape :
67
58
condition_val = expand (
68
59
network , target , source_ir , f"{ name } _expand" , condition , output_shape
69
60
)
70
61
else :
71
62
condition_val = condition
72
63
73
- if type (input ) != TRTTensor :
64
+ if not isinstance (input , TRTTensor ) :
74
65
if x_shape != output_shape :
75
66
# special case where 1 element in input
76
67
if len (input .shape ) == 0 :
77
- input = input .unsqueeze (0 )
68
+ input = (
69
+ input .unsqueeze (0 )
70
+ if isinstance (input , torch .Tensor )
71
+ else np .expand_dims (input , axis = 0 )
72
+ )
78
73
input = input .expand (output_shape )
79
74
x_val = get_trt_tensor (network , input , f"{ name } _x" )
80
75
else :
@@ -84,11 +79,15 @@ def where(
84
79
network , target , source_ir , f"{ name } _x_expand" , input , output_shape
85
80
)
86
81
87
- if type (other ) != TRTTensor :
82
+ if not isinstance (other , TRTTensor ) :
88
83
if y_shape != output_shape :
89
84
# special case where 1 element in other
90
85
if len (other .shape ) == 0 :
91
- other = other .unsqueeze (0 )
86
+ other = (
87
+ other .unsqueeze (0 )
88
+ if isinstance (other , torch .Tensor )
89
+ else np .expand_dims (other , axis = 0 )
90
+ )
92
91
other = other .expand (output_shape )
93
92
y_val = get_trt_tensor (network , other , f"{ name } _y" )
94
93
else :
0 commit comments