Skip to content

Commit 918e983

Browse files
committed
chore: ready to start review
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 15d7be1 commit 918e983

File tree

101 files changed

+524
-585
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+524
-585
lines changed

.pre-commit-config.yaml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ repos:
55
hooks:
66
- id: check-yaml
77
- id: trailing-whitespace
8+
exclude: ^docs
89
- id: check-added-large-files
910
args:
1011
- --maxkb=1000
@@ -13,11 +14,7 @@ repos:
1314
- id: mixed-line-ending
1415
args:
1516
- --fix=lf
16-
- repo: https://github.com/psf/black
17-
rev: 23.7.0
18-
hooks:
19-
- id: black
20-
exclude: ^examples/custom_converters/elu_converter/setup.py
17+
exclude: ^docs
2118
- repo: https://github.com/pre-commit/mirrors-clang-format
2219
rev: v16.0.6
2320
hooks:
@@ -30,21 +27,26 @@ repos:
3027
args:
3128
- --warnings=all
3229
- id: buildifier-lint
33-
- repo: https://github.com/astral-sh/ruff-pre-commit
34-
# Ruff version.
35-
rev: v0.0.278
36-
hooks:
37-
- id: ruff
3830
- repo: https://github.com/abravalheri/validate-pyproject
3931
rev: v0.13
4032
hooks:
4133
- id: validate-pyproject
34+
python_version: "3.11"
4235
- repo: https://github.com/pre-commit/mirrors-mypy
4336
rev: 'v1.4.1'
4437
hooks:
4538
- id: mypy
46-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools"
47-
python_version: "3.11"
39+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
40+
- repo: https://github.com/astral-sh/ruff-pre-commit
41+
# Ruff version.
42+
rev: v0.0.278
43+
hooks:
44+
- id: ruff
45+
- repo: https://github.com/psf/black
46+
rev: 23.7.0
47+
hooks:
48+
- id: black
49+
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
4850
- repo: local
4951
hooks:
5052
- id: dont-commit-upstream

core/conversion/evaluators/aten.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@ DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
104104
"aten::pow.float_int(float a, int b) -> (float)",
105105
}));
106106

107-
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
107+
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
108+
and,
109+
"aten::__and__",
110+
a&& b,
111+
bool,
112+
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
108113
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
109114
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
110115
xor,

docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# %%
1818

19+
1920
# We begin by defining a model
2021
class Model(torch.nn.Module):
2122
def __init__(self) -> None:

docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# %%
1818

19+
1920
# We begin by defining a model
2021
class Model(torch.nn.Module):
2122
def __init__(self) -> None:

py/torch_tensorrt/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def _find_lib(name: str, paths: List[str]) -> str:
8282

8383
import torch
8484
from torch_tensorrt._compile import * # noqa: F403
85+
from torch_tensorrt._Device import Device # noqa: F401
8586
from torch_tensorrt._enums import * # noqa: F403
86-
from torch_tensorrt._util import * # noqa: F403
87-
from torch_tensorrt._Input import Input # noqa: F401
88-
from torch_tensorrt._Device import Device # noqa: F401
87+
from torch_tensorrt._Input import Input # noqa: F401
88+
from torch_tensorrt._utils import * # noqa: F403
8989

9090
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9191
from torch_tensorrt import dynamo # noqa: F401

py/torch_tensorrt/_compile.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, Callable, List, Optional, Set, TypeGuard
2+
from typing import Any, Callable, List, Optional, Set, TypeGuard, Sequence
33

44
import torch
55
import torch.fx
@@ -15,13 +15,13 @@
1515

1616

1717
def _non_fx_input_interface(
18-
inputs: List[Input | torch.Tensor | InputTensorSpec],
18+
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
1919
) -> TypeGuard[List[Input | torch.Tensor]]:
2020
return all(isinstance(i, torch.Tensor | Input) for i in inputs)
2121

2222

2323
def _fx_input_interface(
24-
inputs: List[Input | torch.Tensor | InputTensorSpec],
24+
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
2525
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
2626
return all(isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs)
2727

@@ -97,7 +97,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
9797
def compile(
9898
module: Any,
9999
ir: str = "default",
100-
inputs: Optional[List[Input | torch.Tensor | InputTensorSpec]] = None,
100+
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
101101
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
102102
**kwargs: Any,
103103
) -> (
@@ -201,7 +201,7 @@ def compile(
201201
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
202202

203203

204-
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
204+
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
205205
"""
206206
Returns a boxed model which is the output of torch.compile.
207207
This does not compile the model to TRT. Execute this model on
@@ -216,8 +216,8 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
216216

217217
def convert_method_to_trt_engine(
218218
module: Any,
219-
inputs: List[Input | torch.Tensor],
220-
method_name: str,
219+
method_name: str = "forward",
220+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
221221
ir: str = "default",
222222
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
223223
**kwargs: Any,

py/torch_tensorrt/_enums.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from torch_tensorrt._C import dtype, EngineCapability, TensorFormat # noqa: F401
1+
from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401
2+
23
from tensorrt import DeviceType # noqa: F401

py/torch_tensorrt/_util.py renamed to py/torch_tensorrt/_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from typing import Any
2+
13
import torch
2-
from torch_tensorrt import _C, __version__
4+
from torch_tensorrt._version import __version__
5+
from torch_tensorrt import _C
36

47

58
def dump_build_info() -> None:
@@ -30,7 +33,7 @@ def set_device(gpu_id: int) -> None:
3033
_C.set_device(gpu_id)
3134

3235

33-
def sanitized_torch_version() -> str:
36+
def sanitized_torch_version() -> Any:
3437
return (
3538
torch.__version__
3639
if ".nv" not in torch.__version__

py/torch_tensorrt/dynamo/_SourceIR.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class SourceIR(Enum):
99
TORCHTRT_LOWERED = auto()
1010
UNKNOWN = auto()
1111

12-
def __str__(self):
12+
def __str__(self) -> str:
1313
if self == SourceIR.NN:
1414
return "nn"
1515
elif self == SourceIR.ACC:

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from packaging import version
2-
from torch_tensorrt._util import sanitized_torch_version
2+
from torch_tensorrt._utils import sanitized_torch_version
33

44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
55
from ._settings import * # noqa: F403
6-
from .conversion import * # noqa: F403
7-
from .aten_tracer import trace # noqa: F403
8-
from .conversion.converter_registry import (
9-
DYNAMO_CONVERTERS, # noqa: F403
10-
dynamo_tensorrt_converter, # noqa: F403
11-
)
12-
from .compile import compile # noqa: F403
13-
from ._SourceIR import SourceIR # noqa: F403
6+
from ._SourceIR import SourceIR # noqa: F403
7+
from .aten_tracer import trace # noqa: F403
8+
from .compile import compile # noqa: F403
9+
from .conversion import * # noqa: F403
10+
from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403
11+
from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403

0 commit comments

Comments
 (0)