Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit be84ba5

Browse files
committedSep 29, 2023
fix: Upgrade to_numpy to allow boolean constants
1 parent 2ff5466 commit be84ba5

File tree

9 files changed

+185
-128
lines changed

9 files changed

+185
-128
lines changed
 

‎py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
345345

346346
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
347347
with _disable_current_modes():
348-
from torch_tensorrt.fx.converters import to_numpy
348+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
349349

350350
frozen_attr = self.fetch_attr(target)
351351

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py‎

Lines changed: 99 additions & 99 deletions
Large diffs are not rendered by default.

‎py/torch_tensorrt/dynamo/conversion/converter_utils.py‎

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from torch_tensorrt.fx.converters.converter_utils import (
1818
Frameworks,
1919
get_axes_for_reduce_op,
20-
to_numpy,
2120
unified_dtype_converter,
2221
)
2322
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
@@ -270,9 +269,10 @@ def create_constant(
270269
Returns:
271270
A TensorRT ITensor that represents the given value.
272271
"""
272+
numpy_value = to_numpy(value, dtype)
273273
constant = ctx.net.add_constant(
274274
(1,) if isinstance(value, (int, float, bool)) else value.shape,
275-
to_numpy(value, dtype).copy(),
275+
numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
276276
)
277277
constant.name = name
278278
return constant.get_output(0)
@@ -351,7 +351,10 @@ def enforce_tensor_types(
351351
and all((dtype in (TRTTensor, np.ndarray, torch.Tensor)) for dtype in val)
352352
)
353353
for val in type_dictionary.values()
354-
), "Invalid value(s) specified in type enforcement"
354+
), (
355+
"Invalid value(s) specified in type enforcement."
356+
"Note that torch.Tensor cannot be present as a type without np.ndarray."
357+
)
355358

356359
def wrapper(func: DynamoConverterImplSignature) -> DynamoConverterImplSignature:
357360
@functools.wraps(func)
@@ -384,20 +387,25 @@ def convert_with_type_enforcement(
384387
f"Detected argument at index {index} had type {type(candidate)} "
385388
f"which is not one of the approved types {approved_dtypes}"
386389
)
387-
# Numpy arrays are preferred in general - if approved, promote to Numpy first
388-
elif np.ndarray in approved_dtypes and not isinstance(
389-
candidate, TRTTensor
390-
):
391-
new_value = to_numpy(candidate)
392-
# As a fallback, freeze tensors to IConstantLayers if they cannot be handled as Numpy arrays
393-
elif TRTTensor in approved_dtypes:
394-
_LOGGER.debug(
395-
f"Freezing tensor {name}_constant_{index} to TRT IConstantLayer"
396-
)
397-
new_value = get_trt_tensor(
398-
ctx, candidate, name + f"_constant_{index}"
399-
)
400-
else:
390+
391+
# Type-promotion preference order depends on tuple order
392+
for dtype in approved_dtypes:
393+
# Currently, we do not cast to Torch tensor, due to issues with such casts
394+
# in FakeTensor contexts
395+
if dtype == np.ndarray and not isinstance(candidate, TRTTensor):
396+
new_value = to_numpy(candidate)
397+
break
398+
# As a fallback, freeze tensors to IConstantLayers if they cannot be handled as Numpy arrays
399+
elif dtype == TRTTensor:
400+
_LOGGER.debug(
401+
f"Freezing tensor {name}_constant_{index} to TRT IConstantLayer"
402+
)
403+
new_value = get_trt_tensor(
404+
ctx, candidate, name + f"_constant_{index}"
405+
)
406+
break
407+
408+
if new_value is None:
401409
raise AssertionError(
402410
f"Argument at index {index} was not able to be converted to one of "
403411
f"the following types: {approved_dtypes}"
@@ -414,3 +422,50 @@ def convert_with_type_enforcement(
414422
return convert_with_type_enforcement
415423

416424
return wrapper
425+
426+
427+
def to_numpy(
428+
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
429+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
430+
) -> Optional[np.ndarray]:
431+
"""
432+
Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is
433+
quantized it will be dequantized first.
434+
Args:
435+
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
436+
A PyTorch tensor, Numpy array, int, float, or bool
437+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
438+
If a dtype is given, we will convert the type of the given `value` to this dtype.
439+
Returns:
440+
A Numpy array or None, if the input was None.
441+
"""
442+
output = None
443+
444+
if value is None or isinstance(value, np.ndarray):
445+
output = value
446+
447+
elif isinstance(value, torch.Tensor):
448+
if value.is_quantized:
449+
value = value.dequantize()
450+
451+
output = value.cpu().detach().contiguous().numpy()
452+
453+
elif isinstance(value, int):
454+
output = np.array([value], dtype=np.int32)
455+
456+
elif isinstance(value, float):
457+
output = np.array([value], dtype=np.float32)
458+
459+
elif isinstance(value, bool):
460+
output = np.array([value], dtype=np.bool_)
461+
462+
if isinstance(output, np.ndarray) or output is None:
463+
return (
464+
output
465+
if (dtype is None or output is None)
466+
else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY))
467+
)
468+
else:
469+
raise AssertionError(
470+
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
471+
)

‎py/torch_tensorrt/dynamo/conversion/impl/conv.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
from torch_tensorrt.dynamo.conversion import impl
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
SourceIR,
1213
extend_attr_to_tuple,
1314
get_trt_tensor,
15+
to_numpy,
1416
)
1517
from torch_tensorrt.fx.converters.converter_utils import (
16-
SourceIR,
1718
get_dyn_range,
1819
has_dynamic_shape,
1920
mark_as_int8_layer,
2021
set_layer_name,
21-
to_numpy,
2222
)
2323
from torch_tensorrt.fx.types import TRTTensor
2424

‎py/torch_tensorrt/dynamo/conversion/impl/deconv.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
extend_attr_to_tuple,
1313
get_trt_tensor,
14+
to_numpy,
1415
)
1516
from torch_tensorrt.fx.converters.converter_utils import (
1617
SourceIR,
1718
get_dyn_range,
1819
has_dynamic_shape,
1920
mark_as_int8_layer,
2021
set_layer_name,
21-
to_numpy,
2222
)
2323
from torch_tensorrt.fx.types import TRTTensor
2424

‎py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
1011
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1112
convert_binary_elementwise,
1213
)
@@ -16,7 +17,6 @@
1617
get_trt_plugin,
1718
has_dynamic_shape,
1819
set_layer_name,
19-
to_numpy,
2020
)
2121
from torch_tensorrt.fx.types import TRTTensor
2222
from torch_tensorrt.fx.utils import get_dynamic_dims

‎py/torch_tensorrt/dynamo/conversion/impl/pool.py‎

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
67
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
78
from torch_tensorrt.fx.converters.converter_utils import (
89
has_dynamic_shape,
910
set_layer_name,
1011
)
11-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
12+
from torch_tensorrt.fx.types import TRTTensor
1213

1314

1415
def avg_poolNd(
15-
network: TRTNetwork,
16+
ctx: ConversionContext,
1617
target: Union[Target, str],
1718
source_ir: Optional[SourceIR],
1819
name: str,
@@ -45,7 +46,7 @@ def avg_poolNd(
4546
padding = extend_attr_to_tuple(padding, dim)
4647

4748
# add average pooling layer
48-
pool_layer = network.add_pooling_nd(
49+
pool_layer = ctx.net.add_pooling_nd(
4950
input=input,
5051
type=trt.PoolingType.AVERAGE,
5152
window_size=kernel_size,
@@ -60,7 +61,7 @@ def avg_poolNd(
6061

6162

6263
def max_poolNd(
63-
network: TRTNetwork,
64+
ctx: ConversionContext,
6465
target: Union[Target, str],
6566
source_ir: Optional[SourceIR],
6667
name: str,
@@ -92,7 +93,7 @@ def max_poolNd(
9293
padding = extend_attr_to_tuple(padding, dim)
9394

9495
# add max pooling layer
95-
pool_layer = network.add_pooling_nd(
96+
pool_layer = ctx.net.add_pooling_nd(
9697
input=input,
9798
type=trt.PoolingType.MAX,
9899
window_size=kernel_size,

‎py/torch_tensorrt/dynamo/conversion/impl/select.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
78
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
89
from torch_tensorrt.fx.converters.converter_utils import (
910
get_positive_dim,
1011
has_dynamic_shape,
11-
to_numpy,
1212
)
1313
from torch_tensorrt.fx.types import Shape, TRTTensor
1414

‎py/torch_tensorrt/dynamo/conversion/impl/shape.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
11+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
1112
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1213
convert_binary_elementwise,
1314
)
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, to_numpy
15+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1516
from torch_tensorrt.fx.types import TRTTensor
1617

1718

0 commit comments

Comments
 (0)
Please sign in to comment.