-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Transition export workflows to use torch._export APIs #2195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
f1f202e
feat: Move tracing to use aot export apis
peri044 abaf047
chore: minor changes
peri044 bb1f3cf
chore: minor changes
peri044 3d05b4d
chore: Rebase with main
peri044 8d99be5
chore: rebase
peri044 0aad214
chore: minor logging updates
peri044 8899735
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 8af2627
fix: Refactor tensor freezing in Dynamo
gs-olive f6969be
Key op fixes for failing tests
gs-olive bad1594
fix: Add constant folding utility to freezing
gs-olive db56dd6
chore: Move to new export APIs
peri044 bf961f5
chore: rebase with dynamo_tensor_freeze branch
peri044 b13aa82
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive dd95620
fix: Refactor tensor freezing in Dynamo
gs-olive 6bd3c64
Key op fixes for failing tests
gs-olive 248073f
fix: Add constant folding utility to freezing
gs-olive 3e5f434
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 6bf6945
fix: Refactor tensor freezing in Dynamo
gs-olive 3b6e1e7
Key op fixes for failing tests
gs-olive 2107d8e
fix: Add constant folding utility to freezing
gs-olive fd5a41e
chore: add BERT test case
peri044 f047651
chore: remove pdb
peri044 ab76c0d
chore: rebase
peri044 e4df382
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive d022f4a
fix: Refactor tensor freezing in Dynamo
gs-olive 9610ba7
Key op fixes for failing tests
gs-olive e19aae7
fix: Add constant folding utility to freezing
gs-olive 2860be6
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 51266db
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 2005db7
fix: Add constant folding utility to freezing
gs-olive a8cb1fe
fix: Move tracer code into try/except
gs-olive 7ff9309
Custom implementation of AOT for compile
gs-olive 692921e
Move fixes into Dynamo directory
gs-olive e926724
chore: rebase
peri044 27681c2
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 056cbf3
fix: Add constant folding utility to freezing
gs-olive ece276c
fix: Move tracer code into try/except
gs-olive 73a0bce
Custom implementation of AOT for compile
gs-olive 890ba72
Move fixes into Dynamo directory
gs-olive 980dc1c
chore: rebase
peri044 dfc4899
Move fixes into Dynamo directory
gs-olive 09b099a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 157bb2d
chore: updates
peri044 0005a31
Move fixes into Dynamo directory
gs-olive 5526bca
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 3420fb0
chore: updates
peri044 399f929
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 4b44ff2
fix: Add constant folding utility to freezing
gs-olive a94a075
fix: Move tracer code into try/except
gs-olive 4e308f1
Custom implementation of AOT for compile
gs-olive 95d3f98
Move fixes into Dynamo directory
gs-olive 529262a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 aee529b
chore: rebase
peri044 e6d2d8d
chore: address review comments
peri044 6cd2bab
Merge branch 'main' into export_prototype
peri044 c7b2f3c
chore: fix imports
peri044 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,160 +1,34 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
import logging | ||
import sys | ||
from contextlib import contextmanager | ||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union | ||
import unittest.mock | ||
from typing import Any, Tuple | ||
|
||
import torch | ||
import torch._dynamo as torchdynamo | ||
from torch.fx.passes.infra.pass_base import PassResult | ||
from torch_tensorrt.dynamo.utils import req_torch_version | ||
from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( | ||
compose_bmm, | ||
compose_chunk, | ||
compose_getitem_slice, | ||
remove_ops, | ||
replace_aten_op_with_indices, | ||
replace_aten_reshape_alias_with_replace, | ||
replace_builtin_ops, | ||
replace_inplace_ops, | ||
replace_native_layernorm_with_layernorm, | ||
replace_transpose_mm_op_with_linear, | ||
run_const_fold, | ||
) | ||
from typing_extensions import TypeAlias | ||
|
||
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]] | ||
from torch._export import export | ||
from torch_tensorrt.dynamo.backend.backends import constant_fold | ||
from torch_tensorrt.dynamo.lowering import get_decompositions | ||
from torch_tensorrt.dynamo.utils import set_log_level | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DynamoConfig: | ||
""" | ||
Manage Exir-specific configurations of Dynamo. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
capture_scalar_outputs: bool = True, | ||
guard_nn_modules: bool = True, | ||
dynamic_shapes: bool = True, | ||
specialize_int: bool = True, | ||
verbose: bool = True, | ||
) -> None: | ||
self.capture_scalar_outputs = capture_scalar_outputs | ||
self.guard_nn_modules = guard_nn_modules | ||
self.dynamic_shapes = dynamic_shapes | ||
self.specialize_int = specialize_int | ||
self.verbose = verbose | ||
|
||
def activate(self) -> None: | ||
torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs | ||
torchdynamo.config.guard_nn_modules = self.guard_nn_modules | ||
torchdynamo.config.dynamic_shapes = self.dynamic_shapes | ||
torchdynamo.config.specialize_int = self.specialize_int | ||
torchdynamo.config.verbose = self.verbose | ||
|
||
def deactivate(self) -> None: | ||
torchdynamo.config.capture_scalar_outputs = True | ||
torchdynamo.config.guard_nn_modules = True | ||
torchdynamo.config.dynamic_shapes = True | ||
torchdynamo.config.specialize_int = True | ||
torchdynamo.config.verbose = True | ||
|
||
|
||
@contextmanager | ||
def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]: | ||
config.activate() | ||
try: | ||
yield config | ||
finally: | ||
config.deactivate() | ||
|
||
|
||
@contextmanager | ||
def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: | ||
""" | ||
Temporarily increase the python interpreter stack recursion limit. | ||
This is mostly used for pickling large scale modules. | ||
""" | ||
default = sys.getrecursionlimit() | ||
if limit > default: | ||
sys.setrecursionlimit(limit) | ||
try: | ||
yield | ||
finally: | ||
sys.setrecursionlimit(default) | ||
|
||
|
||
@req_torch_version("2.dev") | ||
def dynamo_trace( | ||
f: Callable[..., Value], | ||
# pyre-ignore | ||
args: Tuple[Any, ...], | ||
aten_graph: bool, | ||
tracing_mode: str = "real", | ||
dynamo_config: Optional[DynamoConfig] = None, | ||
) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: | ||
""" | ||
TODO: Once we fully migrate to torchdynamo frontend, we will remove | ||
this config option alltogether. For now, it helps with quick | ||
experiments with playing around with TorchDynamo | ||
""" | ||
if dynamo_config is None: | ||
dynamo_config = DynamoConfig() | ||
with using_config(dynamo_config), setting_python_recursive_limit(2000): | ||
torchdynamo.reset() | ||
try: | ||
return torchdynamo.export( | ||
f, | ||
*copy.deepcopy(args), | ||
aten_graph=aten_graph, | ||
tracing_mode=tracing_mode, | ||
) | ||
except torchdynamo.exc.Unsupported as exc: | ||
raise RuntimeError( | ||
"The user code is using a feature we don't support. " | ||
"Please try torchdynamo.explain() to get possible the reasons", | ||
) from exc | ||
except Exception as exc: | ||
raise RuntimeError( | ||
"torchdynamo internal error occured. Please see above stacktrace" | ||
) from exc | ||
|
||
|
||
@req_torch_version("2.dev") | ||
def trace( | ||
model: torch.nn.Module | torch.fx.GraphModule, | ||
inputs: Tuple[Any, ...], | ||
**kwargs: Any, | ||
) -> torch.fx.GraphModule: | ||
""" | ||
Optimized trace with necessary passes which re-compose some ops or replace some ops | ||
These passes should be general and functional purpose | ||
""" | ||
passes_list = [ | ||
compose_bmm, | ||
compose_chunk, | ||
compose_getitem_slice, | ||
replace_aten_reshape_alias_with_replace, | ||
replace_aten_op_with_indices, | ||
replace_transpose_mm_op_with_linear, # after compose_bmm | ||
replace_native_layernorm_with_layernorm, | ||
remove_ops, | ||
replace_builtin_ops, # after replace_native_layernorm_with_layernorm | ||
replace_inplace_ops, # remove it once functionalization is enabled | ||
] | ||
|
||
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic") | ||
|
||
for passes in passes_list: | ||
pr: PassResult = passes(fx_module) | ||
fx_module = pr.graph_module | ||
|
||
fx_module(*inputs) | ||
|
||
fx_module = run_const_fold(fx_module) | ||
logger.info("Post export graph : %s\n", fx_module.graph) | ||
return fx_module | ||
# Set log level at the top of compilation (torch_tensorrt.dynamo) | ||
if "debug" in kwargs and kwargs["debug"]: | ||
set_log_level(logger.parent, logging.DEBUG) | ||
|
||
experimental_decompositions = kwargs.get( | ||
"enable_experimental_decompositions", False | ||
) | ||
with unittest.mock.patch( | ||
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) | ||
): | ||
graph_module = export(model, tuple(inputs)).module() | ||
constant_fold(graph_module) | ||
logger.debug("Post export graph: " + str(graph_module.graph)) | ||
return graph_module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.