Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 075a028

Browse files
committedJun 5, 2023
fix: Centralize FX conv impl, add feature
- Centralize convolution implementation in FX, similar across all source IRs, including aten, acc, nn - Enable pass-through of build errors in e2e tests to ensure errors are not being hidden - Allow conv layers to take bias inputs in FX, per new functionality from TRT
1 parent dd31c9a commit 075a028

File tree

6 files changed

+310
-314
lines changed

6 files changed

+310
-314
lines changed
 

‎py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_resnet18(ir):
2727
"device": torchtrt.Device("cuda:0"),
2828
"enabled_precisions": {torch.float},
2929
"ir": ir,
30+
"pass_through_build_failures": True,
3031
}
3132

3233
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -57,6 +58,7 @@ def test_mobilenet_v2(ir):
5758
"device": torchtrt.Device("cuda:0"),
5859
"enabled_precisions": {torch.float},
5960
"ir": ir,
61+
"pass_through_build_failures": True,
6062
}
6163

6264
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -87,6 +89,7 @@ def test_efficientnet_b0(ir):
8789
"device": torchtrt.Device("cuda:0"),
8890
"enabled_precisions": {torch.float},
8991
"ir": ir,
92+
"pass_through_build_failures": True,
9093
}
9194

9295
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -126,6 +129,7 @@ def test_bert_base_uncased(ir):
126129
"enabled_precisions": {torch.float},
127130
"truncate_long_and_double": True,
128131
"ir": ir,
132+
"pass_through_build_failures": True,
129133
}
130134
trt_mod = torchtrt.compile(model, **compile_spec)
131135

@@ -160,6 +164,7 @@ def test_resnet18_half(ir):
160164
"device": torchtrt.Device("cuda:0"),
161165
"enabled_precisions": {torch.half},
162166
"ir": ir,
167+
"pass_through_build_failures": True,
163168
}
164169

165170
trt_mod = torchtrt.compile(model, **compile_spec)

‎py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 44 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation
29+
from torch_tensorrt.fx.converters.impl import activation, convolution
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
3232

@@ -96,86 +96,20 @@ def acc_ops_conv1d(
9696
kwargs: Dict[str, Argument],
9797
name: str,
9898
) -> Union[TRTTensor, Sequence[TRTTensor]]:
99-
input_val = kwargs["input"]
100-
if not isinstance(input_val, TRTTensor):
101-
raise RuntimeError(
102-
f"Conv received input {input_val} that is not part "
103-
"of the TensorRT region!"
104-
)
105-
106-
# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
107-
unsqueeze_layer = network.add_shuffle(input=input_val)
108-
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
109-
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
110-
input_val = unsqueeze_layer.get_output(0)
111-
112-
if has_dynamic_shape(input_val.shape):
113-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
114-
115-
# for now we'll assume bias is constant Tensor or None,
116-
# and bias being ITensor is not supported in TensorRT api
117-
# right now
118-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
119-
raise RuntimeError(
120-
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
121-
)
122-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
123-
if bias is not None:
124-
bias = bias[None]
125-
weight = kwargs["weight"]
126-
127-
if network.has_explicit_precision or isinstance(weight, TRTTensor):
128-
weight = get_trt_tensor(network, weight, f"{name}_weight")
129-
# Expand 1d weight with unsqueeze for calculation
130-
unsqueeze_weight_layer = network.add_shuffle(input=weight)
131-
unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1])
132-
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight")
133-
weight = unsqueeze_weight_layer.get_output(0)
134-
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
135-
# will need to use uninitialized weight and set it later to support
136-
# ITensor weights
137-
dummy_weight = trt.Weights()
138-
layer = network.add_convolution_nd(
139-
input=input_val,
140-
num_output_maps=weight.shape[0],
141-
kernel_shape=weight.shape[2:],
142-
kernel=dummy_weight,
143-
bias=bias,
144-
)
145-
146-
layer.set_input(1, weight)
147-
else:
148-
if not isinstance(kwargs["weight"], torch.Tensor):
149-
raise RuntimeError(
150-
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
151-
)
152-
weight = to_numpy(weight)
153-
weight = np.expand_dims(weight, -1)
154-
layer = network.add_convolution_nd(
155-
input=input_val,
156-
num_output_maps=weight.shape[0],
157-
kernel_shape=weight.shape[2:],
158-
kernel=weight,
159-
bias=bias,
160-
)
161-
# expand params to 2d for computation
162-
padding = list(kwargs["padding"])
163-
padding.append(0)
164-
stride = extend_attr_to_tuple(kwargs["stride"], 2)
165-
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)
166-
167-
set_layer_name(layer, target, name)
168-
layer.stride_nd = stride
169-
layer.padding_nd = padding
170-
layer.dilation_nd = dilation
171-
if kwargs["groups"] is not None:
172-
layer.num_groups = kwargs["groups"]
173-
174-
result = layer.get_output(0)
175-
squeeze_layer = network.add_shuffle(input=result)
176-
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
177-
set_layer_name(squeeze_layer, target, name + "_squeeze")
178-
return squeeze_layer.get_output(0)
99+
return convolution.convNd(
100+
network,
101+
target,
102+
source_ir=SourceIR.ACC,
103+
name=name,
104+
is_conv1d=True,
105+
input_val=kwargs["input"],
106+
weight=kwargs["weight"],
107+
bias=kwargs["bias"],
108+
stride=kwargs["stride"],
109+
padding=kwargs["padding"],
110+
dilation=kwargs["dilation"],
111+
groups=kwargs["groups"],
112+
)
179113

