Skip to content

Commit ae2706d

Browse files
authored
Merge pull request #1770 from pytorch/bose_fx2trt_converters_slice_select
fx2trt converters aten::slice,aten::select and aten::matmul
2 parents bde4860 + f721ed1 commit ae2706d

19 files changed

+1406
-247
lines changed

core/conversion/converters/impl/reduce.cpp

+54-22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,36 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12+
nvinfer1::ITensor* anyDimImplementation(
13+
ConversionCtx* ctx,
14+
const torch::jit::Node* n,
15+
nvinfer1::ITensor* in_tensor,
16+
int dim,
17+
bool keepdim) {
18+
auto in_dims = in_tensor->getDimensions();
19+
LOG_DEBUG("Dim to reduce (original): " << dim);
20+
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
21+
LOG_DEBUG("Dim to reduce (converted): " << dim);
22+
23+
uint32_t axis_mask = 1 << dim;
24+
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
25+
LOG_DEBUG("Keep dims: " << keepdim);
26+
27+
// Reduce does not work on bool inputs
28+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
29+
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
30+
}
31+
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
32+
33+
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
34+
35+
sum_layer->setName(util::node_info(n).c_str());
36+
auto out_tensor =
37+
castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
38+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
39+
return out_tensor;
40+
}
41+
1242
auto reduce_registrations TORCHTRT_UNUSED =
1343
RegisterNodeConversionPatterns()
1444
.pattern(
@@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED =
224254
{"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
225255
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
226256
auto in_tensor = args[0].ITensorOrFreeze(ctx);
227-
auto in_dims = in_tensor->getDimensions();
228257
auto dim = args[1].unwrapToInt();
229-
LOG_DEBUG("Dim to reduce (original): " << dim);
230-
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
231-
LOG_DEBUG("Dim to reduce (converted): " << dim);
232-
233-
uint32_t axis_mask = 1 << dim;
234-
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
235-
236258
auto keepdim = args[2].unwrapToBool();
237-
LOG_DEBUG("Keep dims: " << keepdim);
238-
239-
// Reduce does not work on bool inputs
240-
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
241-
in_tensor =
242-
castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
243-
}
244-
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
245-
246-
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
247-
248-
sum_layer->setName(util::node_info(n).c_str());
249-
auto out_tensor = castITensor(
250-
ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
259+
auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim);
251260
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
252261
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
253262
return true;
263+
}})
264+
.pattern(
265+
{"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
266+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
267+
// use Not(Any(Not(input))) to calculate all without a direct all reduction
268+
auto in_tensor = args[0].ITensorOrFreeze(ctx);
269+
auto dim = args[1].unwrapToInt();
270+
auto keepdim = args[2].unwrapToBool();
271+
if (in_tensor->getType() != nvinfer1::DataType::kBOOL) {
272+
// unary not layer only supports bool inputs
273+
in_tensor = castITensor(
274+
ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str());
275+
}
276+
auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT);
277+
TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n);
278+
not_input_layer->setName((util::node_info(n) + "_not_in").c_str());
279+
auto not_in = not_input_layer->getOutput(0);
280+
auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim);
281+
auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT);
282+
TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n);
283+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0));
284+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
285+
return true;
254286
}});
255287
} // namespace
256288
} // namespace impl

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+12-208
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,13 @@ def acc_ops_batch_norm(
678678

679679

680680
@tensorrt_converter(acc_ops.layer_norm)
681-
def acc_ops_layer_norm(network, target, args, kwargs, name):
681+
def acc_ops_layer_norm(
682+
network: TRTNetwork,
683+
target: Target,
684+
args: Tuple[Argument, ...],
685+
kwargs: Dict[str, Argument],
686+
name: str,
687+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
682688
return add_layer_norm(network, target, kwargs, name)
683689

684690

@@ -690,37 +696,7 @@ def acc_ops_softmax(
690696
kwargs: Dict[str, Argument],
691697
name: str,
692698
) -> Union[TRTTensor, Sequence[TRTTensor]]:
693-
input_val = kwargs["input"]
694-
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
695-
696-
if not isinstance(input_val, TRTTensor):
697-
raise RuntimeError(
698-
f"softmax received input {input_val} that is not part "
699-
"of the TensorRT region!"
700-
)
701-
702-
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
703-
def get_softmax_dim(ndim: int) -> int:
704-
if ndim == 0 or ndim == 1 or ndim == 3:
705-
ret = 0
706-
else:
707-
ret = 1
708-
return ret
709-
710-
if kwargs["dim"] is None:
711-
dim = get_softmax_dim(input_ranks)
712-
else:
713-
dim = cast(int, kwargs["dim"])
714-
715-
dim = get_positive_dim(dim, input_ranks)
716-
if network.has_implicit_batch_dimension:
717-
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
718-
dim -= 1
719-
720-
layer = network.add_softmax(input_val)
721-
layer.axes = 1 << dim
722-
set_layer_name(layer, target, name)
723-
return layer.get_output(0)
699+
return add_softmax(network, target, kwargs, name)
724700

725701

726702
@tensorrt_converter(acc_ops.tile)
@@ -956,9 +932,7 @@ def acc_ops_sqrt(
956932
kwargs: Dict[str, Argument],
957933
name: str,
958934
) -> Union[TRTTensor, Sequence[TRTTensor]]:
959-
input_val = kwargs["input"]
960-
operation_type = trt.UnaryOperation.SQRT
961-
return add_unary_layer(network, input_val, operation_type, target, name)
935+
return add_sqrt(network, target, kwargs, name)
962936

963937

964938
@tensorrt_converter(acc_ops.reciprocal)
@@ -1619,40 +1593,7 @@ def acc_ops_squeeze(
16191593
kwargs: Dict[str, Argument],
16201594
name: str,
16211595
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1622-
input_val = kwargs["input"]
1623-
1624-
if not isinstance(input_val, TRTTensor):
1625-
raise RuntimeError(
1626-
f"squeeze received input {input_val} that is not part "
1627-
"of the TensorRT region!"
1628-
)
1629-
1630-
dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
1631-
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
1632-
# dim, which is a very rare case. For now we just claim not supporting dim=None.
1633-
assert dim is not None, "We don't support dim=None right now for squeeze."
1634-
1635-
dim = get_positive_dim(
1636-
dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
1637-
)
1638-
if network.has_implicit_batch_dimension:
1639-
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
1640-
dim -= 1
1641-
1642-
assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
1643-
assert (
1644-
len(get_dynamic_dims(input_val.shape)) <= 1
1645-
), "Currently more than one dynamic dim for input to squeeze is not supported."
1646-
1647-
output_shape = []
1648-
for i, s in enumerate(input_val.shape):
1649-
if i == dim and s == 1:
1650-
continue
1651-
output_shape.append(s)
1652-
layer = network.add_shuffle(input_val)
1653-
layer.reshape_dims = tuple(output_shape)
1654-
set_layer_name(layer, target, name)
1655-
return layer.get_output(0)
1596+
return add_squeeze(network, target, kwargs, name)
16561597

16571598

16581599
@tensorrt_converter(acc_ops.add)
@@ -2022,89 +1963,7 @@ def acc_ops_where(
20221963
kwargs: Dict[str, Argument],
20231964
name: str,
20241965
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2025-
2026-
condition_t = kwargs["condition"]
2027-
x_t = kwargs["x"]
2028-
y_t = kwargs["y"]
2029-
2030-
if type(x_t) != TRTTensor:
2031-
assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!"
2032-
2033-
if type(y_t) != TRTTensor:
2034-
assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!"
2035-
2036-
# get output shape
2037-
2038-
x_shape = list(x_t.shape)
2039-
y_shape = list(y_t.shape)
2040-
condition_shape = list(condition_t.shape)
2041-
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
2042-
2043-
# expand shape
2044-
if type(condition_t) != TRTTensor:
2045-
assert condition_t.dtype == torch.bool, "condition dtype is not bool"
2046-
if condition_shape != output_shape:
2047-
condition_t.expand(output_shape)
2048-
condition_t = condition_t.to(torch.int32)
2049-
condition_const = get_trt_tensor(network, condition_t, f"{name}_condition")
2050-
condition_layer = network.add_identity(condition_const)
2051-
condition_layer.set_output_type(0, trt.bool)
2052-
set_layer_name(condition_layer, target, f"{name}_condition")
2053-
condition_val = condition_layer.get_output(0)
2054-
else:
2055-
assert condition_t.dtype == trt.bool, "mask dtype is not bool!"
2056-
if condition_shape != output_shape:
2057-
condition_val = acc_ops_expand_tensor(
2058-
network,
2059-
target,
2060-
None,
2061-
{"input": condition_t, "sizes": output_shape},
2062-
name=f"{name}_expand",
2063-
)
2064-
else:
2065-
condition_val = condition_t
2066-
2067-
if type(x_t) != TRTTensor:
2068-
if x_shape != output_shape:
2069-
# special case where 1 element in x_t
2070-
if len(x_t.shape) == 0:
2071-
x_t = x_t.unsqueeze(0)
2072-
x_t = x_t.expand(output_shape)
2073-
x_val = get_trt_tensor(network, x_t, f"{name}_x")
2074-
else:
2075-
x_val = x_t
2076-
if x_shape != output_shape:
2077-
x_val = acc_ops_expand_tensor(
2078-
network,
2079-
target,
2080-
None,
2081-
{"input": x_val, "sizes": output_shape},
2082-
name=f"{name}_x_expand",
2083-
)
2084-
2085-
if type(y_t) != TRTTensor:
2086-
if y_shape != output_shape:
2087-
# special case where 1 element in y_t
2088-
if len(y_t.shape) == 0:
2089-
y_t = y_t.unsqueeze(0)
2090-
y_t = y_t.expand(output_shape)
2091-
y_val = get_trt_tensor(network, y_t, f"{name}_y")
2092-
else:
2093-
y_val = y_t
2094-
if y_shape != output_shape:
2095-
y_val = acc_ops_expand_tensor(
2096-
network,
2097-
target,
2098-
None,
2099-
{"input": y_val, "sizes": output_shape},
2100-
name=f"{name}_y_expand",
2101-
)
2102-
2103-
select_layer = network.add_select(condition_val, x_val, y_val)
2104-
2105-
set_layer_name(select_layer, target, f"{name}_select")
2106-
2107-
return select_layer.get_output(0)
1966+
return add_where(network, target, kwargs, name)
21081967

21091968

21101969
@tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True)
@@ -2721,62 +2580,7 @@ def acc_ops_chunk(
27212580
kwargs: Dict[str, Argument],
27222581
name: str,
27232582
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2724-
input_val = kwargs["input"]
2725-
chunks = cast(int, kwargs["chunks"])
2726-
dim = cast(int, kwargs["dim"])
2727-
input_dim_size = len(input_val.shape) # type: ignore[union-attr]
2728-
2729-
if not isinstance(input_val, TRTTensor):
2730-
raise RuntimeError(
2731-
f"chunk received input {input_val} that is not part "
2732-
"of the TensorRT region!"
2733-
)
2734-
2735-
dynamic_shape = has_dynamic_shape(input_val.shape)
2736-
if network.has_implicit_batch_dimension:
2737-
input_dim_size += 1
2738-
dim = get_positive_dim(dim, input_dim_size)
2739-
assert dim != 0, "Can't chunk on batch dim when it's implicit!"
2740-
dim -= 1
2741-
else:
2742-
if dynamic_shape:
2743-
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
2744-
dim = get_positive_dim(dim, input_dim_size)
2745-
2746-
if chunks > input_val.shape[dim]:
2747-
warnings.warn(
2748-
f"Asked for {chunks} chunks along dimention "
2749-
f"{dim} on tensor with size {input_val.shape}, chunks "
2750-
f"will default to {input_val.shape[dim]}",
2751-
RuntimeWarning,
2752-
)
2753-
chunks = input_val.shape[dim]
2754-
2755-
start = [0] * len(input_val.shape)
2756-
stride = [1] * len(start)
2757-
offset = 0
2758-
split_size = (input_val.shape[dim] + chunks - 1) // chunks
2759-
2760-
max_offset = input_val.shape[dim]
2761-
# add slice layers
2762-
output = []
2763-
for i in range(chunks):
2764-
shape = list(input_val.shape)
2765-
shape[dim] = min(split_size, max_offset - offset)
2766-
if dynamic_shape:
2767-
shape = get_shape_with_dynamic_shape(
2768-
network, shape, input_val, target, f"{name}_{i}"
2769-
)
2770-
start[dim] = offset
2771-
layer = network.add_slice(
2772-
input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
2773-
)
2774-
if dynamic_shape:
2775-
layer.set_input(2, shape)
2776-
offset += split_size
2777-
set_layer_name(layer, target, f"{name}_{i}")
2778-
output.append(layer.get_output(0))
2779-
return output
2583+
return add_chunk(network, target, kwargs, name)
27802584

27812585

27822586
@tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True)

0 commit comments

Comments
 (0)