1
1
import logging
2
- from typing import Dict , List , Optional , Sequence , Tuple
2
+ from typing import Collection , Dict , List , Optional , Tuple
3
3
4
4
import torch
5
-
5
+ import torch .fx .passes .operator_support as ops
6
+ from torch .fx .node import Target
6
7
from torch .fx .passes .splitter_base import (
8
+ FxNetAccFusionsFinder ,
9
+ FxNetAccNodesFinder ,
7
10
Subgraph ,
8
11
_SplitterBase ,
9
12
_SplitterSettingBase ,
10
- FxNetAccNodesFinder ,
11
- FxNetAccFusionsFinder ,
12
13
)
13
- import torch .fx .passes .operator_support as ops
14
- from torch .fx .passes .tools_common import NodeSet , CALLABLE_NODE_OPS
15
- from torch .fx .node import Target
16
-
17
- from torch_tensorrt .dynamo .conversion .converter_registry import ConverterRegistry
18
- from .common import DEFAULT_SINGLE_NODE_PARTITIONS
19
- from torch_tensorrt .dynamo ._defaults import MIN_BLOCK_SIZE
20
-
14
+ from torch .fx .passes .tools_common import CALLABLE_NODE_OPS , NodeSet
21
15
from torch_tensorrt .dynamo import DYNAMO_CONVERTERS as CONVERTERS
16
+ from torch_tensorrt .dynamo ._defaults import MIN_BLOCK_SIZE
17
+ from torch_tensorrt .dynamo .conversion .converter_registry import ConverterRegistry
22
18
19
+ from .common import DEFAULT_SINGLE_NODE_PARTITIONS
23
20
24
21
logger = logging .getLogger (__name__ )
25
22
26
23
27
- class OpSupportTester (ops .OperatorSupportBase ):
24
+ class OpSupportTester (ops .OperatorSupportBase ): # type: ignore
28
25
"""Class to determine whether operators within a module are supported"""
29
26
30
- def __init__ (self , torch_executed_ops : Sequence [Target ] = set ()) -> None :
27
+ def __init__ (self , torch_executed_ops : Collection [Target ] = set ()) -> None :
31
28
super ().__init__ ()
32
29
33
30
# Initialize sets of supported/unsupported operators
34
- self .supported_operators = {}
35
- self .unsupported_operators = {}
31
+ self .supported_operators : Dict [ str , int ] = {}
32
+ self .unsupported_operators : Dict [ str , int ] = {}
36
33
self .torch_executed_ops = torch_executed_ops
37
34
38
35
def is_node_supported (
@@ -58,7 +55,7 @@ def is_node_supported(
58
55
59
56
return False
60
57
61
- def print_support_overview (self , num_trt_blocks : Optional [int ] = None ):
58
+ def print_support_overview (self , num_trt_blocks : Optional [int ] = None ) -> None :
62
59
if num_trt_blocks is not None :
63
60
logger .debug (
64
61
f"\n Number of TensorRT-Accelerated Engines Generated: { num_trt_blocks } "
@@ -81,7 +78,7 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
81
78
logger .debug ("\n All Nodes Supported\n " )
82
79
83
80
84
- class TRTPartitioner (_SplitterBase ):
81
+ class TRTPartitioner (_SplitterBase ): # type: ignore
85
82
"""Partitioner to split an FX graph into subgraphs based on operator support
86
83
87
84
Adapted from, and modified for the Torch-TensorRT Dynamo case:
@@ -102,7 +99,7 @@ def __init__(
102
99
module : torch .fx .GraphModule ,
103
100
operator_support : ops .OperatorSupportBase ,
104
101
allowed_single_node_partition_ops : Optional [
105
- Sequence [str ]
102
+ Collection [str ]
106
103
] = DEFAULT_SINGLE_NODE_PARTITIONS ,
107
104
min_block_size : int = MIN_BLOCK_SIZE ,
108
105
):
@@ -141,7 +138,7 @@ def __init__(
141
138
self .non_acc_submodule_name = "_run_on_gpu_"
142
139
self ._node_submodule_map : Dict [str , str ] = {}
143
140
144
- self .num_trt_accelerated_subgraphs = None
141
+ self .num_trt_accelerated_subgraphs : Optional [ int ] = None
145
142
self .allowed_single_node_partition_ops = allowed_single_node_partition_ops
146
143
147
144
def remove_small_acc_subgraphs (self , subgraphs : List [Subgraph ]) -> List [Subgraph ]:
@@ -152,10 +149,13 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
152
149
result : List [Subgraph ] = []
153
150
for subgraph in subgraphs :
154
151
if subgraph .is_acc :
155
- if len (subgraph .nodes ) >= self .settings .min_acc_module_size or any (
156
- ConverterRegistry .qualified_name_or_str (node .target )
157
- in self .allowed_single_node_partition_ops
158
- for node in subgraph .nodes
152
+ if len (subgraph .nodes ) >= self .settings .min_acc_module_size or (
153
+ self .allowed_single_node_partition_ops is not None
154
+ and any (
155
+ ConverterRegistry .qualified_name_or_str (node .target )
156
+ in self .allowed_single_node_partition_ops
157
+ for node in subgraph .nodes
158
+ )
159
159
):
160
160
result .append (subgraph )
161
161
else :
@@ -214,7 +214,7 @@ def partition(
214
214
gm : torch .fx .GraphModule ,
215
215
verbose : bool = True ,
216
216
min_block_size : int = MIN_BLOCK_SIZE ,
217
- torch_executed_ops : Sequence [Target ] = set (),
217
+ torch_executed_ops : Collection [Target ] = set (),
218
218
) -> torch .fx .GraphModule :
219
219
"""Partition an FX GraphModule with aten ops into TRT engines
220
220
Partitioning is based on converter operator support
@@ -223,7 +223,7 @@ def partition(
223
223
gm: FX GraphModule to partition
224
224
verbose: Bool representing whether to print operator support
225
225
min_block_size: Minimum number of operators per TRT-Engine Block
226
- torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
226
+ torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
227
227
Returns:
228
228
torch.fx.GraphModule
229
229
"""
0 commit comments