180114

181115
@tensorrt_converter(acc_ops.conv3d)
@@ -187,63 +121,20 @@ def acc_ops_convnd(
187121
kwargs: Dict[str, Argument],
188122
name: str,
189123
) -> Union[TRTTensor, Sequence[TRTTensor]]:
190-
input_val = kwargs["input"]
191-
192-
if not isinstance(input_val, TRTTensor):
193-
raise RuntimeError(
194-
f"Conv received input {input_val} that is not part "
195-
"of the TensorRT region!"
196-
)
197-
198-
if has_dynamic_shape(input_val.shape):
199-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
200-
201-
# for now we'll assume bias is constant Tensor or None,
202-
# and bias being ITensor is not supported in TensorRT api
203-
# right now
204-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
205-
raise RuntimeError(
206-
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
207-
)
208-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
209-
210-
if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
211-
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
212-
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
213-
# will need to use uninitialized weight and set it later to support
214-
# ITensor weights
215-
dummy_weight = trt.Weights()
216-
layer = network.add_convolution_nd(
217-
input=input_val,
218-
num_output_maps=weight.shape[0],
219-
kernel_shape=weight.shape[2:],
220-
kernel=dummy_weight,
221-
bias=bias,
222-
)
223-
224-
layer.set_input(1, weight)
225-
else:
226-
if not isinstance(kwargs["weight"], torch.Tensor):
227-
raise RuntimeError(
228-
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
229-
)
230-
weight = to_numpy(kwargs["weight"])
231-
layer = network.add_convolution_nd(
232-
input=input_val,
233-
num_output_maps=weight.shape[0],
234-
kernel_shape=weight.shape[2:],
235-
kernel=weight,
236-
bias=bias,
237-
)
238-
239-
set_layer_name(layer, target, name)
240-
layer.stride_nd = kwargs["stride"]
241-
layer.padding_nd = kwargs["padding"]
242-
layer.dilation_nd = kwargs["dilation"]
243-
if kwargs["groups"] is not None:
244-
layer.num_groups = kwargs["groups"]
245-
246-
return layer.get_output(0)
124+
return convolution.convNd(
125+
network,
126+
target,
127+
source_ir=SourceIR.ACC,
128+
name=name,
129+
is_conv1d=False,
130+
input_val=kwargs["input"],
131+
weight=kwargs["weight"],
132+
bias=kwargs["bias"],
133+
stride=kwargs["stride"],
134+
padding=kwargs["padding"],
135+
dilation=kwargs["dilation"],
136+
groups=kwargs["groups"],
137+
)
247138

