Skip to content

Commit f3b34ac

Browse files
committed
fix: Linting and formatting updates
1 parent 18c0680 commit f3b34ac

File tree

8 files changed

+46
-59
lines changed

8 files changed

+46
-59
lines changed

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
PASS_THROUGH_BUILD_FAILURES,
1111
PRECISION,
1212
TRUNCATE_LONG_AND_DOUBLE,
13+
USE_FAST_PARTITIONER,
1314
USE_PYTHON_RUNTIME,
1415
VERSION_COMPATIBLE,
1516
WORKSPACE_SIZE,
16-
USE_FAST_PARTITIONER,
1717
)
1818

1919

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,7 @@
55
import torch
66
import torch._dynamo as td
77
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
8-
from torch_tensorrt.dynamo import CompilationSettings
9-
from torch_tensorrt.dynamo.lowering._decompositions import (
10-
get_decompositions,
11-
)
12-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
13-
pre_aot_substitutions,
14-
)
15-
from torch_tensorrt.dynamo import partitioning
16-
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
8+
from torch_tensorrt.dynamo import CompilationSettings, partitioning
179
from torch_tensorrt.dynamo.conversion import (
1810
convert_module,
1911
repair_long_or_double_inputs,
@@ -138,7 +130,6 @@ def _compile_module(
138130
# Iterate over all components that can be accelerated
139131
# Generate the corresponding TRT Module for those
140132
for name, _ in partitioned_module.named_children():
141-
142133
# Criteria for a module to be convertible to TRT
143134
if settings.use_fast_partitioner and "_run_on_acc" not in name:
144135
continue

py/torch_tensorrt/dynamo/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
PASS_THROUGH_BUILD_FAILURES,
2020
PRECISION,
2121
TRUNCATE_LONG_AND_DOUBLE,
22+
USE_FAST_PARTITIONER,
2223
USE_PYTHON_RUNTIME,
2324
VERSION_COMPATIBLE,
2425
WORKSPACE_SIZE,
25-
USE_FAST_PARTITIONER,
2626
)
2727
from torch_tensorrt.dynamo.backend.backends import _compile_module
2828
from torch_tensorrt.dynamo.conversion import convert_module
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from ._decompositions import get_decompositions # noqa: F401
2-
from ._fusers import * # noqa: F403
3-
from ._pre_aot_lowering import (
4-
SUBSTITUTION_REGISTRY, # noqa: F401
5-
register_substitution, # noqa: F401
6-
)
7-
from .substitutions import * # noqa: F401
82
from ._fusers import * # noqa: F401
3+
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
4+
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from .substitutions import * # noqa: F401
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .common import get_submod_inputs
21
from ._adjacency_partitioner import partition as fast_partition
32
from ._global_partitioner import partition as global_partition
3+
from .common import get_submod_inputs

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,35 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence, Tuple
2+
from typing import Collection, Dict, List, Optional, Tuple
33

44
import torch
5-
5+
import torch.fx.passes.operator_support as ops
6+
from torch.fx.node import Target
67
from torch.fx.passes.splitter_base import (
8+
FxNetAccFusionsFinder,
9+
FxNetAccNodesFinder,
710
Subgraph,
811
_SplitterBase,
912
_SplitterSettingBase,
10-
FxNetAccNodesFinder,
11-
FxNetAccFusionsFinder,
1213
)
13-
import torch.fx.passes.operator_support as ops
14-
from torch.fx.passes.tools_common import NodeSet, CALLABLE_NODE_OPS
15-
from torch.fx.node import Target
16-
17-
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
18-
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
19-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
20-
14+
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
2115
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
16+
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
17+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
2218

19+
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
2320

2421
logger = logging.getLogger(__name__)
2522

2623

27-
class OpSupportTester(ops.OperatorSupportBase):
24+
class OpSupportTester(ops.OperatorSupportBase): # type: ignore
2825
"""Class to determine whether operators within a module are supported"""
2926

30-
def __init__(self, torch_executed_ops: Sequence[Target] = set()) -> None:
27+
def __init__(self, torch_executed_ops: Collection[Target] = set()) -> None:
3128
super().__init__()
3229

3330
# Initialize sets of supported/unsupported operators
34-
self.supported_operators = {}
35-
self.unsupported_operators = {}
31+
self.supported_operators: Dict[str, int] = {}
32+
self.unsupported_operators: Dict[str, int] = {}
3633
self.torch_executed_ops = torch_executed_ops
3734

3835
def is_node_supported(
@@ -58,7 +55,7 @@ def is_node_supported(
5855

5956
return False
6057

61-
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
58+
def print_support_overview(self, num_trt_blocks: Optional[int] = None) -> None:
6259
if num_trt_blocks is not None:
6360
logger.debug(
6461
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
@@ -81,7 +78,7 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
8178
logger.debug("\nAll Nodes Supported\n")
8279

8380

84-
class TRTPartitioner(_SplitterBase):
81+
class TRTPartitioner(_SplitterBase): # type: ignore
8582
"""Partitioner to split an FX graph into subgraphs based on operator support
8683
8784
Adapted from, and modified for the Torch-TensorRT Dynamo case:
@@ -102,7 +99,7 @@ def __init__(
10299
module: torch.fx.GraphModule,
103100
operator_support: ops.OperatorSupportBase,
104101
allowed_single_node_partition_ops: Optional[
105-
Sequence[str]
102+
Collection[str]
106103
] = DEFAULT_SINGLE_NODE_PARTITIONS,
107104
min_block_size: int = MIN_BLOCK_SIZE,
108105
):
@@ -141,7 +138,7 @@ def __init__(
141138
self.non_acc_submodule_name = "_run_on_gpu_"
142139
self._node_submodule_map: Dict[str, str] = {}
143140

144-
self.num_trt_accelerated_subgraphs = None
141+
self.num_trt_accelerated_subgraphs: Optional[int] = None
145142
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
146143

147144
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
@@ -152,10 +149,13 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
152149
result: List[Subgraph] = []
153150
for subgraph in subgraphs:
154151
if subgraph.is_acc:
155-
if len(subgraph.nodes) >= self.settings.min_acc_module_size or any(
156-
ConverterRegistry.qualified_name_or_str(node.target)
157-
in self.allowed_single_node_partition_ops
158-
for node in subgraph.nodes
152+
if len(subgraph.nodes) >= self.settings.min_acc_module_size or (
153+
self.allowed_single_node_partition_ops is not None
154+
and any(
155+
ConverterRegistry.qualified_name_or_str(node.target)
156+
in self.allowed_single_node_partition_ops
157+
for node in subgraph.nodes
158+
)
159159
):
160160
result.append(subgraph)
161161
else:
@@ -214,7 +214,7 @@ def partition(
214214
gm: torch.fx.GraphModule,
215215
verbose: bool = True,
216216
min_block_size: int = MIN_BLOCK_SIZE,
217-
torch_executed_ops: Sequence[Target] = set(),
217+
torch_executed_ops: Collection[Target] = set(),
218218
) -> torch.fx.GraphModule:
219219
"""Partition an FX GraphModule with aten ops into TRT engines
220220
Partitioning is based on converter operator support
@@ -223,7 +223,7 @@ def partition(
223223
gm: FX GraphModule to partition
224224
verbose: Bool representing whether to print operator support
225225
min_block_size: Minimum number of operators per TRT-Engine Block
226-
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
226+
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
227227
Returns:
228228
torch.fx.GraphModule
229229
"""

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence
2+
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set
33

44
import torch
5-
6-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
7-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
85
from torch.fx.graph_module import GraphModule
9-
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
106
from torch.fx.node import _get_qualified_name
117
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
128
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
139
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
1410
from torch_tensorrt.dynamo.conversion.converter_registry import (
1511
DYNAMO_CONVERTERS as CONVERTERS,
1612
)
17-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import SUBSTITUTION_REGISTRY
13+
14+
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
1815

1916
logger = logging.getLogger(__name__)
2017

@@ -40,7 +37,7 @@ def __init__(
4037
*,
4138
non_compute_ops: Optional[Sequence[str]] = None,
4239
allowed_single_node_partition_ops: Optional[
43-
Sequence[str]
40+
Collection[str]
4441
] = DEFAULT_SINGLE_NODE_PARTITIONS,
4542
min_block_size: int = MIN_BLOCK_SIZE,
4643
) -> None:

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import torch
21
import logging
3-
from typing import Sequence, Set
4-
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
2+
from typing import Any, Optional, Sequence, Set
3+
4+
import torch
55
from torch.fx.node import _get_qualified_name
6+
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
67

7-
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
8+
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = {
89
_get_qualified_name(to_replace.new_operator)
910
for to_replace in SUBSTITUTION_REGISTRY.values()
10-
)
11+
}
1112

1213

1314
logger = logging.getLogger(__name__)
@@ -17,7 +18,7 @@ def get_submod_inputs(
1718
mod: torch.fx.GraphModule,
1819
submod: torch.fx.GraphModule,
1920
inputs: Sequence[torch.Tensor],
20-
) -> Sequence[torch.Tensor]:
21+
) -> Optional[Sequence[torch.Tensor]]:
2122
"""Helper function to get inputs to a Torch submodule
2223
2324
Args:
@@ -29,9 +30,10 @@ def get_submod_inputs(
2930
"""
3031
acc_inputs = None
3132

32-
def get_input(self, inputs):
33+
def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None:
3334
nonlocal acc_inputs
3435
acc_inputs = inputs
36+
return
3537

3638
handle = submod.register_forward_pre_hook(get_input)
3739
mod(*inputs)

0 commit comments

Comments
 (0)