248139

249140
@tensorrt_converter(acc_ops.conv_transpose2d)
@@ -268,32 +159,36 @@ def acc_ops_conv_transposend(
268159
input_val.shape[1] != -1
269160
), "Channel dim can't be dynamic for transpose convolution."
270161

271-
# for now we'll assume bias is constant Tensor or None,
272-
# and bias being ITensor is not supported in TensorRT api
273-
# right now
274-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
275-
raise RuntimeError(
276-
f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
277-
)
278-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
162+
if not isinstance(kwargs["bias"], TRTTensor):
163+
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
164+
raise RuntimeError(
165+
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
166+
)
167+
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
168+
else:
169+
bias = kwargs["bias"]
279170

280171
if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
281172
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
282173
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
283174
# will need to use uninitialized weight and set it later to support
284175
# ITensor weights
285-
dummy_weight = trt.Weights()
286176

287177
# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
288178
layer = network.add_deconvolution_nd(
289179
input=input_val,
290180
num_output_maps=weight.shape[1] * kwargs["groups"],
291181
kernel_shape=weight.shape[2:],
292-
kernel=dummy_weight,
293-
bias=bias,
182+
kernel=trt.Weights(),
183+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
294184
)
295185

296186
layer.set_input(1, weight)
187+
188+
# If the bias is a TRTTensor, set it as an input of the layer
189+
if isinstance(bias, TRTTensor):
190+
bias = get_trt_tensor(network, bias, f"{name}_bias")
191+
layer.set_input(2, bias)
297192
else:
298193
if not isinstance(kwargs["weight"], torch.Tensor):
299194
raise RuntimeError(

‎py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from .converter_utils import * # noqa: F403
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
25-
from torch_tensorrt.fx.converters.impl import activation
25+
from torch_tensorrt.fx.converters.impl import activation, convolution
2626

2727
_LOGGER: logging.Logger = logging.getLogger(__name__)
2828

@@ -129,13 +129,36 @@ def aten_ops_convolution(
129129
# we do not handle output_padding.
130130
if args[7] not in ([0], [0, 0], [0, 0, 0]):
131131
raise RuntimeError(f"Target {target} has non-0 output_padding")
132+
132133
if len(kwargs_new["stride"]) == 1:
133-
return acc_ops_converters.acc_ops_conv1d(
134-
network, target, None, kwargs_new, name
134+
return convolution.convNd(
135+
network,
136+
target,
137+
source_ir=SourceIR.ATEN,
138+
name=name,
139+
is_conv1d=True,
140+
input_val=kwargs_new["input"],
141+
weight=kwargs_new["weight"],
142+
bias=kwargs_new["bias"],
143+
stride=kwargs_new["stride"],
144+
padding=kwargs_new["padding"],
145+
dilation=kwargs_new["dilation"],
146+
groups=kwargs_new["groups"],
135147
)
136148
else:
137-
return acc_ops_converters.acc_ops_convnd(
138-
network, target, None, kwargs_new, name
149+
return convolution.convNd(
150+
network,
151+
target,
152+
source_ir=SourceIR.ATEN,
153+
name=name,
154+
is_conv1d=False,
155+
input_val=kwargs_new["input"],
156+
weight=kwargs_new["weight"],
157+
bias=kwargs_new["bias"],
158+
stride=kwargs_new["stride"],
159+
padding=kwargs_new["padding"],
160+
dilation=kwargs_new["dilation"],
161+
groups=kwargs_new["groups"],
139162
)
140163

141164

‎py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,17 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
9999

100100

101101
def set_layer_name(
102-
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
102+
layer: TRTLayer,
103+
target: Union[Target, torch.nn.Module, str],
104+
name: str,
105+
source_ir: Optional[SourceIR] = None,
103106
) -> None:
104107
"""
105108
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
106109
107110
Args:
108111
layer (TRTLayer): A TensorRT layer of which we want to set the name.
109-
target (Target): A fx node.target. For call_function node, it's the function that
112+
target (Target): A fx node.target or submodule. For call_function node, it's the function that
110113
the node represents.
111114
name (str): Consists of fx node.name with optional suffix.
112115
source_ir: (Optional[SourceIR]): The IR producing the op.
Lines changed: 69 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,212 +1,123 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
22
import logging
33

4-
import numpy as np
5-
import tensorrt as trt
64
import torch
75

86
from ..converter_registry import tensorrt_converter
97

108
from .converter_utils import (
11-
extend_mod_attr_to_tuple,
12-
get_dyn_range,
13-
mark_as_int8_layer,
14-
to_numpy,
9+
SourceIR,
1510
)
1611

17-
logger = logging.getLogger(__name__)
18-
12+
from torch_tensorrt.fx.converters.impl import convolution, activation
1913

20-
def common_conv(network, mod, dimension, input_val, layer_name, is_quantized):
21-
if mod.padding_mode != "zeros":
22-
raise RuntimeError(f"Only support padding mode: zeros, got {mod.padding_mode}.")
23-
24-
kernel_size = extend_mod_attr_to_tuple(mod, "kernel_size", dimension)
25-
stride = extend_mod_attr_to_tuple(mod, "stride", dimension)
26-
padding = extend_mod_attr_to_tuple(mod, "padding", dimension)
27-
dilation = extend_mod_attr_to_tuple(mod, "dilation", dimension)
28-
29-
kernel = to_numpy(mod.weight() if is_quantized else mod.weight)
30-
bias = to_numpy(mod.bias() if is_quantized else mod.bias)
31-
32-
if dimension == 1:
33-
# Append unsqueeze before conv2d to calculate conv1d
34-
unsqueeze_layer = network.add_shuffle(input=input_val)
35-
unsqueeze_layer.reshape_dims = (*input_val.shape, 1)
36-
unsqueeze_layer.name = f"{layer_name}_unsqueeze"
37-
input_val = unsqueeze_layer.get_output(0)
38-
39-
kernel = np.expand_dims(kernel, -1)
40-
kernel_size = kernel.shape[2:]
41-
if bias is not None:
42-
bias = bias[None]
43-
stride = (stride[0], 1)
44-
padding = (padding[0], 0)
45-
dilation = (dilation[0], 1)
46-
layer = network.add_convolution_nd(
47-
input=input_val,
48-
num_output_maps=mod.out_channels,
49-
kernel_shape=kernel_size,
50-
kernel=kernel,
51-
bias=bias,
52-
)
53-
layer.name = layer_name
54-
layer.stride_nd = stride
55-
layer.padding_nd = padding
56-
layer.dilation_nd = dilation
57-
layer.num_groups = mod.groups
58-
59-
if is_quantized:
60-
# Assume the dtype of activation is torch.quint8
61-
mark_as_int8_layer(
62-
layer, get_dyn_range(mod.scale, mod.zero_point, torch.quint8)
63-
)
64-
65-
result = layer.get_output(0)
66-
if dimension == 1:
67-
# Append squeeze after conv2d to calculate conv1d
68-
squeeze_layer = network.add_shuffle(input=result)
69-
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
70-
squeeze_layer.name = f"{layer_name}_squeeze"
71-
result = squeeze_layer.get_output(0)
72-
73-
return result
74-
75-
76-
def common_conv_relu(network, mod, dimension, input_val, layer_name, is_quantized):
77-
conv_output = common_conv(
78-
network,
79-
mod,
80-
dimension=2,
81-
input_val=input_val,
82-
layer_name=f"{layer_name}_conv",
83-
is_quantized=is_quantized,
84-
)
85-
86-
layer = network.add_activation(input=conv_output, type=trt.ActivationType.RELU)
87-
layer.name = f"{layer_name}_relu"
88-
89-
if is_quantized:
90-
mark_as_int8_layer(layer, conv_output.dynamic_range)
91-
92-
return layer.get_output(0)
14+
logger = logging.getLogger(__name__)
9315

9416

9517
@tensorrt_converter(torch.nn.modules.conv.Conv1d)
9618
def conv1d(network, submod, args, kwargs, layer_name):
9719
# args/kwargs should have already been normalized to kwargs
9820
assert len(args) == 0
99-
input_val = kwargs["input"]
100-
101-
if not isinstance(input_val, trt.tensorrt.ITensor):
102-
raise RuntimeError(
103-
f"Conv1d received input {input_val} that is not part "
104-
"of the TensorRT region!"
105-
)
10621

10722
if layer_name is None:
10823
raise RuntimeError("layer name is none")
109-
return common_conv(
24+
return convolution.convNd(
11025
network,
111-
submod,
112-
dimension=1,
113-
input_val=input_val,
114-
layer_name=layer_name,
115-
is_quantized=False,
26+
submod._get_name(),
27+
source_ir=SourceIR.NN,
28+
name=layer_name,
29+
is_conv1d=True,
30+
input_val=kwargs["input"],
31+
weight=submod.weight,
32+
bias=submod.bias,
33+
stride=getattr(submod, "stride"),
34+
padding=getattr(submod, "padding"),
35+
dilation=getattr(submod, "dilation"),
36+
groups=submod.groups,
11637
)
11738

11839

11940
@tensorrt_converter(torch.nn.modules.conv.Conv2d)
12041
def conv2d(network, submod, args, kwargs, layer_name):
12142
# args/kwargs should have already been normalized to kwargs
12243
assert len(args) == 0
123-
input_val = kwargs["input"]
124-
125-
if not isinstance(input_val, trt.tensorrt.ITensor):
126-
raise RuntimeError(
127-
f"Conv2d received input {input_val} that is not part "
128-
"of the TensorRT region!"
129-
)
130-
131-
return common_conv(
44+
return convolution.convNd(
13245
network,
133-
submod,
134-
dimension=2,
135-
input_val=input_val,
136-
layer_name=layer_name,
137-
is_quantized=False,
46+
submod._get_name(),
47+
source_ir=SourceIR.NN,
48+
name=layer_name,
49+
is_conv1d=False,
50+
input_val=kwargs["input"],
51+
weight=submod.weight,
52+
bias=submod.bias,
53+
stride=getattr(submod, "stride"),
54+
padding=getattr(submod, "padding"),
55+
dilation=getattr(submod, "dilation"),
56+
groups=submod.groups,
13857
)
13958

14059

14160
@tensorrt_converter(torch.nn.modules.conv.Conv3d)
14261
def conv3d(network, submod, args, kwargs, layer_name):
14362
# args/kwargs should have already been normalized to kwargs
14463
assert len(args) == 0
145-
input_val = kwargs["input"]
146-
# TODO: Remove this warning when https://github.com/pytorch/TensorRT/issues/1445 is fixed
147-
kernel = to_numpy(submod.weight)
148-
kernel_size_one = True
149-
if len(kernel.shape) == 5:
150-
for filter_size in kernel.shape[2:]:
151-
if filter_size != 1:
152-
kernel_size_one = False
153-
if kernel_size_one:
154-
logger.warn(
155-
"Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \
156-
Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue."
157-
)
158-
159-
if not isinstance(input_val, trt.tensorrt.ITensor):
160-
raise RuntimeError(
161-
f"Conv3d received input {input_val} that is not part "
162-
"of the TensorRT region!"
163-
)
164-
165-
return common_conv(
64+
return convolution.convNd(
16665
network,
167-
submod,
168-
dimension=3,
169-
input_val=input_val,
170-
layer_name=layer_name,
171-
is_quantized=False,
66+
submod._get_name(),
67+
source_ir=SourceIR.NN,
68+
name=layer_name,
69+
is_conv1d=False,
70+
input_val=kwargs["input"],
71+
weight=submod.weight,
72+
bias=submod.bias,
73+
stride=getattr(submod, "stride"),
74+
padding=getattr(submod, "padding"),
75+
dilation=getattr(submod, "dilation"),
76+
groups=submod.groups,
17277
)
17378

17479

17580
@tensorrt_converter(torch.nn.quantized.modules.conv.Conv2d)
17681
def quantized_conv2d(network, submod, args, kwargs, layer_name):
17782
input_val = args[0]
178-
179-
if not isinstance(input_val, trt.tensorrt.ITensor):
180-
raise RuntimeError(
181-
f"Quantized Conv2d received input {input_val} that is not part "
182-
"of the TensorRT region!"
183-
)
184-
185-
return common_conv(
83+
return convolution.convNd(
18684
network,
187-
submod,
188-
dimension=2,
85+
submod._get_name(),
86+
source_ir=SourceIR.NN,
87+
name=layer_name,
88+
is_conv1d=False,
18989
input_val=input_val,
190-
layer_name=layer_name,
191-
is_quantized=True,
90+
weight=submod.weight(),
91+
bias=submod.bias(),
92+
stride=getattr(submod, "stride"),
93+
padding=getattr(submod, "padding"),
94+
dilation=getattr(submod, "dilation"),
95+
groups=submod.groups,
96+
scale=submod.scale,
97+
zero_point=submod.zero_point,
19298
)
19399

194100

195101
@tensorrt_converter(torch.nn.intrinsic.quantized.modules.ConvReLU2d)
196102
def quantized_conv_relu2d(network, submod, args, kwargs, layer_name):
197103
input_val = args[0]
198-
199-
if not isinstance(input_val, trt.tensorrt.ITensor):
200-
raise RuntimeError(
201-
f"Quantized ConvReLU2d received input {input_val} that is not part "
202-
"of the TensorRT region!"
203-
)
204-
205-
return common_conv_relu(
104+
conv_out = convolution.convNd(
206105
network,
207-
submod,
208-
dimension=2,
106+
submod._get_name(),
107+
source_ir=SourceIR.NN,
108+
name=layer_name,
109+
is_conv1d=False,
209110
input_val=input_val,
210-
layer_name=f"{layer_name}_conv",
211-
is_quantized=True,
111+
weight=submod.weight(),
112+
bias=submod.bias(),
113+
stride=getattr(submod, "stride"),
114+
padding=getattr(submod, "padding"),
115+
dilation=getattr(submod, "dilation"),
116+
groups=submod.groups,
117+
scale=submod.scale,
118+
zero_point=submod.zero_point,
119+
)
120+
121+
return activation.relu(
122+
network, submod._get_name(), SourceIR.NN, layer_name + "_relu", conv_out
212123
)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import numpy as np
2+
from typing import Any, Callable, Optional, Sequence, Union
3+
4+
# @manual=//deeplearning/trt/python:py_tensorrt
5+
import tensorrt as trt
6+
import torch
7+
from torch.fx.node import Target
8+
9+
from torch_tensorrt.fx.converters.converter_utils import (
10+
SourceIR,
11+
extend_attr_to_tuple,
12+
get_dyn_range,
13+
mark_as_int8_layer,
14+
set_layer_name,
15+
has_dynamic_shape,
16+
to_numpy,
17+
get_trt_tensor,
18+
)
19+
from torch_tensorrt.fx.converters.acc_ops_converters import (
20+
acc_ops_unsqueeze,
21+
acc_ops_squeeze,
22+
)
23+
24+
from torch_tensorrt.fx.types import (
25+
TRTNetwork,
26+
TRTTensor,
27+
)
28+
29+
30+
def convNd(
31+
network: TRTNetwork,
32+
target: Union[Target, str],
33+
source_ir: Optional[SourceIR],
34+
name: str,
35+
is_conv1d: bool,
36+
input_val: TRTTensor,
37+
weight: Union[TRTTensor, torch.Tensor],
38+
bias: Optional[Union[TRTTensor, torch.Tensor]],
39+
stride: Optional[Union[int, Sequence[int]]],
40+
padding: Optional[Union[int, Sequence[int]]],
41+
dilation: Optional[Union[int, Sequence[int]]],
42+
groups: Optional[int],
43+
scale: Optional[Union[torch.Tensor, float]] = None,
44+
zero_point: Optional[Union[torch.Tensor, float]] = None,
45+
) -> TRTTensor:
46+
47+
if has_dynamic_shape(input_val.shape):
48+
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
49+
50+
if is_conv1d:
51+
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
52+
kwargs = {
53+
"input": input_val,
54+
"dim": -1,
55+
}
56+
input_val = acc_ops_unsqueeze(
57+
network, target, tuple(), kwargs, name + "_unsqueeze"
58+
)
59+
60+
# Process bias terms
61+
if isinstance(bias, torch.Tensor):
62+
# Transform the bias constant into a Numpy array
63+
bias = to_numpy(bias)
64+
65+
# Prepend new dimension (unsqueeze) if the convolution is 1d
66+
if is_conv1d:
67+
bias = np.expand_dims(bias, 0)
68+
69+
elif isinstance(bias, TRTTensor):
70+
bias = get_trt_tensor(network, bias, f"{name}_bias")
71+
# Prepend new dimension (unsqueeze) if the convolution is 1d
72+
if is_conv1d:
73+
kwargs = {
74+
"input": bias,
75+
"dim": 0,
76+
}
77+
bias = acc_ops_unsqueeze(
78+
network, target, tuple(), kwargs, name + "_unsqueeze_bias"
79+
)
80+
81+
elif bias is not None:
82+
raise RuntimeError(
83+
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
84+
)
85+
86+
# Process weight terms
87+
if network.has_explicit_precision or isinstance(weight, TRTTensor):
88+
weight = get_trt_tensor(network, weight, f"{name}_weight")
89+
# Append new dimension (unsqueeze) if the convolution is 1d
90+
if is_conv1d:
91+
kwargs = {
92+
"input": weight,
93+
"dim": -1,
94+
}
95+
weight = acc_ops_unsqueeze(
96+
network, target, tuple(), kwargs, name + "_unsqueeze_weight"
97+
)
98+
99+
elif isinstance(weight, torch.Tensor):
100+
# Transform the weight constant into a Numpy array
101+
weight = to_numpy(weight)
102+
103+
# Append new dimension (unsqueeze) if the convolution is 1d
104+
if is_conv1d:
105+
weight = np.expand_dims(weight, -1)
106+
107+
else:
108+
raise RuntimeError(
109+
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
110+
)
111+
112+
conv_layer = network.add_convolution_nd(
113+
input=input_val,
114+
num_output_maps=weight.shape[0],
115+
kernel_shape=weight.shape[2:],
116+
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
117+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
118+
)
119+
120+
# If the weight is a TRTTensor, set it as an input of the layer
121+
if isinstance(weight, TRTTensor):
122+
conv_layer.set_input(1, weight)
123+
124+
# If the bias is a TRTTensor, set it as an input of the layer
125+
if isinstance(bias, TRTTensor):
126+
conv_layer.set_input(2, bias)
127+
128+
# Expand parameters manually for Conv1D computations
129+
if is_conv1d:
130+
padding = tuple(padding) + (0,)
131+
stride = extend_attr_to_tuple(stride, 2)
132+
dilation = extend_attr_to_tuple(dilation, 2)
133+
134+
set_layer_name(conv_layer, target, name, source_ir)
135+
136+
# Set relevant attributes of convolution layer
137+
conv_layer.padding_nd = padding
138+
conv_layer.stride_nd = stride
139+
conv_layer.dilation_nd = dilation
140+
141+
if groups is not None:
142+
conv_layer.num_groups = groups
143+
144+
# Handle quantization cases
145+
if scale is not None and zero_point is not None:
146+
# Assume the dtype of activation is torch.quint8
147+
mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8))
148+
149+
result = conv_layer.get_output(0)
150+
151+
if is_conv1d:
152+
# Apply a squeeze operation to transform the conv2d problem back into conv1d
153+
kwargs = {
154+
"input": result,
155+
"dim": -1,
156+
}
157+
result = acc_ops_squeeze(network, target, tuple(), kwargs, name + "_squeeze")
158+
159+
return result

0 commit comments

Comments
 (0)
Please sign in to comment.