From f1098f2520f30a6082597c93895c350905c8245d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 20 Mar 2023 14:45:35 -0700 Subject: [PATCH 01/45] feat: Add sample torch.compile backend for tensorrt aten path - Add backend adapted from previous `fx2trt_compiler` provided by Dynamo - Currently, the TRTSplitter needs work to fully support the `aten` path - Additionally, the existing `aten` pass was reworked to exclude the `torch._dynamo.export` call, which may be necessary here --- .../fx/tracer/dispatch_tracer/aten_tracer.py | 8 +- .../tensorrt_dynamo_backend.py | 107 ++++++++++++++++++ 2 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index e60c8f8d13..356ddc978e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -130,7 +130,7 @@ def trace(f, args, *rest): @req_torch_version("2.dev") -def opt_trace(f, args, *rest): +def opt_trace(f, args, perform_trace=True, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose @@ -148,7 +148,11 @@ def opt_trace(f, args, *rest): replace_inplace_ops, # remove it once functionalization is enabled ] - fx_module, _ = trace(f, args) + if perform_trace: + fx_module, _ = trace(f, args) + else: + fx_module = f + print(fx_module.graph) for passes in passes_list: pr: PassResult = passes(fx_module) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py new file mode 100644 index 0000000000..bb6e68b0b5 --- /dev/null +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -0,0 +1,107 @@ +import torch +import traceback +import torch._dynamo as td + +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt +from torch_tensorrt.fx.tools.trt_splitter import ( + TRTSplitter, + TRTSplitterSetting, +) +from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + +MAX_SPLITS_THRESHOLD = 10 + + +def tensorrt_backend(gm, sample_inputs): + # Invoke AOTAutograd to compile model + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(fx2trt_compiler), + ) + + +def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): + model = gm + inputs = example_inputs + + # Perform lowering pass on model + model = aten_tracer.opt_trace(model, inputs, perform_trace=False) + + # Split out unsupported ops --> Needs rewrite/revision for ATEN + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + + splitter.node_support_preview() + split_mod = splitter() + num_piece = 0 + + for name, _ in split_mod.named_children(): + print(f"Graph is split into {name}") + num_pieces += 1 + + # Select threshold above which segmentation is not beneficial and run graph in Torch + if num_pieces > MAX_SPLITS_THRESHOLD: + raise AssertionError( + f"The graph module is split into {num_piece} which is large than the \ + threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." + ) + + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + + return split_mod + + +@td.register_backend +def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): + try: + trt_compiled = fx2trt(gm, example_inputs) + return trt_compiled + except Exception: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" + ) + return gm.forward From 243bf9bc340e27837a33c3d6fc3c0998381aff0a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 21 Mar 2023 16:17:51 -0700 Subject: [PATCH 02/45] Add decompositions to aot call --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index bb6e68b0b5..a76162b93b 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -17,6 +17,9 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._inductor.decomposition import decompositions + +DECOMPOSITIONS = decompositions.copy() MAX_SPLITS_THRESHOLD = 10 @@ -26,6 +29,7 @@ def tensorrt_backend(gm, sample_inputs): gm, sample_inputs, fw_compiler=make_boxed_compiler(fx2trt_compiler), + decompositions=DECOMPOSITIONS, ) From 76fd3c8207bdf017af294f1883863a755045b1a8 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:31:22 -0700 Subject: [PATCH 03/45] Mark FX2TRT converter as fake tensor unsupported --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index a76162b93b..20cea4ffd5 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -15,6 +15,8 @@ from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision +from torch._dynamo.backends.common import fake_tensor_unsupported + from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler from torch._inductor.decomposition import decompositions @@ -99,6 +101,7 @@ def get_input(self, inputs): @td.register_backend +@fake_tensor_unsupported def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): try: trt_compiled = fx2trt(gm, example_inputs) From 6a8102c14f3c0fa7a200222979888e9d213d0d84 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 28 Mar 2023 18:52:12 -0700 Subject: [PATCH 04/45] Minor naming bugfix --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index 20cea4ffd5..55c5e2df33 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -49,7 +49,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): splitter.node_support_preview() split_mod = splitter() - num_piece = 0 + num_pieces = 0 for name, _ in split_mod.named_children(): print(f"Graph is split into {name}") @@ -58,7 +58,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): # Select threshold above which segmentation is not beneficial and run graph in Torch if num_pieces > MAX_SPLITS_THRESHOLD: raise AssertionError( - f"The graph module is split into {num_piece} which is large than the \ + f"The graph module is split into {num_pieces} which is large than the \ threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." ) From 5dd1a5002d3ee2c0a6d521641aa85cced7572040 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 4 Apr 2023 17:35:12 -0700 Subject: [PATCH 05/45] feat: Initial refactoring of fx2trt in dynamo namespace Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/README.md | 21 + py/torch_tensorrt/dynamo/__init__.py | 15 + py/torch_tensorrt/dynamo/diagnostics.py | 287 ++++++++ py/torch_tensorrt/dynamo/fx2trt.py | 377 +++++++++++ py/torch_tensorrt/dynamo/input_tensor_spec.py | 180 +++++ py/torch_tensorrt/dynamo/lower.py | 306 +++++++++ py/torch_tensorrt/dynamo/lower_setting.py | 92 +++ py/torch_tensorrt/dynamo/observer.py | 194 ++++++ py/torch_tensorrt/dynamo/passes/__init__.py | 0 py/torch_tensorrt/dynamo/passes/graph_opts.py | 74 ++ .../dynamo/passes/lower_basic_pass.py | 632 ++++++++++++++++++ .../dynamo/passes/lower_basic_pass_aten.py | 525 +++++++++++++++ .../passes/lower_pass_manager_builder.py | 325 +++++++++ py/torch_tensorrt/dynamo/passes/pass_utils.py | 296 ++++++++ .../passes/remove_duplicate_output_args.py | 140 ++++ py/torch_tensorrt/dynamo/tools/__init__.py | 1 + .../dynamo/tools/common_fx2trt.py | 445 ++++++++++++ .../dynamo/tools/engine_layer_visualize.py | 217 ++++++ py/torch_tensorrt/dynamo/tools/graph_util.py | 78 +++ .../dynamo/tools/model_packager.py | 126 ++++ .../dynamo/tools/node_profiler.py | 53 ++ py/torch_tensorrt/dynamo/tools/tensor_prop.py | 33 + .../dynamo/tools/timing_cache_utils.py | 39 ++ .../dynamo/tools/trt_minimizer.py | 101 +++ .../dynamo/tools/trt_profiler_sorted.py | 58 ++ .../dynamo/tools/trt_splitter.py | 138 ++++ py/torch_tensorrt/dynamo/trt_module.py | 239 +++++++ py/torch_tensorrt/dynamo/types.py | 24 + py/torch_tensorrt/dynamo/utils.py | 140 ++++ 29 files changed, 5156 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/README.md create mode 100644 py/torch_tensorrt/dynamo/__init__.py create mode 100644 py/torch_tensorrt/dynamo/diagnostics.py create mode 100644 py/torch_tensorrt/dynamo/fx2trt.py create mode 100644 py/torch_tensorrt/dynamo/input_tensor_spec.py create mode 100644 py/torch_tensorrt/dynamo/lower.py create mode 100644 py/torch_tensorrt/dynamo/lower_setting.py create mode 100644 py/torch_tensorrt/dynamo/observer.py create mode 100644 py/torch_tensorrt/dynamo/passes/__init__.py create mode 100644 py/torch_tensorrt/dynamo/passes/graph_opts.py create mode 100644 py/torch_tensorrt/dynamo/passes/lower_basic_pass.py create mode 100644 py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py create mode 100644 py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py create mode 100644 py/torch_tensorrt/dynamo/passes/pass_utils.py create mode 100644 py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py create mode 100644 py/torch_tensorrt/dynamo/tools/__init__.py create mode 100644 py/torch_tensorrt/dynamo/tools/common_fx2trt.py create mode 100644 py/torch_tensorrt/dynamo/tools/engine_layer_visualize.py create mode 100644 py/torch_tensorrt/dynamo/tools/graph_util.py create mode 100644 py/torch_tensorrt/dynamo/tools/model_packager.py create mode 100644 py/torch_tensorrt/dynamo/tools/node_profiler.py create mode 100644 py/torch_tensorrt/dynamo/tools/tensor_prop.py create mode 100644 py/torch_tensorrt/dynamo/tools/timing_cache_utils.py create mode 100644 py/torch_tensorrt/dynamo/tools/trt_minimizer.py create mode 100644 py/torch_tensorrt/dynamo/tools/trt_profiler_sorted.py create mode 100644 py/torch_tensorrt/dynamo/tools/trt_splitter.py create mode 100644 py/torch_tensorrt/dynamo/trt_module.py create mode 100644 py/torch_tensorrt/dynamo/types.py create mode 100644 py/torch_tensorrt/dynamo/utils.py diff --git a/py/torch_tensorrt/dynamo/README.md b/py/torch_tensorrt/dynamo/README.md new file mode 100644 index 0000000000..d53f43a1d4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/README.md @@ -0,0 +1,21 @@ +FX2TRT is merged as FX module in Torch-TensorRT + +- The user guide is in [link](../../../docsrc/tutorials/getting_started_with_fx_path.rst#installation) +- The examples are moved to [link](../../../examples/fx) + +* Method 1. Follow the instrucions for Torch-TensorRT +* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path +``` + $ conda create --name python_env python=3.8 + $ conda activate python_env + # Recommend to install PyTorch 1.12 and later + $ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly + # Install TensorRT python package + $ pip3 install nvidia-pyindex + $ pip3 install tensorrt==8.5.1.7 + $ git clone https://github.com/pytorch/TensorRT.git + $ cd TensorRT/py && python setup.py install --fx-only && cd .. + $ pyton -c "import torch_tensorrt.fx" + # Test an example by + $ python py/torch_tensorrt/fx/example/lower_example.py +``` diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py new file mode 100644 index 0000000000..8c40ecac76 --- /dev/null +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -0,0 +1,15 @@ +import logging + +from torch_tensorrt.fx.converter_registry import ( # noqa + CONVERTERS, + NO_EXPLICIT_BATCH_DIM_SUPPORT, + NO_IMPLICIT_BATCH_DIM_SUPPORT, + tensorrt_converter, +) +from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa +from .input_tensor_spec import InputTensorSpec # noqa +from .lower_setting import LowerSetting # noqa +from .trt_module import TRTModule # noqa +from .lower import compile # usort: skip #noqa + +logging.basicConfig(level=logging.INFO) diff --git a/py/torch_tensorrt/dynamo/diagnostics.py b/py/torch_tensorrt/dynamo/diagnostics.py new file mode 100644 index 0000000000..0ba2a30652 --- /dev/null +++ b/py/torch_tensorrt/dynamo/diagnostics.py @@ -0,0 +1,287 @@ +import contextlib +import inspect +import logging +import os +import os.path +import shutil +import tempfile +import time +import traceback +import typing as t +from contextvars import ContextVar +from dataclasses import dataclass + +TWrite = t.Union[str, bytes] +WriteObj = t.Union[TWrite, t.Callable[[], TWrite]] + +_CURRENT_WRITER: ContextVar["DiagnosticsWriter"] = ContextVar("_CURRENT_WRITER") +_CURRENT_COLLECTOR: ContextVar["DiagnosticsCollector"] = ContextVar( + "_CURRENT_COLLECTOR" +) +# Allows a collector to indicate subsequent collections should be suppressed to +# avoid duplicate collections. +_SUBSEQUENT_COLLECT_SUPPRESSED_BY: ContextVar[object] = ContextVar( + "_SUBSEQUENT_COLLECT_SUPPRESSED_BY" +) +# Indicates current execution context is within a context manager by +# `collect_when`. Only when it's set do we actually write diagnostics. +_IS_IN_COLLECT_CONTEXT: ContextVar[bool] = ContextVar("_IS_IN_COLLECT_CONTEXT") +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class CollectionConditionContext: + exception: t.Optional[Exception] + + +CollectionCondition = t.Callable[[CollectionConditionContext], bool] + + +def collect_when( + condition: "CollectionCondition", supress_subsequent_collect: bool = True +): + """See `DiagnosticsCollector.collect_when`""" + return get_current_collector().collect_when(condition, supress_subsequent_collect) + + +def collect(): + return collect_when(CollectionConditions.always()) + + +def collect_when_fail(): + return collect_when(CollectionConditions.when_fail()) + + +def write(file_name: str, text: WriteObj): + return get_current_writer().write(file_name, text) + + +def get_current_writer() -> "DiagnosticsWriter": + """Get the writer for current execution context. + + Lazily instantiates and registers one if not already done. + """ + current_writer = _CURRENT_WRITER.get(None) + if not current_writer: + current_writer = DiagnosticsWriter() + _CURRENT_WRITER.set(current_writer) + return current_writer + + +def get_current_collector() -> "DiagnosticsCollector": + current_collector = _CURRENT_COLLECTOR.get(None) + if not current_collector: + current_collector = DiagnosticsCollector() + _CURRENT_COLLECTOR.set(current_collector) + return current_collector + + +def set_current_collector(collector: "DiagnosticsCollector"): + _CURRENT_COLLECTOR.set(collector) + + +class DiagnosticsWriter: + + # the root dir in which the diagnostics will be written + _root_dir: str + + def __init__(self): + self._root_dir = tempfile.mkdtemp(prefix="fx2trt.") + _LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}") + + def write(self, file_name: str, data: WriteObj): + """ + TODO: Can be disabled by regex on file_name + """ + # Only write if we are inside a collect_when() context. + if not _IS_IN_COLLECT_CONTEXT.get(False): + return + + try: + res, err = _res_or_err(data) + if err: + to_write = err.encode("utf-8") + else: + if isinstance(res, str): + to_write = res.encode("utf-8") + elif isinstance(res, bytes): + to_write = res + else: + raise TypeError(f"Unknown data type: {type(res)}") + self._write(file_name, to_write) + except Exception as e: + # Log the error and swallow the exception, as this should not + # propagated into business logic + _LOGGER.warning(f"Error writing diagnostics: {e}") + + def root_dir(self) -> str: + return self._root_dir + + def _write(self, file_name: str, to_write: bytes): + # ms granularity - no naming collash, otherwise file will be + # overwritten. + ts = int(time.time() * 1000) + file_name = f"{file_name}.{ts}" + fn = os.path.join(self.root_dir(), file_name) + with open(fn, "wb") as f: + f.write(to_write) + + +class CollectionConditions: + @classmethod + def any(cls, *conditions: "CollectionCondition") -> "CollectionCondition": + return lambda ctx: any(cond(ctx) for cond in conditions) + + @classmethod + def all(cls, *conditions: "CollectionCondition") -> "CollectionCondition": + return lambda ctx: all(cond(ctx) for cond in conditions) + + @classmethod + def not_(cls, condition: "CollectionCondition") -> "CollectionCondition": + return lambda ctx: not condition(ctx) + + @classmethod + def always(cls) -> "CollectionCondition": + """Always collect""" + return lambda ctx: True + + @classmethod + def never(cls) -> "CollectionCondition": + """Never collect""" + return lambda ctx: False + + @classmethod + def when_fail(cls) -> "CollectionCondition": + """Collect when failed""" + ctx: CollectionConditionContext + return lambda ctx: ctx.exception is not None + + @classmethod + def when_called_by_function( + cls, func_name: str, match_prefix: bool = False + ) -> "CollectionCondition": + def _when_called_by_function(ctx: CollectionConditionContext) -> bool: + frames = inspect.stack() + for frame in frames: + if match_prefix: + if frame[3].startswith(func_name): + return True + else: + if frame[3] == func_name: + return True + return False + + return _when_called_by_function + + @classmethod + def when_not_in_tests(cls) -> CollectionCondition: + return CollectionConditions.not_( + CollectionConditions.when_called_by_function("test_", match_prefix=True) + ) + + +class DiagnosticsCollector: + @contextlib.contextmanager + def collect_when( + self, condition: "CollectionCondition", supress_subsequent_collect: bool = True + ): + """ + Context manager to collect diagnostics when the enclosed code completes + and *any* of the given condition is met. + + Args: + condition: + the condition only when met should the collection be done + supress_subsequent_collect: + When true, suppress any collections registered by this function + call. This is to ensure duplicate collections registered across + the callstack by different components. In this case, only the + outermost component will collect. + + When false, always collect (subject to given condition) regardless + of earlier collection registration's suppression. + + Returns: + a context manager that handles the collection when its enclosed + code finished run. + """ + this_collection_handle = object() + suppressed_by = _SUBSEQUENT_COLLECT_SUPPRESSED_BY.get(None) + reset_suppressed_by = False + if supress_subsequent_collect: + if suppressed_by and suppressed_by != this_collection_handle: + # Disable this collection since it's suppressed by a previously + # installed collection + condition = CollectionConditions.never() + else: + suppressed_by = this_collection_handle + _SUBSEQUENT_COLLECT_SUPPRESSED_BY.set(suppressed_by) + # don't forget to reset it in `finanlly` + reset_suppressed_by = True + + is_in_collect_context_tok = _IS_IN_COLLECT_CONTEXT.set(True) + exception: t.Optional[Exception] = None + try: + yield + except Exception as e: + exception = e + raise + finally: + if reset_suppressed_by: + _SUBSEQUENT_COLLECT_SUPPRESSED_BY.set(None) + if self._test_condition(condition, CollectionConditionContext(exception)): + try: + self.collect() + except Exception as e: + _LOGGER.warning( + f"Error while collecting diagnostics (THIS EXCEPTION IS HANDLED):\n" + f"{e}\n" + f"{traceback.format_exc()}" + ) + _IS_IN_COLLECT_CONTEXT.reset(is_in_collect_context_tok) + + def collect(self) -> str: + """Collect the diagnostics. Overridable in sub-classes.""" + return "" + + @classmethod + def _test_condition( + cls, cond: CollectionCondition, ctx: CollectionConditionContext + ) -> bool: + try: + return cond(ctx) + except Exception as e: + _LOGGER.warning(f"Error while testing condition: {e}") + return False + + +class ZipDiagnosticsCollector(DiagnosticsCollector): + _write: DiagnosticsWriter + _last_zip_path_for_test: str = "" # for test purpose only + + def __init__(self, writer: DiagnosticsWriter): + self._write = writer + + def collect(self) -> str: + _, fp = tempfile.mkstemp() + try: + zip_path = shutil.make_archive(fp, "zip", self._write.root_dir()) + self._last_zip_path_for_test = zip_path + return zip_path + finally: + os.remove(fp) + + +def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]: + if isinstance(data, (str, bytes)): + return data, "" + if not callable(data): + raise TypeError( + f"data must be a callable that returns actual data to" + f"write, but got {type(data)}" + ) + try: + return data(), "" + except Exception as e: + _LOGGER.warning(f"Error getting data to write: {e}") + return "", str(e) diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx2trt.py new file mode 100644 index 0000000000..4140a344f0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx2trt.py @@ -0,0 +1,377 @@ +import logging +import warnings +from datetime import datetime +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence + +import numpy + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +import torch.fx +from torch._ops import OpOverload +from torch.fx.node import _get_qualified_name +from torch.fx.passes.shape_prop import TensorMetadata + +from torch_tensorrt.dynamo import CONVERTERS +from .input_tensor_spec import InputTensorSpec +from .observer import Observer +from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ + Callable[[torch.fx.GraphModule], None] +] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") + + +class TRTInterpreterResult(NamedTuple): + engine: Any + input_names: Sequence[str] + output_names: Sequence[str] + serialized_cache: bytearray + + +class TRTInterpreter(torch.fx.Interpreter): + def __init__( + self, + module: torch.fx.GraphModule, + input_specs: List[InputTensorSpec], + explicit_batch_dimension: bool = False, + explicit_precision: bool = False, + logger_level=None, + ): + super().__init__(module) + + self.logger = trt.Logger(logger_level or trt.Logger.WARNING) + self.builder = trt.Builder(self.logger) + + flag = 0 + if explicit_batch_dimension: + EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH + ) + flag |= EXPLICIT_BATCH + + if explicit_precision: + EXPLICIT_PRECISION = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION + ) + flag |= EXPLICIT_PRECISION + self.network = self.builder.create_network(flag) + + missing_ops = self.validate_conversion() + if missing_ops: + warnings.warn( + "Interpretation will fail due to missing operations \n" + + "\n".join(f"{i}" for i in missing_ops) + ) + + self.optimization_profiles: Optional[List] = None + self.input_specs = input_specs + self.input_specs_iter = 0 + self.validate_input_specs() + self._cur_node_name: Optional[str] = None + self._input_names: List[str] = [] + self._output_names: List[str] = [] + self._itensor_to_tensor_meta: Dict[ + trt.tensorrt.ITensor, TensorMetadata + ] = dict() + + def validate_input_specs(self): + for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: + if not self.network.has_implicit_batch_dimension: + assert ( + has_batch_dim + ), "It's required to specify batch dimension when it's explicit in TensorRT network." + + dynamic_dims = get_dynamic_dims(shape) + if len(dynamic_dims): + assert not self.network.has_implicit_batch_dimension, ( + "Can't have dynamic dim when " + f"batch dim is implicit, got {shape}." + ) + assert len( + shape_ranges + ), "shape_ranges must be provided when shape has dynamic dim." + + if self.optimization_profiles: + assert len(shape_ranges) == len(self.optimization_profiles), ( + "Number of optimization " + f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" + f" {len(shape_ranges)} provided." + ) + else: + self.optimization_profiles = [ + self.builder.create_optimization_profile() + for _ in range(len(shape_ranges)) + ] + + for shape_range in shape_ranges: + assert ( + len(shape_range) == 3 + ), f"Expect three elements in shape_range, got {len(shape_range)}" + assert all(len(s) == len(shape) for s in shape_range), ( + "Expect elements in shape_range" + f" {shape_range} have the same number of dimension as the provided shape {len(shape)}" + ) + + for i in range(len(shape)): + if i in dynamic_dims: + assert all( + shape_range[j][i] <= shape_range[j + 1][i] + for j in range(2) + ), ( + "Expect dynamic dim" + f" {i} to have incremental value for shapes in shape_range {shape_range}." + ) + else: + assert all(s[i] == shape[i] for s in shape_range), ( + f"Expect non dynamic dim {i} to be the same" + f" for all shapes in shape_range {shape_range}." + ) + else: + assert ( + len(shape_ranges) == 0 + ), "shape_ranges are provided for input that doesn't have dynamic dim." + + def validate_conversion(self): + missing_converter = set() + + for node in self.module.graph.nodes: + if node.op == "call_function" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}") + elif node.op == "call_method" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} torch.Tensor.{node.target}") + elif node.op == "call_module": + submod = self.fetch_attr(node.target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + if not CONVERTERS.get(submod_type): + missing_converter.add(f"{node.op} {torch.typename(submod_type)}") + + return missing_converter + + def run( + self, + max_workspace_size=1 << 25, + lower_precision=LowerPrecision.FP16, + sparse_weights=False, + force_fp32_output=False, + strict_type_constraints=False, + algorithm_selector=None, + timing_cache=None, + profiling_verbosity=None, + tactic_sources=None, + ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + max_workspace_size: set to the maximum size we can afford for temporary buffer + lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). + sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity + force_fp32_output: force output to be fp32 + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + timing_cache: enable timing cache for TensorRT + profiling_verbosity: TensorRT logging level + Return: + TRTInterpreterResult + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and + # force_fp32_output=False. + self.output_fp16 = ( + not force_fp32_output and lower_precision == LowerPrecision.FP16 + ) + + if ( + lower_precision == LowerPrecision.INT8 + and not self.builder.platform_has_fast_int8 + ): + raise RuntimeError("Current platform doesn't support fast native int8!") + + if ( + lower_precision == LowerPrecision.FP16 + and not self.builder.platform_has_fast_fp16 + ): + warnings.warn("Current platform doesn't support fast native fp16!") + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + build_engine_start_time = datetime.now() + + builder_config = self.builder.create_builder_config() + builder_config.max_workspace_size = max_workspace_size + + cache = None + if timing_cache: + cache_file = numpy.array(timing_cache) + cache = builder_config.create_timing_cache(cache_file.tobytes()) + else: + cache = builder_config.create_timing_cache(b"") + builder_config.set_timing_cache(cache, False) + + if trt.__version__ >= "8.2": + builder_config.profiling_verbosity = ( + profiling_verbosity + if profiling_verbosity + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ) + if lower_precision == LowerPrecision.FP16: + builder_config.set_flag(trt.BuilderFlag.FP16) + + if lower_precision == LowerPrecision.INT8: + builder_config.set_flag(trt.BuilderFlag.INT8) + + if sparse_weights: + builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) + + if strict_type_constraints: + builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + if self.optimization_profiles: + for optimization_profile in self.optimization_profiles: + builder_config.add_optimization_profile(optimization_profile) + + if algorithm_selector: + builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) + builder_config.algorithm_selector = algorithm_selector + + if tactic_sources is not None: + builder_config.set_tactic_sources(tactic_sources=tactic_sources) + + engine = self.builder.build_engine(self.network, builder_config) + assert engine + + serialized_cache = ( + bytearray(cache.serialize()) + if builder_config.get_timing_cache() + else bytearray() + ) + _LOGGER.info( + f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" + ) + + return TRTInterpreterResult( + engine, self._input_names, self._output_names, serialized_cache + ) + + def run_node(self, n): + self._cur_node_name = str(n) + # add "_itensor_to_tensor_meta" + kwargs = dict(n.kwargs) + kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta + n.kwargs = kwargs + + # run the node + trt_node = super().run_node(n) + + # remove "_itensor_to_tensor_meta" + kwargs = dict(n.kwargs) + del kwargs["_itensor_to_tensor_meta"] + n.kwargs = kwargs + + if isinstance(trt_node, trt.tensorrt.ITensor): + self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") + + return trt_node + + def placeholder(self, target, args, kwargs): + self._input_names.append(target) + shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[ + self.input_specs_iter + ] + self.input_specs_iter += 1 + + if self.network.has_implicit_batch_dimension: + if has_batch_dim: + shape = shape[1:] + else: + for i, shape_range in enumerate(shape_ranges): + assert self.optimization_profiles + self.optimization_profiles[i].set_shape(target, *shape_range) + + return self.network.add_input( + name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype) + ) + + def call_module(self, target, args, kwargs): + assert isinstance(target, str) + submod = self.fetch_attr(target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + converter = CONVERTERS.get(submod_type) + + if not converter: + raise RuntimeError( + f"Conversion of module of type {submod_type} not currently supported!" + ) + + assert self._cur_node_name is not None + return converter(self.network, submod, args, kwargs, self._cur_node_name) + + def call_function(self, target, args, kwargs): + converter = CONVERTERS.get(target) + if not converter: + raise RuntimeError( + f"Conversion of function {torch.typename(target)} not currently supported!" + ) + + assert self._cur_node_name is not None + return converter(self.network, target, args, kwargs, self._cur_node_name) + + def call_method(self, target, args, kwargs): + assert isinstance(target, str) + converter = CONVERTERS.get(target) + + if not converter: + raise RuntimeError( + f"Conversion of method {target} not currently supported!" + ) + + assert self._cur_node_name is not None + return converter(self.network, target, args, kwargs, self._cur_node_name) + + def output(self, target, args, kwargs): + assert len(args) == 1 + if isinstance(args[0], tuple): + outputs = args[0] + elif isinstance(args[0], list): + outputs = tuple(args[0]) + else: + outputs = (args[0],) + + if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): + raise RuntimeError("TensorRT requires all outputs to be Tensor!") + + for i, output in enumerate(outputs): + if any( + op_name in output.name.split("_") + for op_name in ( + "eq", + "gt", + "lt", + "or", + "xor", + "and", + "not", + "ne", + "isinf", + "any", + ) + ): + output_bool = True + else: + output_bool = False + name = f"output{i}" + output.name = name + self.network.mark_output(output) + if output_bool: + output.dtype = trt.bool + elif self.output_fp16 and output.dtype == trt.float32: + output.dtype = trt.float16 + self._output_names.append(name) diff --git a/py/torch_tensorrt/dynamo/input_tensor_spec.py b/py/torch_tensorrt/dynamo/input_tensor_spec.py new file mode 100644 index 0000000000..1e64c31c59 --- /dev/null +++ b/py/torch_tensorrt/dynamo/input_tensor_spec.py @@ -0,0 +1,180 @@ +from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple + +import torch + +from .types import Shape, ShapeRange +from .utils import get_dynamic_dims +from torch_tensorrt._Input import Input + +class InputTensorSpec(NamedTuple): + """ + This class contains the information of a input tensor. + + shape: shape of the tensor. + + dtype: dtyep of the tensor. + + device: device of the tensor. This is only used to generate inputs to the given model + in order to run shape prop. For TensorRT engine, inputs have to be on cuda device. + + shape_ranges: If dynamic shape is needed (shape has dimensions of -1), then this field + has to be provided (default is empty list). Every shape_range is a tuple of three + tuples ((min_input_shape), (optimized_input_shape), (max_input_shape)). Each shape_range + is used to populate a TensorRT optimization profile. + e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize + for (25, 224) because it's the most common input shape, then we set shape_ranges to + ((1, 224), (25, 225), (100, 224)). + + has_batch_dim: Whether the shape includes batch dimension. Batch dimension has to be provided + if the engine want to run with dynamic shape. + """ + + shape: Shape + dtype: torch.dtype + device: torch.device = torch.device("cpu") + shape_ranges: List[ShapeRange] = [] + has_batch_dim: bool = True + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "InputTensorSpec": + """ + Produce an InputTenosrSpec named tuple which contains the + information of the given PyTorch tensor. + + Args: + tensor (torch.Tensor): A PyTorch tensor. + + Returns: + An InputTensorSpec named tuple. + """ + return cls(tensor.shape, tensor.dtype, tensor.device) + + @classmethod + def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec"]: + """ + Produce a list of InputTenosrSpec named tuples which contain + the information of all the given PyTorch tensors. + + Args: + tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. + + Returns: + A list of InputTensorSpec named tuples. + """ + assert isinstance(tensors, (list, tuple)) + return [cls.from_tensor(t) for t in tensors] + + @classmethod + def from_input(cls, input_obj: Input) -> "InputTensorSpec": + """ + Produce a list of InputTenosrSpec named tuples which contain + the information of all the given PyTorch tensors. + + Args: + tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. + + Returns: + A list of InputTensorSpec named tuples. + """ + assert isinstance(input_obj, Input) + input_spec = None + if isinstance(input_obj.shape, dict): + min_shape = input_obj.shape["min_shape"] + opt_shape = input_obj.shape["opt_shape"] + max_shape = input_obj.shape["max_shape"] + dyn_shape = [] + for min, opt, max in zip(min_shape, opt_shape, max_shape): + if min == opt == max: + dyn_shape.append(min) + else: + dyn_shape.append(-1) + dtype = input_obj.torch_dtype + input_spec = cls( + shape=dyn_shape, + dtype=dtype, + shape_ranges=[(min_shape, opt_shape, max_shape)], + ) + else: + shape = input_obj.shape + dtype = input_obj.torch_dtype + input_spec = cls(shape=shape, dtype=dtype) + + return input_spec + + @classmethod + def from_tensors_with_dynamic_batch_size( + cls, + tensors: Sequence[torch.Tensor], + batch_size_range: Tuple[int, int, int], + opt_profile_replica: int = 1, + batch_dims: Optional[List[int]] = None, + ) -> List["InputTensorSpec"]: + """ + Produce a list of InputTenosrSpec named tuples which would contain + the information of all the given PyTorch tensors. The produced input + tensor specs will treat all tensors' first dimension as batch dimension + and mark them as dynmaic. + + Args: + tensors (Sequence[torch.Tensor]): A list of PyTorch tensors. + batch_size_range (Tuple[int, int, int]): The first integer indicates + the smallest batch size allowed. The second integer indiceates + the batch size that we'll optimize for. The third integer indicates + the largest batch size allowed. + opt_profile_replica (int): If dynamic shape is enabled, each execution + context requires a different optimization profile. This arg determines + how many optimization profile replicas we want to produce. + batch_dims (Optional[List[int]]): The batch dim might not be the leading dim + and allow user to specify the batch dims using this arg. Default we treat + dim 0 as the batch dim. + + Returns: + A list of InputTensorSpec named tuples with dynamic ranges. + """ + if batch_dims is None: + batch_dims = [0] * len(tensors) + + input_specs = [] + batch_size = tensors[0].size(batch_dims[0]) + + for i, tensor in enumerate(tensors): + batch_dim = batch_dims[i] + assert batch_size == tensor.size( + batch_dim + ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." + shape = list(tensor.shape) + shape[batch_dim] = -1 + shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + input_specs.append( + cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) + ) + + return input_specs + + def to_random_tensor(self, id=1): + shape = tuple(self.shape) + if len(get_dynamic_dims(shape)): + # id=0 -> min shape + # id=1 -> optimal shape + # id=2 -> max shape + shape = tuple(self.shape_ranges[0][id]) + elif not self.has_batch_dim: + shape = (1,) + tuple(shape) + + return torch.randn(shape).to(dtype=self.dtype, device=self.device) + + @staticmethod + def create_inputs_from_specs(input_specs: Iterable["InputTensorSpec"]): + inputs = [] + for spec in input_specs: + inputs.append(spec.to_random_tensor()) + + return inputs + + @staticmethod + def create_inputs_from_max_specs(input_specs: Iterable["InputTensorSpec"]): + inputs = [] + for spec in input_specs: + inputs.append(spec.to_random_tensor(2)) + + return inputs diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py new file mode 100644 index 0000000000..75e0de9fe8 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lower.py @@ -0,0 +1,306 @@ +import dataclasses as dc +import logging +from typing import Any, Callable, Optional, Sequence + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +import torch.fx as fx +import torch.nn as nn +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch.fx.passes.splitter_base import SplitResult + +from .fx2trt import TRTInterpreter, TRTInterpreterResult +from .lower_setting import LowerSetting +from .passes.lower_pass_manager_builder import LowerPassManagerBuilder +from .passes.pass_utils import PassFunc, validate_inference +from .tools.timing_cache_utils import TimingCacheManager +from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting + +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer +from .trt_module import TRTModule +from .utils import LowerPrecision + +logger = logging.getLogger(__name__) + +Input = Sequence[Any] + + +def compile( + module: nn.Module, + inputs, + min_block_size: int = 10, + max_workspace_size=1 << 25, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + is_aten=False, + use_experimental_fx_rt=False, +) -> nn.Module: + """ + Takes in original module, input and lowering setting, run lowering workflow to turn module + into lowered module, or so called TRTModule. + + Args: + module: Original module for lowering. + input: Input for module. + min_block_size: Minimal number of nodes for an accelerated submodule + max_workspace_size: Maximum size of workspace given to TensorRT. + lower_precision: lower_precision config given to TRTModule. + verbose_log: Enable verbose log for TensorRT if set True. + timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. + save_timing_cache: Update timing cache with current timing cache data if set to True. + cuda_graph_batch_size: Cuda graph batch size, default to be -1. + use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + Returns: + A torch.nn.Module lowered by TensorRT. + """ + if use_experimental_fx_rt and not explicit_batch_dimension: + raise ValueError( + "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True" + ) + + lower_setting = LowerSetting( + min_block_size=min_block_size, + max_workspace_size=max_workspace_size, + lower_precision=lower_precision, + verbose_log=verbose_log, + timing_cache_prefix=timing_cache_prefix, + save_timing_cache=save_timing_cache, + cuda_graph_batch_size=cuda_graph_batch_size, + is_aten=is_aten, + use_experimental_rt=use_experimental_fx_rt, + ) + lowerer = Lowerer.create(lower_setting=lower_setting) + return lowerer(module, inputs) + + +@dc.dataclass +class LowerTrtInterpreter: + lower_setting: LowerSetting + timing_cache_manager: TimingCacheManager + + @classmethod + def create(cls, lower_setting): + timing_cache_manager = TimingCacheManager( + lower_setting.timing_cache_prefix, lower_setting.save_timing_cache + ) + return LowerTrtInterpreter(lower_setting, timing_cache_manager) + + def __call__(self, mod, input, split_name) -> TRTInterpreterResult: + assert self.lower_setting.input_specs, "Can't find input specs for lowering!" + logger.info( + f"split_name={split_name}, input_specs={self.lower_setting.input_specs}" + ) + + # Prepare algorithm selector and timing_cache for TRTInterpreter + algo_selector = None + if self.lower_setting.algo_selector: + algo_selector = self.lower_setting.algo_selector(f"{split_name}.json") + cache_data = None + if self.timing_cache_manager: + try: + cache_data = self.timing_cache_manager.get_timing_cache_trt(split_name) + logger.info("Timing cache is used!") + except Exception as e: + logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") + cache_data = None + + interpreter = TRTInterpreter( + mod, + input_specs=self.lower_setting.input_specs, + explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, + explicit_precision=self.lower_setting.explicit_precision, + logger_level=trt.Logger.VERBOSE + if self.lower_setting.verbose_log + else trt.Logger.WARNING, + ) + + interp_result: TRTInterpreterResult = interpreter.run( + max_workspace_size=self.lower_setting.max_workspace_size, + lower_precision=self.lower_setting.lower_precision, + strict_type_constraints=self.lower_setting.strict_type_constraints, + algorithm_selector=algo_selector, + timing_cache=cache_data, + profiling_verbosity=trt.ProfilingVerbosity.DETAILED + if self.lower_setting.verbose_profile + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, + tactic_sources=self.lower_setting.tactic_sources, + ) + + # Update timing cache file if needed + timing_cache = interp_result.serialized_cache + if timing_cache and self.timing_cache_manager: + self.timing_cache_manager.update_timing_cache(split_name, timing_cache) + + return interp_result + + +def default_split_function( + model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting +) -> SplitResult: + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension + splitter_setting.min_block_size = lower_setting.min_block_size + splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + splitter.node_support_preview() + return splitter.generate_split_results() + + +def create_lower_trt_interpreter(lower_setting: LowerSetting) -> LowerTrtInterpreter: + return LowerTrtInterpreter.create(lower_setting) + + +def default_lower_pass( + create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter], +) -> PassFunc: + def lower_pass( + mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str + ) -> nn.Module: + """ + Create a module transformation pass which lowers an `fx.GraphModule` into a + `TRTModule` + """ + interpreter = create_trt_interpreter(lower_setting) + interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) + if lower_setting.use_experimental_rt: + import io + + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext + + with io.BytesIO() as engine_bytes: + engine_bytes.write(interp_res.engine.serialize()) + engine_str = engine_bytes.getvalue() + + trt_module = TRTModuleNext( + engine_str, + name=module_name, + input_binding_names=interp_res.input_names, + output_binding_names=interp_res.output_names, + target_device=Device(f"cuda:{torch.cuda.current_device()}"), + # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do + ) + return trt_module + + else: + trt_module = TRTModule( + engine=interp_res.engine, + input_names=interp_res.input_names, + output_names=interp_res.output_names, + cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, + ) + return trt_module + + return lower_pass + + +@dc.dataclass(frozen=True) +class Lowerer: + """Lowers a module using fx2trt. + + This is a composable class to facilitate fx2trt. A normal fx2trt process + composes of the following passes to transform an `fx.GraphModule`: + + 1. trace - use torch.fx to trace the module so we can get the graph + representation of the model. + 2. split - the graph module is split into several submodules, + running either via TensorRT, or via regular CUDA. + + For each split that need to run via TRT, the following passes are + invoked: + + 3. `TRTInterpreter` - build the TRT engine for the submodule that + can be supported through `TRTInterpreter`. + 4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`. + 5. The converted submodule is then set back onto the top-level module + + """ + + lower_pass_manager_builder: LowerPassManagerBuilder + + @classmethod + def create( + cls, + lower_setting: LowerSetting, + interpreter_builder: Callable = create_lower_trt_interpreter, + split_func: Callable = default_split_function, + ) -> "Lowerer": + """Instantiate a `Lowerer` instance.""" + if not lower_setting.is_aten: + return cls( + lower_pass_manager_builder=LowerPassManagerBuilder( + lower_setting=lower_setting, + trace_func=lambda module, inputs: acc_tracer.trace( + module, + inputs, # type: ignore[arg-type] + ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, + leaf_module_list=lower_setting.leaf_module_list, + ), + split_func=split_func, + lower_func=default_lower_pass(interpreter_builder), + ) + ) + # proxytensor_trace + else: + return cls( + lower_pass_manager_builder=LowerPassManagerBuilder( + lower_setting=lower_setting, + trace_func=lambda module, inputs: aten_tracer.opt_trace( + module, inputs + ), + split_func=split_func, + lower_func=default_lower_pass(interpreter_builder), + ) + ) + + def __call__( + self, + module: nn.Module, + inputs: Input, + additional_inputs: Optional[Input] = None, + fp16_conversion_fn: Optional[Callable[[Input], Input]] = None, + ) -> nn.Module: + lower_setting = self.lower_pass_manager_builder.lower_setting + atol = lower_setting.correctness_atol + rtol = lower_setting.correctness_rtol + + @validate_inference( + atol=atol, + rtol=rtol, + ) + def do_lower(module: nn.Module, inputs: Input) -> nn.Module: + module.eval() + if ( + self.lower_pass_manager_builder.lower_setting.lower_precision + == LowerPrecision.FP16 + ): + module.half() + # A custom conversion function can be passed to the lowerer to + # handle inputs with custom types. By default, just handle + # tensors and NoneType. + if fp16_conversion_fn is None: + conversion_fn = ( + lambda x: x.half() + if x is not None and x.dtype == torch.float32 + else x + ) + else: + conversion_fn = fp16_conversion_fn + + inputs = tuple(conversion_fn(x) for x in inputs) + if lower_setting.is_aten: + pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( + inputs, additional_inputs + ) + else: + pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( + inputs, additional_inputs + ) + lower_result = pm(module) + return lower_result + + return do_lower(module, inputs) diff --git a/py/torch_tensorrt/dynamo/lower_setting.py b/py/torch_tensorrt/dynamo/lower_setting.py new file mode 100644 index 0000000000..d72709c75b --- /dev/null +++ b/py/torch_tensorrt/dynamo/lower_setting.py @@ -0,0 +1,92 @@ +import dataclasses as dc +from typing import List, Optional, Set, Type + +from torch import nn +from torch.fx.passes.pass_manager import PassManager + +from .input_tensor_spec import InputTensorSpec +from .passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul +from .utils import LowerPrecision + + +@dc.dataclass +class LowerSettingBasic: + """ + Basic class for lowering. + lower_precision: lower precision dtype during lowering. + min_block_size(int): The minimum number of contiguous TensorRT convertable nodes in order to run them in TensorRT + ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of + modules that need AST rewriting. This is aiming to eliminate input variable involve in + exception checking control flow. + leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where + modules will not be traced into. + verbose_profile (bool): verbosity of profiler, default to False. + """ + + lower_precision: LowerPrecision = LowerPrecision.FP32 + min_block_size: int = 3 + ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None + leaf_module_list: Optional[Set[Type[nn.Module]]] = None + verbose_profile: bool = False + is_aten: bool = False + + +@dc.dataclass +class LowerSetting(LowerSettingBasic): + """ + Basic configuration for lowering stack. + Args: + input_specs: Specs for inputs to engine, can either be a single size or a + range defined by Min, Optimal, Max sizes. + explicit_precision: Use explicit precision during lowering. + max_workspace_size: The maximum workspace size. The maximum GPU temporary + memory which the TensorRT engine can use at execution time. + strict_type_constraints: Require TensorRT engine to strictly follow data type + setting at execution time. + customized_fuse_pass: List of custmozied pass to apply during lowering process. + lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations + as (a->b->c->d)=>(e). Current basic fuse patterns are: + permute->linear + permute->matmul + verbose_log: Enable TensorRT engine verbose log mode. + algo_selector: Enable TensorRT algorithm selector at execution time. + timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing + cache file at execution time if valid timing cache file is provided. + save_timing_cache: Save updated timing cache data into timing cache file if the timing + cache file is provided. + cuda_graph_batch_size (int): Cuda graph batch size, default to be -1. + preset_lowerer (str): when specified, use a preset logic to build the + instance of Lowerer. + only used by explicit batch dim with dynamic shape mode. In general, we use 2 GPU setting with + 2 stream on each. Set total number to 8 as a safe default value. + tactic_sources: tactic sources for TensorRT kernel selection. Default to None, + meaning all possible tactic sources. + correctness_atol: absolute tolerance for correctness check + correctness_rtol: relative tolerance for correctness check + use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + """ + + input_specs: List[InputTensorSpec] = dc.field(default_factory=list) + explicit_batch_dimension: bool = True + explicit_precision: bool = False + max_workspace_size: int = 1 << 30 + strict_type_constraints: bool = False + customized_fuse_pass: PassManager = dc.field( + default_factory=lambda: PassManager.build_from_passlist([]) + ) + lower_basic_fuse_pass: PassManager = dc.field( + default_factory=lambda: PassManager.build_from_passlist( + [fuse_permute_matmul, fuse_permute_linear] + ) + ) + verbose_log: bool = False + algo_selector = None + timing_cache_prefix: str = "" + save_timing_cache: bool = False + cuda_graph_batch_size: int = -1 + preset_lowerer: str = "" + opt_profile_replica: int = 8 + tactic_sources: Optional[int] = None + correctness_atol: float = 0.1 + correctness_rtol: float = 0.1 + use_experimental_rt: bool = False diff --git a/py/torch_tensorrt/dynamo/observer.py b/py/torch_tensorrt/dynamo/observer.py new file mode 100644 index 0000000000..3742bd2840 --- /dev/null +++ b/py/torch_tensorrt/dynamo/observer.py @@ -0,0 +1,194 @@ +import contextlib +import functools +import logging +import traceback +import typing as t +from contextvars import ContextVar +from dataclasses import dataclass, field + +_LOGGER = logging.getLogger(__name__) + +# A context variable to hold registered callbacks for all the observers for the +# current execution context. The callbacks list could have been a member +# variable on the observer instance, however, contextvars document advice +# against creating context variables not at module-global level. +# https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar +_CALLBACKS: ContextVar[t.Dict["Observer", t.List[t.Callable]]] = ContextVar( + "_CALLBACKS", default=None +) + +TObserverCallback = t.TypeVar("TObserverCallback", bound=t.Callable[..., t.Any]) + +# Whether to rethrow the exception caught while calling observer callbacks. +# Default to False. True is only used during tests. +RETHROW_CALLBACK_EXCEPTION: bool = False + + +@dataclass(frozen=True) +class Observer(t.Generic[TObserverCallback]): + """ + Usage: + + >>> some_observer: Observer = ... + >>> with some_observer.add(callback_func): + >>> # do stuff, and when some_observer.observe() is called, + >>> # it will execute callback_func() + >>> ... + + """ + + name: str = "" + # Ensure each Observer instance is considered a distinct key when stored in + # the `_CALLBACKS` dictionary. + unique_id: object = field(default_factory=lambda: object()) + + def add(self, callback: TObserverCallback) -> t.ContextManager: + self._get_callbacks().append(callback) + + # Cannot decorate the outer `add` directly with `contextmanager`, + # because if it were not used with a `with` statement, its body won't + # be executed. + @contextlib.contextmanager + def _add(): + try: + yield + finally: + try: + self._get_callbacks().remove(callback) + except ValueError: + # Callback should be in the callbacks list. I'm just being + # extra cautious here. I don't want it to throw and affect + # business logic. + pass + + return _add() + + def observe(self, *args, **kwargs) -> None: + for callback in self._get_callbacks(): + with _log_error( + "Error calling observer callback", rethrow=RETHROW_CALLBACK_EXCEPTION + ): + callback(*args, **kwargs) + + def _get_callbacks(self) -> t.List[t.Callable]: + """ + Gets the callbacks registered in current execution context. Any code + that manipulates the returned list (add, remove, iterate) is + concurrency safe. + """ + callbacks_dict = _CALLBACKS.get() + if callbacks_dict is None: + callbacks_dict = {} + _CALLBACKS.set(callbacks_dict) + + if self not in callbacks_dict: + callbacks_dict[self] = [] + + return callbacks_dict[self] + + +@dataclass(frozen=True) +class ObserveContext: + """ + Passed to the registered callables that observes any function decorated by + `observable`. See `observable` for detail. + + Attributes: + callable: the observed callable object + args: the args passed to the callable + kwargs: the kwargs passed to the callable + return_value: the return value returned by the callable, only available + when observing the callable after its invocation (via + `CallableObservers.post`) + """ + + callable: t.Callable + args: t.List[t.Any] + kwargs: t.Mapping[str, t.Any] + return_value: t.Any = None + + +def observable(): + """ + A decorator to turn a function into observable + + Example: + + >>> @observable() + >>> def func_to_observe(x, y) -> int: + >>> ... + >>> + >>> def log(ctx: ObserveContext): + >>> print( + >>> f"called {ctx.callable.__name__} with {ctx.args} {ctx.kwargs}" + >>> ) + >>> + >>> # register: + >>> with func_to_observe.observers.pre.add(log): + >>> func_to_observe(1, 2) + >>> # print out "called func_to_observe with (1,2) + >>> # here it won't print + """ + + def decorator(observed_func: callable) -> ObservedCallable: + wrapped_func = _make_observable(orig_func=observed_func) + return functools.wraps(observed_func)(wrapped_func) + + return decorator + + +@dataclass(frozen=True) +class CallableObservers: + pre: Observer[t.Callable[[ObserveContext], None]] + post: Observer[t.Callable[[ObserveContext], None]] + + +class ObservedCallable: + """ + Interface for an observed callable + """ + + observers: CallableObservers + orig_func: callable + + def __call__(self, *args, **kwargs) -> t.Any: + raise NotImplementedError() + + +def _make_observable(orig_func: t.Callable) -> ObservedCallable: + """ + A wrapper for a callable which is to be observed. + """ + + observers = CallableObservers( + pre=Observer(), + post=Observer(), + ) + + @functools.wraps(orig_func) + def observed_func(*args, **kwargs): + observers.pre.observe(ObserveContext(orig_func, args, kwargs)) + return_value = None + try: + return_value = orig_func(*args, **kwargs) + return return_value + finally: + observers.post.observe( + ObserveContext(orig_func, args, kwargs, return_value) + ) + + observed_func.orig_func = orig_func + observed_func.observers = observers + + return observed_func + + +@contextlib.contextmanager +def _log_error(msg: str, rethrow: bool = False) -> t.ContextManager: + try: + yield + except Exception as e: + _e = e # noqa: F841 + _LOGGER.info(f"{msg} (This error is handled): {traceback.format_exc()}") + if rethrow: + raise diff --git a/py/torch_tensorrt/dynamo/passes/__init__.py b/py/torch_tensorrt/dynamo/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/dynamo/passes/graph_opts.py b/py/torch_tensorrt/dynamo/passes/graph_opts.py new file mode 100644 index 0000000000..2adc5c7fe3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/passes/graph_opts.py @@ -0,0 +1,74 @@ +from collections.abc import Sequence + +import torch +import torch.fx + + +def common_subexpression_elimination(graph_module: torch.fx.GraphModule) -> bool: + """ + Optimize quantization by removing repeated subexpressions. + + Args: + graph_module(torch.fx.GraphModule): target module to be optimized + + Returns: + Graph changed or not. + """ + + def seq_hashable(seq): + if seq is None: + return None + + items = [] + for old in seq: + if isinstance(old, Sequence) and not isinstance(old, str): + new = seq_hashable(old) + elif isinstance(old, dict): + new = dict_hashable(old) + elif isinstance(old, slice): + new = old.__reduce__() + else: + new = old + + items.append(new) + + return tuple(items) + + def dict_hashable(d): + if d is None: + return None + + items = [] + for k, old_v in d.items(): + if isinstance(old_v, Sequence): + new_v = seq_hashable(old_v) + elif isinstance(old_v, dict): + new_v = dict_hashable(old_v) + elif isinstance(old_v, slice): + new_v = old_v.__reduce__() + else: + new_v = old_v + + items.append((k, new_v)) + return tuple(sorted(items)) + + changed = False + env = {} + for n in graph_module.graph.nodes: + # do not CSE away impure ops + if n.op not in {"call_function", "call_method"} or n.is_impure(): + continue + + # hash target, args, kwargs + hash_val = (n.target, seq_hashable(n.args), dict_hashable(n.kwargs)) + + # check if a node has a substitute and can be eliminated + if hash_val in env: + n.replace_all_uses_with(env[hash_val]) + graph_module.graph.erase_node(n) + changed = True + continue + + env[hash_val] = n + + return changed diff --git a/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py b/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py new file mode 100644 index 0000000000..3fa4f69bc5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py @@ -0,0 +1,632 @@ +import copy +import logging +import operator +import warnings +from typing import Any, Optional + +import torch +import torch.fx +import torch.fx as fx +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +from torch.fx.experimental.const_fold import split_const_subgraphs + +from ..observer import observable + +from torch_tensorrt.fx.tracer.acc_tracer import acc_ops +from torch_tensorrt.fx.tracer.acc_tracer.acc_utils import get_attr +from .pass_utils import log_before_after, validate_inference + +_LOGGER = logging.getLogger(__name__) + +# Create an alias for module input type to avoid littering pyre-ignore for Any +# throughout the file. +Input = Any + + +def replace_mutable_op(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + if not isinstance(module, torch.fx.GraphModule): + return module + + # Before any lowering pass, replace mutable ops like torch.fill_ + # Because fx cannot deal with inplace ops + for n in module.graph.nodes: + # TODO: add more mutable ops + if (n.op == "call_method" and n.target == "fill_") or ( + n.op == "call_function" and n.target == torch.fill_ + ): + # Replace mutable op only if the modified variable + # is used by the rest of the graph + # only through this op + if set(n.args[0].users.keys()) == {n}: + with module.graph.inserting_after(n): + + # TODO: move this outside? + def fill_with_mul_zero_and_add(*args): + return args[0].mul(0.0).add(args[1]) + + new_node = module.graph.create_node( + "call_function", fill_with_mul_zero_and_add, args=n.args + ) + n.replace_all_uses_with(new_node) + module.graph.erase_node(n) + module.recompile() + return module + + +def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: + # Now we do constant folding on traced module. We want to skip pattern like + # weights -> quant -> dequant -> op during constant folding when the model is + # a quantized int8 model. + def skip_folding_quant_dequant(node: torch.fx.Node): + if node.target != acc_ops.quantize_per_tensor: + return False + # If quantize_per_node -> dequantize, then skip folding. + for user in node.users: + if user.target == acc_ops.dequantize: + return True + return False + + const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant) + const_split_mod.run_folding() + return const_split_mod + + +def replace_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.max_pool3d_with_indices.default, + torch.ops.aten.native_batch_norm.default, + ): + if len(n.users) != 1: + raise RuntimeError( + f"{n.target} has users={len(n.users)}. We can only handle it with 1 user" + ) + if n.target == torch.ops.aten.max_pool2d_with_indices.default: + new_op = torch.ops.aten.max_pool2d + new_args = n.args + elif n.target == torch.ops.aten.max_pool3d_with_indices.default: + new_op = torch.ops.aten.max_pool3d + new_args = n.args + elif n.target == torch.ops.aten.native_batch_norm.default: + new_op = torch.ops.aten.batch_norm + new_args = list(n.args) + new_args.append(False) + new_args = tuple(new_args) + + getitem_node = next(iter(n.users)) + with module.graph.inserting_after(getitem_node): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=n.kwargs, + ) + getitem_node.replace_all_uses_with(new_node) + module.graph.erase_node(getitem_node) + module.graph.eliminate_dead_code() + module.recompile() + return module + + +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input): + """ + Replace acc_ops.matmul + acc_ops.add with acc_ops.linear + TRT8.2 can take advantage of structured sparsity (2:4), but the graph needs contain a single FC layer. + Later versions of TRT should work with matmul. + + Example before: + def forward(self, x): + a = self.a + b = self.b + addmm_mm = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None + addmm_add = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None + return addmm_add + + After: + def forward(self, x): + a = self.a + b = self.b + linear_1 = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None + return linear_1 + """ + counter = 0 + for node in gm.graph.nodes: + if node.target != acc_ops.add: + continue + add_node = node + bias = add_node.kwargs["other"] + + if bias.op != "get_attr": + continue + # test that bias tensor is one-dimensional, should correspond to shape (out_features) + if get_attr(bias).dim() > 1: + continue + + node = add_node.kwargs["input"] + if node.target != acc_ops.matmul: + continue + matmul_node = node + a = matmul_node.kwargs["input"] + + node = matmul_node.kwargs["other"] + if node.op != "get_attr": + continue + + get_attr_node = node + weight = get_attr(get_attr_node) + # TODO: verify that weight comply with TRT structured sparsity requirements: + # For each output channel and for each spatial pixel in the kernel weights, + # every 4 input channels must have at least 2 zeros. + + # test that weight tensor is two-dimensional, should correspond to shape (out_features, in_features) + if weight.dim() != 2: + continue + + weight_t = weight.transpose(0, 1) + weight_t_name = "weight_t_tensor_" + str(counter) + gm.register_buffer(weight_t_name, weight_t) + counter += 1 + + with gm.graph.inserting_before(add_node): + weight_t_attr = gm.graph.get_attr(weight_t_name) + fused_node = gm.graph.call_function( + acc_ops.linear, + kwargs={"input": a, "weight": weight_t_attr, "bias": bias}, + ) + add_node.replace_all_uses_with(fused_node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + + +def trt_transposed_matmul( + lhs: torch.Tensor, rhs: torch.Tensor, lhs_transposed: bool, rhs_transposed: bool +): + if lhs_transposed: + lhs = lhs.transpose(-1, -2) + if rhs_transposed: + rhs = rhs.transpose(-1, -2) + return torch.matmul(lhs, rhs) + + +def trt_transposed_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +): + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def check_permute(node: torch.fx.Node): + ranks = len(node.meta["tensor_meta"].shape) + permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr] + allowed_permutation = list(i for i in range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +@observable() +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def fuse_permute_linear(gm: torch.fx.GraphModule, input: Input): + """ + Fuse pattern like permute + linear if permute is transposing the last two dimension. + """ + for node in gm.graph.nodes: + if node.target == acc_ops.linear: + inp = node.kwargs["input"] + if inp.target == acc_ops.permute and check_permute(inp): + inp = inp.kwargs["input"] + weight = node.kwargs["weight"] + bias = node.kwargs["bias"] + with gm.graph.inserting_before(node): + fused_node = gm.graph.call_function( + trt_transposed_linear, args=(inp, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + + +@observable() +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input): + """ + Fuse pattern like permute + matmul if permute is transposing the last two dimension. + """ + for node in gm.graph.nodes: + if node.target == acc_ops.matmul: + lhs, rhs = node.kwargs["input"], node.kwargs["other"] + lhs_transposed = rhs_tranposed = False + skip = False + + if lhs.target == acc_ops.permute and check_permute(lhs): + lhs_transposed = True + lhs = lhs.kwargs["input"] + + if rhs.target == acc_ops.permute and check_permute(rhs): + rhs_tranposed = True + rhs = rhs.kwargs["input"] + + if (not skip) and (lhs_transposed or rhs_tranposed): + with gm.graph.inserting_before(node): + fused_node = gm.graph.call_function( + trt_transposed_matmul, + args=(lhs, rhs, lhs_transposed, rhs_tranposed), + ) + node.replace_all_uses_with(fused_node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + + +def slice_list(sli: slice, dim: int, size: int): + slice_all = slice(None, None, None) + if size == 1: + return [sli] + elif size == 2: + if dim == 0: + return [sli, slice_all] + elif dim == 1: + return [slice_all, sli] + elif size == 3: + if dim == 0: + return [sli, slice_all, slice_all] + elif dim == 1: + return [slice_all, sli, slice_all] + elif dim == 2: + return [slice_all, slice_all, sli] + elif size == 4: + if dim == 0: + return [sli, slice_all, slice_all, slice_all] + elif dim == 1: + return [slice_all, sli, slice_all, slice_all] + elif dim == 2: + return [slice_all, slice_all, sli, slice_all] + elif dim == 3: + return [slice_all, slice_all, slice_all, sli] + + +def split_across( + gm: torch.fx.GraphModule, sli: slice, input_node: torch.fx.Node, dim: int, size: int +): + start_node = end_node = mid_node = None + if sli.start is None and sli.stop is None: + return (start_node, input_node, end_node) + if sli.start is not None: + st_sli = slice(0, sli.start, None) + slice_list_gen = slice_list(st_sli, dim, size) + start_node = gm.graph.call_function( + operator.getitem, args=(input_node, slice_list_gen) + ) + if sli.stop is not None: + end_sli = slice(sli.stop, None, None) + slice_list_gen = slice_list(end_sli, dim, size) + end_node = gm.graph.call_function( + operator.getitem, args=(input_node, slice_list_gen) + ) + if dim != size - 1: + mid_sli = slice(sli.start, sli.stop, None) + slice_list_gen = slice_list(mid_sli, dim, size) + mid_node = gm.graph.call_function( + operator.getitem, args=(input_node, slice_list_gen) + ) + return (start_node, mid_node, end_node) + + +def list_gen( + start_node: torch.fx.Node, + end_node: torch.fx.Node, + input_node: torch.fx.Node, + gm: torch.fx.GraphModule, + dim: int, +): + if start_node: + if end_node: + concat_list = [start_node, input_node, end_node] + else: + concat_list = [start_node, input_node] + else: + if end_node: + concat_list = [input_node, end_node] + else: + concat_list = [input_node] + if len(concat_list) > 1: + concat_node = gm.graph.call_function(torch.cat, args=(concat_list, dim)) + else: + concat_node = concat_list[0] + return concat_node + + +def transform_setitem(gm: torch.fx.GraphModule, input: Input): + """ + Setitem is not tracable in fx and acc tracer but is available in dynamo trace. This pass works for dynamo trace only. + The implementation decompose the setitem into a few getitem op and assembly together again through concat. + The major reason is that TRT does not support in-place copy and memory reference. + """ + map_replace = {} + for node in gm.graph.nodes: + for old_node in map_replace: + node.replace_input_with(old_node, map_replace[old_node]) + + if node.target == operator.setitem: + input_node = node.args[0] + sli = node.args[1] + inp = node.args[2] + + inp_flag = False + if type(inp) == torch.fx.node.Node and inp.target == operator.getitem: + new_args = list(copy.deepcopy(inp.args[1])) + for ind, val in enumerate(new_args): + if type(val) == int: + inp_flag = True + if val == -1: + new_args[ind] = slice(-1, None, None) + else: + new_args[ind] = slice(val, val + 1, None) + + if inp_flag: + with gm.graph.inserting_before(inp): + new_node = gm.graph.call_function( + operator.getitem, args=(inp.args[0], new_args) + ) + inp.replace_all_uses_with(new_node) + inp = new_node + + if type(sli) is not tuple: + sli = [sli] + + tmp_sli = [] + for x in sli: + if type(x) == int: + if x == -1: + tmp_sli.append(slice(-1, None, None)) + else: + tmp_sli.append(slice(x, x + 1, None)) + else: + tmp_sli.append(x) + sli = tmp_sli + + dimension = len(sli) + with gm.graph.inserting_before(node): + if dimension == 1: + start_node_0, _, end_node_0 = split_across( + gm, sli[0], input_node, dim=0, size=1 + ) + concat_node_0 = list_gen(start_node_0, end_node_0, inp, gm, 0) + elif dimension == 2: + start_node_0, mid_node_0, end_node_0 = split_across( + gm, sli[0], input_node, dim=0, size=2 + ) + start_node_1, _, end_node_1 = split_across( + gm, sli[1], mid_node_0, dim=1, size=2 + ) + concat_node_1 = list_gen(start_node_1, end_node_1, inp, gm, 1) + concat_node_0 = list_gen( + start_node_0, end_node_0, concat_node_1, gm, 0 + ) + elif dimension == 3: + start_node_0, mid_node_0, end_node_0 = split_across( + gm, sli[0], input_node, dim=0, size=3 + ) + start_node_1, mid_node_1, end_node_1 = split_across( + gm, sli[1], mid_node_0, dim=1, size=3 + ) + start_node_2, _, end_node_2 = split_across( + gm, sli[2], mid_node_1, dim=2, size=3 + ) + concat_node_2 = list_gen(start_node_2, end_node_2, inp, gm, 2) + concat_node_1 = list_gen( + start_node_1, end_node_1, concat_node_2, gm, 1 + ) + concat_node_0 = list_gen( + start_node_0, end_node_0, concat_node_1, gm, 0 + ) + elif dimension == 4: + start_node_0, mid_node_0, end_node_0 = split_across( + gm, sli[0], input_node, dim=0, size=4 + ) + start_node_1, mid_node_1, end_node_1 = split_across( + gm, sli[1], mid_node_0, dim=1, size=4 + ) + start_node_2, mid_node_2, end_node_2 = split_across( + gm, sli[2], mid_node_1, dim=2, size=4 + ) + start_node_3, _, end_node_3 = split_across( + gm, sli[3], mid_node_2, dim=3, size=4 + ) + concat_node_3 = list_gen(start_node_3, end_node_3, inp, gm, 3) + concat_node_2 = list_gen( + start_node_2, end_node_2, concat_node_3, gm, 2 + ) + concat_node_1 = list_gen( + start_node_1, end_node_1, concat_node_2, gm, 1 + ) + concat_node_0 = list_gen( + start_node_0, end_node_0, concat_node_1, gm, 0 + ) + else: + warnings.warn(f"setitem does not support dimension={dimension}") + continue + node.replace_all_uses_with(concat_node_0) + map_replace[input_node] = concat_node_0 + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + return gm + + +def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule: + """\ + TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256), + since the dynamic shape of the reshape comes from the dynamic shape of + another node (y). The compilation will fail with various memory related + errors, depending on the size of the input tensor. + + This pass fixes the issue by finding this reshape pattern, checking that: + + x.size(0) == y.size(0) + + And then replaces reshape's batch size from y.size(0) to x.size(0). + """ + + def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]: + """\ + Try to find the reshape op's batch size as an input node. + + Match below graph structure and return `node_y`: + node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}}) + """ + if ( + maybe_reshape.op != "call_function" + or maybe_reshape.target != acc_ops.reshape + ): + return None + shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None) + if not shape: + return None + batch_size = shape[0] + if isinstance(batch_size, fx.Node): + return batch_size + return None + + def get_reshape_batch_size_inferred_source( + batch_size_node: fx.Node, + ) -> Optional[fx.Node]: + """\ + Given a node representing the batch size used for reshape op, we want + to know if it is coming from below pattern: + + batch_size_node = src.size()[0] + + or in IR graph: + + src -> size(input=_) -> getitem(input=_, idx=0) + ^ ~~~ batch_size_node + + If so, return `src`. Otherwise, return `None`. + """ + if ( + batch_size_node.op != "call_function" + or batch_size_node.target != acc_ops.getitem + or batch_size_node.kwargs["idx"] != 0 + ): + return None + maybe_size: fx.Node = batch_size_node.all_input_nodes[0] + if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size: + return None + return maybe_size.all_input_nodes[0] + + maybe_reshape: fx.Node + for maybe_reshape in mod.graph.nodes: + reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( + maybe_reshape + ) + if not reshape_batch_size: + continue + reshape_batch_size_inferred_source: Optional[ + fx.Node + ] = get_reshape_batch_size_inferred_source(reshape_batch_size) + if not reshape_batch_size_inferred_source: + continue + + reshape_input: fx.Node = maybe_reshape.kwargs["input"] + if reshape_input == reshape_batch_size_inferred_source: + continue + + if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source): + continue + + _LOGGER.info( + f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}" + ) + + # Step 1: create a node to compute batch size, using the tensor which + # is being reshaped: reshape_input.size()[0]. This batch size is now + # derived from reshape_input, the same node as the reshape op's input. + with mod.graph.inserting_before(maybe_reshape): + reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function( + acc_ops.getitem, + kwargs={ + "idx": 0, + "input": maybe_reshape.graph.call_function( + acc_ops.size, + kwargs={ + "input": reshape_input, + }, + ), + }, + ) + + # Step 2: update `maybe_reshape`'s shape argument to be + # (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER) + maybe_reshape.kwargs = { + **maybe_reshape.kwargs, + "acc_out_ty": acc_utils.build_raw_tensor_meta( + shape=( + reshape_batch_size_2, + *(maybe_reshape.kwargs["acc_out_ty"].shape[1:]), + ) + ), + } + + mod.graph.eliminate_dead_code() + mod.recompile() + return mod + + +def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool: + """\ + Check that x.size(0) == y.size(0) + """ + x_size, y_size = _get_shape(x), _get_shape(y) + return ( + x_size + and y_size + # now both are non-empty + and x_size[0] == y_size[0] + ) + + +def _get_shape(node: fx.Node) -> Optional[torch.Size]: + if ( + not getattr(node, "meta", None) + or not node.meta.get("tensor_meta", None) + or not getattr(node.meta["tensor_meta"], "shape", None) + ): + # shape info not available + return None + return node.meta["tensor_meta"].shape + + +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def fix_clamp_numerical_limits_to_fp16( + mod: torch.fx.GraphModule, input: Input +) -> torch.fx.GraphModule: + MIN_FP16 = -65504.0 + MAX_FP16 = 65504.0 + for node in mod.graph.nodes: + if node.op == "call_function" and "clamp" in str(node.target): + input_kwargs = node.kwargs + if input_kwargs["min"] < MIN_FP16 and input_kwargs["max"] > MAX_FP16: + new_kwargs = { + "input": input_kwargs["input"], + "min": MIN_FP16, + "max": MAX_FP16, + } + node.kwargs = new_kwargs + + mod.recompile() + return mod diff --git a/py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py new file mode 100644 index 0000000000..00063c3e21 --- /dev/null +++ b/py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py @@ -0,0 +1,525 @@ +import logging +import operator +from typing import Any + +import torch +import torch.fx +from torch.fx.experimental.const_fold import split_const_subgraphs +from torch.fx.passes.infra.pass_base import PassResult + +_LOGGER = logging.getLogger(__name__) + +# Create an alias for module input type to avoid littering pyre-ignore for Any +# throughout the file. +Input = Any + + +def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: + # Now we do constant folding on traced module. + def skip_folding(node: torch.fx.Node): + if node.target == torch.ops.aten.sym_size: + return True + + const_split_mod = split_const_subgraphs( + traced_mod, skip_folding_node_fn=skip_folding + ) + const_split_mod.run_folding() + return const_split_mod + + +def replace_inplace_ops( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + Remove this func after functionalization is workable + """ + modified = False + map_func = { + torch.ops.aten.relu_.default: torch.ops.aten.relu.default, + torch.ops.aten.hardtanh_.default: torch.ops.aten.hardtanh.default, + torch.ops.aten.add_.Tensor: torch.ops.aten.add.Tensor, + } + for n in module.graph.nodes: + if n.op == "call_function" and n.target in map_func.keys(): + modified = True + node = n + with module.graph.inserting_after(node): + new_args = node.args + new_node = module.graph.create_node( + "call_function", + map_func[node.target], + args=new_args, + kwargs=None, + ) + node.replace_all_uses_with(new_node) + module.graph.erase_node(node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def replace_native_layernorm_with_layernorm( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + modified = False + for n in module.graph.nodes: + if ( + n.op == "call_function" + and n.target == torch.ops.aten.native_layer_norm.default + ): + for v in n.users: + if v.op == "call_function" and v.target == operator.getitem: + if v.args[1] != 0: + raise RuntimeError( + f"Got args[{v.args[1]}]!!\n" + "layernorm can only generate output (args[0]), " + "not mean (args[1]) or std (args[2])!" + ) + new_op = torch.ops.aten.layer_norm.default + new_args = (*n.args, True) # cudnn_enable=True + modified = True + else: + continue + + with module.graph.inserting_after(v): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=v.kwargs, + ) + v.replace_all_uses_with(new_node) + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def replace_transpose_mm_op_with_linear( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.t.default: + to_erase = [] + for v in n.users: + if v.op == "call_function" and v.target == torch.ops.aten.addmm.default: + new_op = torch.ops.aten.linear + bias, inp, _ = list(v.args) + weight = list(n.args)[0] + new_args = (inp, weight, bias) + modified = True + elif v.op == "call_function" and v.target == torch.ops.aten.mm.default: + new_op = torch.ops.aten.linear + inp, _ = list(v.args) + weight = list(n.args)[0] + new_args = (inp, weight, None) + modified = True + # this pass should be after `compose_bmm` + elif v.op == "call_function" and v.target == aten_compose_bmm_2d: + new_op = torch.ops.aten.linear + inp, _ = list(v.args) + weight = list(n.args)[0] + new_args = (inp, weight, None) + modified = True + else: + continue + + with module.graph.inserting_after(v): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=v.kwargs, + ) + v.replace_all_uses_with(new_node) + to_erase.append(v) + for v in to_erase: + module.graph.erase_node(v) + module.graph.eliminate_dead_code() + module.recompile() + # handle the linear with multiple dim, remove the extra reshape + for n in module.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear: + before = n.args[0] + after = next(iter(n.users)) + if (len(n.users) == 1 and after.target == torch.ops.aten.view.default) and ( + before.target == torch.ops.aten.view.default and len(before.users) == 1 + ): + real_input = before.args[0] + new_args = list(n.args) + new_args[0] = real_input + n.args = tuple(new_args) + after.replace_all_uses_with(n) + module.graph.eliminate_dead_code() + module.recompile() + + return PassResult(module, modified) + + +def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.max_pool3d_with_indices.default, + torch.ops.aten.native_batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, + ): + modified = True + if len(n.users) != 1: + raise RuntimeError( + f"{n.target} has users={len(n.users)}. We can only handle it with 1 user" + ) + if n.target == torch.ops.aten.max_pool2d_with_indices.default: + new_op = torch.ops.aten.max_pool2d + new_args = n.args + elif n.target == torch.ops.aten.max_pool3d_with_indices.default: + new_op = torch.ops.aten.max_pool3d + new_args = n.args + elif ( + n.target == torch.ops.aten.native_batch_norm.default + or n.target == torch.ops.aten._native_batch_norm_legit.default + ): + new_op = torch.ops.aten.batch_norm + new_args = list(n.args) + new_args.append(False) + new_args = tuple(new_args) + elif ( + n.target == torch.ops.aten._native_batch_norm_legit_no_training.default + ): + new_op = torch.ops.aten.batch_norm + new_args = list(n.args) + new_args.append(False) + # _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false) + # but batchnorm takes in a training arg at position 5. + new_args.insert(5, False) + new_args = tuple(new_args) + + getitem_node = next(iter(n.users)) + with module.graph.inserting_after(getitem_node): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=n.kwargs, + ) + getitem_node.replace_all_uses_with(new_node) + module.graph.erase_node(getitem_node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def replace_aten_reshape_alias_with_replace( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + # The stride parameter is not used. Replace with reshape without stride + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten._reshape_alias.default, + ): + modified = True + node = n + with module.graph.inserting_after(node): + new_args = (node.args[0], node.args[1]) + new_node = module.graph.create_node( + "call_function", + torch.ops.aten.reshape, + args=new_args, + kwargs=None, + ) + node.replace_all_uses_with(new_node) + module.graph.erase_node(node) + break + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def remove_ops( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + 1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable + 2. Remove inefficient op getitem(index=slice) P561572458 + """ + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in (torch.ops.aten.clone.default,): + modified = True + node = n + input_n = node.all_input_nodes[0] + node.replace_all_uses_with(input_n) + module.graph.eliminate_dead_code() + module.recompile() + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten._unsafe_view.default, + ): + modified = True + node = n + with module.graph.inserting_after(node): + new_node = module.graph.create_node( + "call_function", + torch.ops.aten.reshape, + args=node.args, + kwargs=node.kwargs, + ) + node.replace_all_uses_with(new_node) + module.graph.erase_node(node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def aten_operator_getitem(*args): + return operator.getitem(*args) + + +def replace_builtin_ops( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + To differential the same op in fx2ait as they are registered in the same dictionary + """ + + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in (operator.getitem,): + modified = True + n.target = aten_operator_getitem + module.graph.eliminate_dead_code() + module.recompile() + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +############### +""" +Trace compose. For some ops, we do not want to decompose further but want coarse granularity +For ex: +1. bmm +2. chunk +3. getitem(input, idx=(slice(),slice()...)) +""" + + +def aten_compose_getitem_slice(input, list_args): + for _, args in enumerate(list_args): + input = torch.ops.aten.slice.Tensor(input, *args) + return input + + +def compose_getitem_slice( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + combine decomposed getitem(input, idx=(slice(),slice()...)) + """ + + def match_pattern(module, node): + if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor: + holder = [] + holder.append(node) + while ( + len(node.users.keys()) == 1 + and next(iter(node.users)).target == torch.ops.aten.slice.Tensor + and node.args[1] + 1 == next(iter(node.users)).args[1] + ): + node = next(iter(node.users)) + holder.append(node) + if len(holder) == 1: + return (False,) + else: + return (True, holder) + return (False,) + + modified = False + for node in module.graph.nodes: + res = match_pattern(module, node) + if res[0]: + modified = True + holder = res[1] + input_n = holder[0].args[0] + last_n = holder[-1] + list_args = [] + for h_n in holder: + list_args.append(h_n.args[1:]) + + with module.graph.inserting_after(last_n): + new_args = (input_n, list_args) + new_node = module.graph.create_node( + "call_function", + aten_compose_getitem_slice, + args=new_args, + kwargs=None, + ) + last_n.replace_all_uses_with(new_node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def aten_compose_bmm_2d(flat_args_1, flat_args_2): + sym_size = torch.ops.aten.sym_size(flat_args_1, 0) + sym_size_1 = torch.ops.aten.sym_size(flat_args_1, 1) + sym_size_2 = torch.ops.aten.sym_size(flat_args_1, 2) + expand = torch.ops.aten.expand.default( + flat_args_1, [sym_size, sym_size_1, sym_size_2] + ) + view = torch.ops.aten.view.default(expand, [sym_size, sym_size_1, sym_size_2]) + sym_size_3 = torch.ops.aten.sym_size(flat_args_2, 0) + sym_size_4 = torch.ops.aten.sym_size(flat_args_2, 1) + expand_1 = torch.ops.aten.expand.default( + flat_args_2, [sym_size, sym_size_3, sym_size_4] + ) + view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) + bmm = torch.ops.aten.bmm.default(view, view_1) + view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) + return view_2 + + +def aten_compose_bmm_3d(flat_args_1, flat_args_2): + sym_size = torch.ops.aten.sym_size(flat_args_1, 0) + sym_size_1 = torch.ops.aten.sym_size(flat_args_1, 1) + sym_size_2 = torch.ops.aten.sym_size(flat_args_1, 2) + expand = torch.ops.aten.expand.default( + flat_args_1, [sym_size, sym_size_1, sym_size_2] + ) + view = torch.ops.aten.view.default(expand, [sym_size, sym_size_1, sym_size_2]) + sym_size_3 = torch.ops.aten.sym_size(flat_args_2, 1) + sym_size_4 = torch.ops.aten.sym_size(flat_args_2, 2) + expand_1 = torch.ops.aten.expand.default( + flat_args_2, [sym_size, sym_size_3, sym_size_4] + ) + view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) + bmm = torch.ops.aten.bmm.default(view, view_1) + view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) + return view_2 + + +def compose_bmm( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + combine decomposed bmm (matmul) + """ + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in (torch.ops.aten.bmm.default,): + modified = True + node = n + input_n = node.all_input_nodes[0] + other_n = node.all_input_nodes[1] + output = next(iter(node.users)) + input_input_n = input_n.all_input_nodes[0] + if ( + input_input_n.target != torch.ops.aten.expand.default + and input_n.target != torch.ops.aten.view.default + ): + raise RuntimeError( + "Bmm is addressed in fixed pattern. A new pattern is met!" + ) + real_input = input_input_n.all_input_nodes[0] + input_other_n = other_n.all_input_nodes[0] + if ( + input_other_n.target != torch.ops.aten.expand.default + and other_n.target != torch.ops.aten.view.default + ): + raise RuntimeError( + "Bmm is addressed in fixed pattern. A new pattern is met!" + ) + real_other = input_other_n.all_input_nodes[0] + if len(real_other.meta["val"].size()) == 2: + new_func = aten_compose_bmm_2d + if len(real_other.meta["val"].size()) == 3: + new_func = aten_compose_bmm_3d + + with module.graph.inserting_after(node): + new_args = (real_input, real_other) + new_node = module.graph.create_node( + "call_function", + new_func, + args=new_args, + kwargs=None, + ) + output.replace_all_uses_with(new_node) + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def aten_compose_chunk(flat_args_1, chunk, dim): + sym_size = torch.ops.aten.sym_size(flat_args_1, dim) + add = operator.add(sym_size, chunk) + sub = operator.sub(add, 1) + floordiv = operator.floordiv(sub, chunk) + split = torch.ops.aten.split.Tensor(flat_args_1, floordiv, dim) + return split + + +def compose_chunk( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + combine decomposed chunk + """ + + def match_pattern(module, node): + if node.op == "call_function" and node.target in (torch.ops.aten.split.Tensor,): + div = node.args[1] + input = node.args[0] + if isinstance(div, int): + return (False,) + if div.target != operator.floordiv: + return (False,) + else: + div_const = div.args[1] + sub = div.args[0] + if sub.target != operator.sub: + return (False,) + else: + add = sub.args[0] + if add.target != operator.add: + return (False,) + else: + add_const = add.args[1] + if add_const != div_const: + return (False,) + symsize = add.args[0] + if symsize.target != torch.ops.aten.sym_size: + return (False,) + else: + symsize_input = symsize.args[0] + dim = symsize.args[1] + if symsize_input != input: + return (False,) + + return (True, div_const, dim) + else: + return (False,) + + modified = False + for node in module.graph.nodes: + res = match_pattern(module, node) + if res[0]: + modified = True + with module.graph.inserting_after(node): + new_args = (node.args[0], res[1], res[2]) + new_node = module.graph.create_node( + "call_function", + aten_compose_chunk, + args=new_args, + kwargs=None, + ) + node.replace_all_uses_with(new_node) + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py new file mode 100644 index 0000000000..d79c0e77b4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py @@ -0,0 +1,325 @@ +import datetime +import logging +from functools import partial, wraps +from typing import Any, Callable, Optional, Sequence + +import torch +from torch import nn +from torch.fx.passes.pass_manager import inplace_wrapper, PassManager +from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt import _Input +from ..input_tensor_spec import InputTensorSpec + +from ..lower_setting import LowerSetting +from ..observer import Observer +from ..passes.remove_duplicate_output_args import remove_duplicate_output_args +from .graph_opts import common_subexpression_elimination +from .pass_utils import extract_example_tensors_from_input + +from .lower_basic_pass import ( # noqa + fix_clamp_numerical_limits_to_fp16, + fix_reshape_batch_dim, + replace_mutable_op, + replace_op_with_indices, + run_const_fold, +) + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +Input = Sequence[Any] + + +# ---------------------------------------------------------------------- +# OBSERVERS +# ---------------------------------------------------------------------- +# List of observers. We can subscribe to them by calling its `add(callback)` +# function from anywhere in code: +# +# >>> from torch_tensorrt.fx.lower import FUSE_PASSES_POST_OBSERVER +# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input): +# >>> # print_module_and_input will be called right after the fuse passes +# >>> lower(module, sample_input) + +# Observer for the model after the fuse passes. +FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer( + "FUSE_PASSES_POST_OBSERVER" +) + +# Observer for the TRT split submodules before lowering +LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer( + "LOWER_SPLIT_PRE_OBSERVER" +) + +# Observer for the TRT split submodules after lowering +LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer( + "LOWER_SPLIT_POST_OBSERVER" +) +# ---------------------------------------------------------------------- + + +def wrapper(fn: Callable, input) -> Callable: + @wraps(fn) + def wrapped_fn(gm): + if isinstance(gm, torch.fx.GraphModule): + ShapeProp(gm).propagate(*input) + return fn(gm, input) + + return wrapped_fn + + +class LowerPassManagerBuilder: + """ + Build PassManager for lowering. + + Attributes: + lower_setting: Setting that will be used during process of lowering, see lower_setting.py for the details. + _trace_func: fx trace function for TRT conversion. + _split_func: the fx2trt split function. + _lower_func: function to create and run `TRTInterpreter` to convert `fx.GraphModule` + into a TensorRT engine. + + """ + + def __init__( + self, + lower_setting: LowerSetting, + trace_func: Callable, + split_func: Callable, + lower_func: Callable, + ): + self.lower_setting = lower_setting + self._trace_func = trace_func + self._split_func = split_func + self._lower_func = lower_func + + def _const_fold_pass(self) -> PassManager: + passes = [ + wrapper(self._trace_func, self._input), + run_const_fold, + ] + return PassManager.build_from_passlist(passes) + + def graph_optimization_pass(self) -> PassManager: + passes = [ + wrapper(self._trace_func, self._input), + ] + for p in self.lower_setting.customized_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + for p in self.lower_setting.lower_basic_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + if ( + hasattr(self.lower_setting, "lower_precision") + and self.lower_setting.lower_precision is LowerPrecision.FP16 + ) or ( + hasattr(self.lower_setting, "precision") + and self.lower_setting.precision is LowerPrecision.FP16 + ): + passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) + + passes.append(inplace_wrapper(common_subexpression_elimination)) + passes.append( + inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + ) + passes.append(fix_reshape_batch_dim) + + return PassManager.build_from_passlist(passes) + + def graph_optimization_pass_aten(self) -> PassManager: + passes = [] + + for p in self.lower_setting.customized_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + for p in self.lower_setting.lower_basic_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + # TODO fix this pass for aten graph + # if ( + # hasattr(self.lower_setting, "lower_precision") + # and self.lower_setting.lower_precision is LowerPrecision.FP16 + # ) or ( + # hasattr(self.lower_setting, "precision") + # and self.lower_setting.precision is LowerPrecision.FP16 + # ): + # passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) + + passes.append( + inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + ) + # TODO we most likely do not need it for aten + # passes.append(fix_reshape_batch_dim) + + return PassManager.build_from_passlist(passes) + + def _split_pass(self) -> PassManager: + passes = [ + partial( + self._split_func, inputs=self._input, lower_setting=self.lower_setting + ) + ] + passes.append( + inplace_wrapper( + lambda split_result: remove_duplicate_output_args( + split_result.split_module, split_result.submodule_inputs.keys() + ) + ) + ) + + return PassManager.build_from_passlist(passes) + + def _trt_lower_pass(self) -> PassManager: + def lower_func(split_result: SplitResult) -> nn.Module: + if ( + hasattr(self.lower_setting, "explicit_batch_dimension") + and self.lower_setting.explicit_batch_dimension + and self._additional_input + ): + additional_submodule_inputs = generate_inputs_for_submodules( + split_result.split_module, + self._additional_input, + list(split_result.submodule_inputs.keys()), + ) + else: + additional_submodule_inputs = None + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + + LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) + + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + _LOGGER.info(f"Now lowering submodule {submod_name}") + lowering_start_time = datetime.datetime.now() + + self.lower_setting.input_specs = self._trt_input + + lowered_module = self._lower_func( + submod, submod_inputs, self.lower_setting, submod_name + ) + setattr(split_result.split_module, submod_name, lowered_module) + LOWER_SPLIT_POST_OBSERVER.observe( + submod_name, lowered_module, submod_inputs + ) + _LOGGER.info( + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" + ) + + return split_result.split_module + + return PassManager.build_from_passlist([lower_func]) + + def _default_lower_pass(self) -> PassManager: + def lower_func(split_result: SplitResult) -> nn.Module: + if self._additional_input: + additional_submodule_inputs = generate_inputs_for_submodules( + split_result.split_module, + self._additional_input, + list(split_result.submodule_inputs.keys()), + ) + else: + additional_submodule_inputs = None + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + + LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) + + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + _LOGGER.info(f"Now lowering submodule {submod_name}") + lowering_start_time = datetime.datetime.now() + + self.lower_setting.additional_inputs = ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None, + ) + + lowered_module = self._lower_func( + submod, submod_inputs, self.lower_setting, submod_name + ) + setattr(split_result.split_module, submod_name, lowered_module) + LOWER_SPLIT_POST_OBSERVER.observe( + submod_name, lowered_module, submod_inputs + ) + _LOGGER.info( + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" + ) + + return split_result.split_module + + return PassManager.build_from_passlist([lower_func]) + + def _default_replace_mutable_op_pass(self) -> PassManager: + return PassManager.build_from_passlist([replace_mutable_op]) + + def build_trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + + self._input = extract_example_tensors_from_input(input) + self._trt_input = [] + for input_obj in input: + if isinstance(input_obj, _Input.Input): + self._trt_input.append(InputTensorSpec.from_input(input_obj)) + elif isinstance(input_obj, torch.Tensor): + self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) + else: + raise ValueError("Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor") + + self._additional_input = additional_input + passes = [] + + passes.append(self._default_replace_mutable_op_pass()) + passes.append(self._const_fold_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + + def build_aten2trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + + self._input = extract_example_tensors_from_input(input) + self._trt_input = [] + for input_obj in input: + if isinstance(input_obj, _Input.Input): + self._trt_input.append(InputTensorSpec.from_input(input_obj)) + elif isinstance(input_obj, torch.Tensor): + self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) + else: + raise ValueError("Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor") + + self._additional_input = additional_input + passes = [] + passes.append( + wrapper(self._trace_func, self._input), + ) + passes.append(self.graph_optimization_pass_aten()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + + def build_default_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + self._input = input + self._additional_input = additional_input + passes = [] + + passes.append(self._default_replace_mutable_op_pass()) + passes.append(self._const_fold_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._default_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm diff --git a/py/torch_tensorrt/dynamo/passes/pass_utils.py b/py/torch_tensorrt/dynamo/passes/pass_utils.py new file mode 100644 index 0000000000..3fdd4c7541 --- /dev/null +++ b/py/torch_tensorrt/dynamo/passes/pass_utils.py @@ -0,0 +1,296 @@ +import io +import logging +import tempfile +from datetime import datetime +from functools import wraps +from typing import Any, Callable, List, Optional + +import torch +from torch import fx +from torch.fx.passes.shape_prop import ShapeProp +from torch_tensorrt import _Input + +# Create an alias for module input type to avoid littering pyre-ignore for Any +# throughout the file. +Input = Any +_LOGGER: logging.Logger = logging.getLogger(__name__) + +PassFunc = Callable[[fx.GraphModule, Input], fx.GraphModule] + +RELAX_ACCURACY_FAILURE: bool = False +FINAL_CHECK_ATOL_MULTIPLIER: float = 10 +FINAL_CHECK_RTOL_MULTIPLIER: float = 10 + + +def extract_example_tensors_from_input( + inputs: Any, device: torch.device = torch.device("cuda") +): + input_tensors = [] + for input_obj in inputs: + if isinstance(input_obj, _Input.Input): + if isinstance(input_obj.shape, dict): + input_tensors.append( + input_obj.example_tensor(optimization_profile_field="opt_shape").to( + device + ) + ) + else: + input_tensors.append(input_obj.example_tensor().to(device)) + elif isinstance(input_obj, torch.Tensor): + input_tensors.append(input_obj) + else: + raise ValueError( + "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" + ) + + return input_tensors + + +class RelaxAccuracyCheckMode: + """ + Basically a context manager that controls a global variable that controls + the accuracy check mode. Use it like + with RelaxAccuracyCheckMode(True): + fx2trt() + """ + + def __init__( + self, + mode: bool, + final_atol_multiplier: Optional[float] = None, + final_rtol_multiplier: Optional[float] = None, + ): + """ + Arguments: + mode: whether we relax the immediate accuracy check failure or not. If yes, we will do an extra + accruacy check by raising the tolerance by the multipler times and only raise error if that fails. + This is to avoid catastrophic errors. + final_atol_multiplier [optional]: set FINAL_CHECK_ATOL_MULTIPLIER if specifier. + final_rtol_multiplier [optional]: set FINAL_CHECK_RTOL_MULTIPLIER if specifier. + """ + global RELAX_ACCURACY_FAILURE + global FINAL_CHECK_ATOL_MULTIPLIER + global FINAL_CHECK_RTOL_MULTIPLIER + self._old_mode = ( + RELAX_ACCURACY_FAILURE, + FINAL_CHECK_ATOL_MULTIPLIER, + FINAL_CHECK_RTOL_MULTIPLIER, + ) + RELAX_ACCURACY_FAILURE = mode + FINAL_CHECK_ATOL_MULTIPLIER = ( + final_atol_multiplier + if final_atol_multiplier + else FINAL_CHECK_ATOL_MULTIPLIER + ) + FINAL_CHECK_RTOL_MULTIPLIER = ( + final_rtol_multiplier + if final_rtol_multiplier + else FINAL_CHECK_RTOL_MULTIPLIER + ) + _LOGGER.info( + f"Set new relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" + ) + + def __enter__(self): + pass + + def __exit__(self, type, value, traceback): + global RELAX_ACCURACY_FAILURE + global FINAL_CHECK_ATOL_MULTIPLIER + global FINAL_CHECK_RTOL_MULTIPLIER + ( + RELAX_ACCURACY_FAILURE, + FINAL_CHECK_ATOL_MULTIPLIER, + FINAL_CHECK_RTOL_MULTIPLIER, + ) = self._old_mode + _LOGGER.info( + f"Restored old relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" + ) + + +def chain_passes(*passes: PassFunc) -> PassFunc: + """ + Chains a sequence of pass functions to form a single pass function + """ + + def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: + for pass_ in passes: + if isinstance(module, torch.fx.GraphModule): + ShapeProp(module).propagate(*input) + module = pass_(module, input) + return module + + return parent_pass + + +# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall +# on pass that failed accuracy check. +def validate_inference(rtol=None, atol=None): + def _validate_inference(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to validate that its inference results before and + after the pass run should be `close`. + """ + + @wraps(pass_) + def pass_with_validation( + module: fx.GraphModule, + input: Input, + *args, + **kwargs, + ) -> fx.GraphModule: + input_tensors = extract_example_tensors_from_input(input) + res0 = module(*input_tensors) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input_tensors) + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: + torch.testing.assert_close(x, y, **kwargs2) + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module + + return pass_with_validation + + return _validate_inference + + +Decorator = Callable[[Callable], Callable] + + +def decorate_method(dec_for_function: Decorator) -> Decorator: + def dec_for_method(unbounded_method) -> Callable: + def decorated_unbounded_method(self, *args, **kwargs): + @dec_for_function + def bounded_method(*args, **kwargs): + return unbounded_method(self, *args, **kwargs) + + return bounded_method(*args, **kwargs) + + return decorated_unbounded_method + + return dec_for_method + + +def log_perf_before_after(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to log perf of the module before and after the pass + """ + + @wraps(pass_) + def check_perf_with_before_after_log( + module: fx.GraphModule, input: Input + ) -> fx.GraphModule: + def benchmark_torch_function(iters: int, f, *args) -> float: + """Estimates the average time duration for a single inference call in second + + If the input is batched, then the estimation is for the batches inference call. + + Args: + iters: number of inference iterations to run + f: a function to perform a single inference call + + Returns: + estimated average time duration in second for a single inference call + """ + with torch.inference_mode(): + f(*args) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + # print("== Start benchmark iterations") + with torch.inference_mode(): + start_event.record() + for _ in range(iters): + f(*args) + end_event.record() + torch.cuda.synchronize() + # print("== End benchmark iterations") + return (start_event.elapsed_time(end_event) * 1.0e-3) / iters + + time_before = benchmark_torch_function(100, lambda: module(*input)) + _LOGGER.info(f"[{pass_}] Perf Before(eager mode): {time_before}") + + module = pass_(module, input) + time_after = benchmark_torch_function(100, lambda: module(*input)) + _LOGGER.info(f"[{pass_}] Perf After(eager mode): {time_after}") + return module + + return check_perf_with_before_after_log + + +def log_before_after(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to log the module graph before and after the pass + """ + + @wraps(pass_) + def pass_with_before_after_log( + module: fx.GraphModule, input: Input + ) -> fx.GraphModule: + before_io = io.StringIO() + after_io = io.StringIO() + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + print(f"[{pass_}] Before:\n{module.graph}", file=f) + print(module.graph, file=before_io) + start_time = datetime.now() + module = pass_(module, input) + t_elapsed = datetime.now() - start_time + print(f"[{pass_}] After:\n{module.graph}", file=f) + print(module.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + _LOGGER.info( + f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}, time elapsed = {t_elapsed}" + ) + return module + + return pass_with_before_after_log + + +def _collect_tensors(arg: fx.node.Argument) -> List[torch.Tensor]: + """Collects all the tensors found in a nested container object""" + res: List[torch.Tensor] = [] + + def collect(x: fx.node.Argument) -> fx.node.Argument: + if isinstance(x, torch.Tensor): + res.append(x) + return x + + fx.node.map_aggregate(arg, collect) + return res diff --git a/py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py b/py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py new file mode 100644 index 0000000000..84a522a3f0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 + +import dataclasses as dc +import logging +import operator +import typing as t + +import torch.fx as fx + +_LOGGER = logging.getLogger(__name__) + +RemoveDuplicateOutputArgsFunc = t.Callable[ + [ + fx.GraphModule, + t.Collection[str], + ], + t.Mapping[str, "RemoveDuplicateResult"], +] + + +def remove_duplicate_output_args( + top_level: fx.GraphModule, target_subnets: t.Collection[str] +) -> t.Mapping[str, "RemoveDuplicateResult"]: + """Removes duplicate output args. + + This pass removes duplicate output args from the target subnets and fixes + their uses in the top level module where the subnets are called. This pass + must be called after acc split on the top-level net and subsequent calls to + the acc trace on the subnets. + + This pass will change both the subnets and top level module. + + Returns: + a mapping of the target subnet name to its dedupcate result + """ + + processed_subnets = {} + for node in top_level.graph.nodes: # type: fx.Node + if node.op == "call_module" and node.name in target_subnets: + assert isinstance(node.target, str) + sub_gm = top_level.get_submodule(node.target) + assert isinstance(sub_gm, fx.GraphModule) + + replace_res = _remove_duplicate_output_args(sub_gm) + processed_subnets[node.name] = replace_res + if replace_res.replacement_map is None: + continue + sub_gm.recompile() + + needs_recompile = False + # iterate on the copy since we will be changing elements of node.users + for user in list(node.users): + idx = _ensure_proper_output_use(user, node) + idx_new = replace_res.replacement_map[idx] + if idx_new != idx: + user.args = (user.args[0], idx_new) + needs_recompile = True + + if needs_recompile: + top_level.recompile() + return processed_subnets + + +@dc.dataclass(frozen=True) +class RemoveDuplicateResult: + replacement_map: t.Optional[t.List[int]] + module: fx.GraphModule + + +def _ensure_proper_output_use(user: fx.Node, target_node: fx.Node) -> int: + """ + Ensures the node looks in proper form of calling the output of an fx2trt + splitter sub-net. Specifically: + + 1. op is call function, target: operator.getitem + 2. args is a 2-element tuple + 3. args[0] is the name of the subnet's output + 4. args[1] is the index into the subnet output tuple + + E.g.: + + %getitem_4 : [#users=1] = call_function[target=operator.getitem](args = (%_run_on_acc_1, 4), kwargs = {}) + + returns the index into the subnet output tuple + """ + assert ( + user.op == "call_function" + and user.target == operator.getitem + and len(user.args) == 2 + and isinstance(user.args[0], fx.Node) + and user.args[0].name == target_node.name + and isinstance(user.args[1], int) + ), f"Node is not a proper user of splitter output: {user.format_node()}" + + return user.args[1] + + +def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult: + output_nodes = [n for n in gm.graph.nodes if n.op == "output"] + assert ( + len(output_nodes) == 1 + ), f"Expecting exactly one `output` node, but got {len(output_nodes)}" + + changed = False + # arg node name to its index in the new output args tuple + name_to_idx: t.Dict[str, int] = {} + output_node = output_nodes[0] + + # Output op only uses its `args[0]`, and it does not have `kwargs`. + # https://pytorch.org/docs/stable/fx.html#torch.fx.Node + args: t.Sequence[t.Any] = output_node.args[0] + + # Only concern outselves to the case where the args is an iterable of fx.Node. + # Other return cases (e.g., a single value) is possible and we don't handle + # that in this pass. + if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)): + return RemoveDuplicateResult(replacement_map=None, module=gm) + + # Map old index of the arg node to the remaining node's idx, + # initialized to `i => i` + replacement_map: t.List[int] = list(range(len(args))) + args_new = [] + for idx, a in enumerate(args): + assert isinstance(a, fx.Node), f"Expecting fx.Node instance, but got: {type(a)}" + + if a.name not in name_to_idx: + args_new.append(a) + name_to_idx[a.name] = len(args_new) - 1 + else: + changed = True + _LOGGER.warning( + f"Replaced duplicate output arg '{a.name}': " + f"{idx} -> {name_to_idx[a.name]}" + ) + replacement_map[idx] = name_to_idx[a.name] + + output_node.args = (tuple(args_new),) + if changed: + gm.recompile() + return RemoveDuplicateResult(replacement_map, module=gm) diff --git a/py/torch_tensorrt/dynamo/tools/__init__.py b/py/torch_tensorrt/dynamo/tools/__init__.py new file mode 100644 index 0000000000..6423aa65ea --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/__init__.py @@ -0,0 +1 @@ +from .trt_minimizer import * # noqa: F401 F403 diff --git a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py new file mode 100644 index 0000000000..be99562455 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py @@ -0,0 +1,445 @@ +import logging +import time +import unittest +from typing import Callable, List, Optional, Set, Tuple + +import torch +import torch.fx + +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch.fx.experimental.normalize import NormalizeArgs +from torch.fx.passes import shape_prop +from torch.fx.passes.infra.pass_base import PassResult +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter, TRTModule +from torch_tensorrt.dynamo.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) +from torch_tensorrt.dynamo.passes.pass_utils import chain_passes +from torch_tensorrt.dynamo.utils import LowerPrecision, proxytensor_trace + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def fetch_attr(mod, target): + """ + Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. + + Args: + target (str): The fully-qualfiied name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") +class TRTTestCase(TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(3) + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops, + interpreter, + rtol, + atol, + precision=LowerPrecision.FP32, + ): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(mod, unexpected_ops) + start = time.perf_counter() + interpreter_result = interpreter.run(lower_precision=precision) + sec = time.perf_counter() - start + _LOGGER.info(f"Interpreter run time(s): {sec}") + trt_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + ref_outputs = mod(*inputs) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + outputs = trt_mod(*cuda_inputs) + end_event.record() + torch.cuda.synchronize() + _LOGGER.info( + f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" + ) + + if type(outputs) not in (list, tuple): + outputs = [outputs] + if type(ref_outputs) not in ( + list, + tuple, + torch.return_types.max, + torch.return_types.min, + ): + ref_outputs = [ref_outputs] + for out, ref in zip(outputs, ref_outputs): + if not isinstance(ref, torch.Tensor): + ref = torch.tensor([ref]) + ref = ref.cpu() # to_dtype test has cases with gpu output + if ref.dtype == torch.int64: + ref = ref.int() # convert torch.max's index output tensor to int32 + torch.testing.assert_close( + out.cpu(), ref, rtol=rtol, atol=atol, equal_nan=True + ) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + """ + Runs the test and compares the result using the provided comparators. + The size of comparators must be equal to the number of outputs from 'mod'. + + mod - a model to run. + inputs - a list of the model inputs. + expected ops - a list of ops that should be verified. + interpreter - used for converting the model to TRT. + comparators - a list of (func, args) pairs corresponding to each of + the module outputs. usage: func(x, y, *args) + + """ + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + + interpreter_result = interpreter.run( + lower_precision=LowerPrecision.FP16 + if fp16_mode + else LowerPrecision.FP32 + ) + trt_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + res_trt = trt_mod(*cuda_inputs).cpu() + res_cpu = mod(*inputs) + assert len(res_trt) == len(res_cpu) + assert len(res_cpu) == len(comparators) + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) + + def run_test_with_error(self, mod, inputs, interpreter, expect_error): + with self.assertRaises(expect_error): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + interpreter.run(lower_precision=LowerPrecision.FP32) + + def assert_has_op(self, mod, ops): + ops_in_mod = set() + + for node in mod.graph.nodes: + if node.op == "call_module": + ops_in_mod.add(type(fetch_attr(mod, node.target))) + elif node.op in {"call_function", "call_method"}: + ops_in_mod.add(node.target) + + self.assertTrue( + ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" + ) + + def assert_unexpected_op(self, mod, ops): + for node in mod.graph.nodes: + if node.op == "call_module": + if type(fetch_attr(mod, node.target)) in ops: + return False + elif node.op in {"call_function", "call_method"}: + if node.target in ops: + return False + return True + + +class VanillaTestCase(TRTTestCase): + def run_test(self, mod, inputs, expected_ops, rtol=1e-03, atol=1e-03): + mod = torch.fx.symbolic_trace(mod) + shape_prop.ShapeProp(mod).propagate(*inputs) + mod = NormalizeArgs(mod).transform() + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + # interpreter is ignored, we do not need this for Vanilla tests + # Note this is different from internal version, we need to fix the test case + # after we refactor the internal callsites to use this file + mod = torch.fx.symbolic_trace(mod) + shape_prop.ShapeProp(mod).propagate(*inputs) + mod = NormalizeArgs(mod).transform() + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test_custom_compare_results( + mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode + ) + + +class AccTestCase(TRTTestCase): + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + test_explicit_batch_dim=True, + test_implicit_batch_dim=True, + test_explicit_precision=False, + rtol=1e-03, + atol=1e-03, + precision=LowerPrecision.FP32, + ): + mod.eval() + mod = acc_tracer.trace(mod, inputs) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_precision: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol + ) + + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + def run_test_with_assert_error( + self, + mod, + inputs, + expect_error, + test_explicit_batch_dim=True, + test_implicit_batch_dim=True, + ): + mod.eval() + mod = acc_tracer.trace(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test_with_error(mod, inputs, interp, expect_error) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + super().run_test_with_error(mod, inputs, interp, expect_error) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + mod = acc_tracer.trace(mod, inputs) + interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True) + super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) + + +class DispatchTestCase(TRTTestCase): + def generate_graph( + self, + mod: torch.nn.Module, + original_inputs: List[torch.Tensor], + expected_ops: Set[Callable], + unexpected_ops: Optional[Set[Callable]] = None, + customized_passes: List[Callable] = None, + ): + # Torchdynamo+aot proxytensor tracer + # Below are common passes + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + ] + # Combine with customized passes specific to any model + if customized_passes: + passes_list.extend(customized_passes) + fx_module, _ = aten_tracer.trace(mod, original_inputs) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + fx_module(*original_inputs) + + fx_module = run_const_fold(fx_module) + _LOGGER.info(f"FX graph= {fx_module.graph}") + + if len(expected_ops): + self.assert_has_op(fx_module, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(fx_module, unexpected_ops) + + return fx_module + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + test_explicit_batch_dim=True, + test_explicit_precision=False, + rtol=1e-03, + atol=1e-03, + precision=LowerPrecision.FP32, + ): + mod.eval() + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_precision: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol + ) + + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + interp = TRTInterpreter( + mod, + input_specs, + explicit_batch_dimension=True, + ) + # Since the lowering is based on optimal shape. We need to test with + # different shape(for ex. max shape) for testing dynamic shape + inputs_max = InputTensorSpec.create_inputs_from_max_specs(input_specs) + super().run_test( + mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol + ) diff --git a/py/torch_tensorrt/dynamo/tools/engine_layer_visualize.py b/py/torch_tensorrt/dynamo/tools/engine_layer_visualize.py new file mode 100644 index 0000000000..cecd1ecb20 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/engine_layer_visualize.py @@ -0,0 +1,217 @@ +import argparse +import logging +import re +from typing import Any, Dict, List, NamedTuple, Optional, Tuple + +import pydot + +_LOGGER: logging.Logger = logging.getLogger(__name__) +""" +log_file is generated by tensorrt verbose logger during building engine. +profile_file is generated by tensorrt profiler. + +Curretnly we support processing multiple logs in one log_file, which +would generate multiple dot graphs. However, multiple engine profiles are not +supported. + +Usage: + python torch_tensorrt.fx/tools/engine_layer_visualize.py --log_file aaa --profile_file bbb +""" + +parser = argparse.ArgumentParser() +parser.add_argument( + "--log_file", + type=str, + default="", + help="TensorRT VERBOSE logging when building engines.", +) +parser.add_argument( + "--profile_file", + type=str, + default="", + help="TensorRT execution context profiler output.", +) +args = parser.parse_args() + + +class LayerInfo(NamedTuple): + kernel_name: str + layer_name: str + tactic: str + input_names: Optional[List[str]] + input_types: Optional[List[str]] + output_name: str + output_type: str + time: str + + @classmethod + def from_string(cls, string, tactic_names, layer_times=None): + input_names = [] + input_types = [] + kernel_name, layer_name, tactic, inputs, output_name, output_type = re.findall( + "Layer\\((.+)\\): (.+), Tactic: (-?\\d+), (.+)? -> (.+)\\[(.+)\\]", string + )[0] + + if kernel_name != "Constant": + inputs = re.findall( + "[, ]*(.+?)\\[([Half|Float|Int8]+\\(\\d[,\\d]*\\))\\]", inputs + ) + for input_name, input_type in inputs: + input_names.append(input_name) + input_types.append(input_type) + + if layer_name in tactic_names: + kernel_name = tactic_names[layer_name] + else: + input_names = input_types = None # type:ignore[assignment] + + return cls( + kernel_name, + layer_name, + tactic, + input_names, + input_types, + output_name, + output_type, + layer_times[layer_name] if layer_times else "NA", + ) + + +def build_node(layer): + layer_name = layer.layer_name.replace("|", "\\|") + label = f"{{{layer_name}|kernel: {layer.kernel_name}\\l|tactic: {layer.tactic}\\l|time: {layer.time}\\l}}" + label = label.replace(">", "\\>") + return pydot.Node(layer.layer_name, label=label, **style) + + +def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node): + if layer.input_names is None: + return + + for input_name, input_type in zip(layer.input_names, layer.input_types): + if input_name not in output_name2node: + if input_name in reformat_layers: + from_node = pydot.Node( + input_name, + label="{reformatter|kernel: Reformat\\l|tactic: 0\\l}", + **style, + ) + graph.add_node(from_node) + if reformat_layers[input_name][0] in output_name2node: + graph.add_edge( + pydot.Edge( + output_name2node[reformat_layers[input_name][0]], + from_node, + label=f"{reformat_layers[input_name][0]}\\l{reformat_layers[input_name][1]}\\l", + ) + ) + else: + _LOGGER.info(f"Missing node {input_name}") + from_node = input_name + else: + from_node = output_name2node[input_name] + + edge_name = input_name.replace(">", "\\>") + graph.add_edge( + pydot.Edge( + from_node, + layer_name2node[layer.layer_name], + label=f"{edge_name}\\l{input_type}\\l", + ) + ) + + +if args.profile_file != "": + layer_times = {} + with open(args.profile_file) as f: + times = f.readlines() + + for t in times: + t = t.strip("\n").split(": ") # type: ignore[assignment] + layer_times[": ".join(t[:-1])] = t[-1] +else: + layer_times = None # type: ignore[assignment] + +if args.log_file != "": + with open(args.log_file) as f: + lines = f.readlines() + + graphs = [] + layers = [] + reformat_layers: Dict[str, Tuple[str, str]] = {} + tactic_names: Dict[str, str] = {} + layer_info_start = False + tactic_name_start = False + + for line in lines: + line = line.strip("\n") + + if layer_info_start: + if "Layer(" in line: + layers.append(LayerInfo.from_string(line, tactic_names, layer_times)) + else: + layer_info_start = False + graphs.append((layers, reformat_layers)) + layers = [] + reformat_layers = {} + tactic_names = {} + + if tactic_name_start and "Set Tactic Name:" in line: + layer_name, kernel_name, _ = re.findall( + "VERBOSE: (.*) Set Tactic Name: (.*) Tactic: (.*)$", line + )[0] + tactic_names[layer_name] = kernel_name + + # Some reformat layers aren't displayed in Engine Layer Information + if "Adding reformat layer" in line: + output_name, input_name, from_type, to_type = re.findall( + "reformat layer: (.+) \\((.+)\\) from (.+) to (.+)", line + )[0] + reformat_layers[output_name] = (input_name, from_type) + + if "Total Activation Memory:" in line: + tactic_name_start = True + + if "Engine Layer Information" in line: + layer_info_start = True + tactic_name_start = False + + style = { + "shape": "record", + "fillcolor": "Salmon", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + + dot_graphs: List[Any] = [] + i = 0 + for layers, reformat_layers in graphs: + output_name2node = {} + layer_name2node = {} + dot_graph = pydot.Dot("Layer Graph") + + for layer in layers: + node = build_node(layer) + dot_graph.add_node(node) + output_name2node[layer.output_name] = node + layer_name2node[layer.layer_name] = node + + for layer in layers: + build_edge( + layer, dot_graph, reformat_layers, output_name2node, layer_name2node + ) + + dot_graph.write_raw(f"EngineLayers_{i}.dot") + i += 1 + +if args.profile_file != "": + est_reformat_time = 0.0 + est_total_time = 0.0 + + for layer in layers: + if layer.kernel_name == "Reformat": + est_reformat_time += float(layer.time[:-2]) + est_total_time += float(layer.time[:-2]) + + _LOGGER.info(f"Time Cost on Reformatting: {est_reformat_time} ms") + _LOGGER.info(f"Total Time Cost: {est_total_time} ms") diff --git a/py/torch_tensorrt/dynamo/tools/graph_util.py b/py/torch_tensorrt/dynamo/tools/graph_util.py new file mode 100644 index 0000000000..5d07f76641 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/graph_util.py @@ -0,0 +1,78 @@ +import graphviz # type: ignore[import] + + +def get_layer_name_type(layer): + return "\n".join(f"{i}" for i in [layer.name, layer.type]) + + +def trt_network_to_dot_graph(network): + dot = graphviz.Digraph(comment="Network") + + # add nodes (layers) + for i in range(network.num_layers): + layer = network.get_layer(i) + dot.node(get_layer_name_type(layer)) + + # add nodes (inputs) + for i in range(network.num_inputs): + dot.node(network.get_input(i).name) + + # add nodes (outputs) + for i in range(network.num_outputs): + dot.node(network.get_output(i).name) + + # add layer->layer edges + for a in range(network.num_layers): + layer_a = network.get_layer(a) + + for b in range(network.num_layers): + layer_b = network.get_layer(b) + + for i in range(layer_a.num_outputs): + output_i = layer_a.get_output(i) + + for j in range(layer_b.num_inputs): + input_j = layer_b.get_input(j) + + if output_i == input_j: + dot.edge( + get_layer_name_type(layer_a), + get_layer_name_type(layer_b), + label=str(input_j.shape), + ) + + # add input->layer edges + for i in range(network.num_inputs): + input_i = network.get_input(i) + + for b in range(network.num_layers): + layer_b = network.get_layer(b) + + for j in range(layer_b.num_inputs): + input_j = layer_b.get_input(j) + + if input_i == input_j: + dot.edge( + input_i.name, + get_layer_name_type(layer_b), + label=str(input_j.shape), + ) + + # add layer->output edges + for i in range(network.num_outputs): + input_i = network.get_output(i) + + for b in range(network.num_layers): + layer_b = network.get_layer(b) + + for j in range(layer_b.num_outputs): + input_j = layer_b.get_output(j) + + if input_i == input_j: + dot.edge( + get_layer_name_type(layer_b), + input_i.name, + label=str(input_j.shape), + ) + + return dot diff --git a/py/torch_tensorrt/dynamo/tools/model_packager.py b/py/torch_tensorrt/dynamo/tools/model_packager.py new file mode 100644 index 0000000000..0ef0ff05a4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/model_packager.py @@ -0,0 +1,126 @@ +from pathlib import Path +from typing import BinaryIO, Sequence, TextIO, Union + +import torch +from torch.fx.passes.split_utils import getattr_recursive +from torch.package import PackageExporter + +""" +A tool to package acc submodule as a torch package. The packaged model can be loaded +with just PyTorch library. +""" + + +def flatten_model(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Remove all original modules with an attr holder module so that all original modules + and names are not present. + """ + holder_module = torch.nn.Module() + model._holder = holder_module + attr_id = 0 + + for node in model.graph.nodes: + assert node.op != "call_module" + if node.op == "get_attr": + attr = getattr_recursive(model, node.target) + setattr(holder_module, f"_attr_{attr_id}", attr) + with model.graph.inserting_before(node): + new_node = model.graph.get_attr(f"_holder._attr_{attr_id}") + node.replace_all_uses_with(new_node) + attr_id += 1 + + model.graph.eliminate_dead_code() + model.recompile() + model.delete_all_unused_submodules() + return model + + +def generate_standalone_repro( + model: torch.fx.GraphModule, output: Union[str, Path, TextIO], prelude: str = "" +) -> None: + """ + Generate a standalone python file for the model where weights are randomized + and the model flattened. + This only works if leaf nodes are only torch.nn modules. + """ + model = flatten_model(model) + + INDENT = " " + lines = [ + "", + "import torch", + "from torch import nn", + "", + "", + "class ExportedModule(nn.Module):", + f"{INDENT}def __init__(self):", + f"{INDENT * 2}super().__init__()", + ] + for k, v in model._holder.named_parameters(): + shape = ", ".join([str(i) for i in v.shape]) + rand_func = "randn" if torch.is_floating_point(v) else "randint" + int_range = "" if torch.is_floating_point(v) else "0, 5, " + lines.append( + f"{INDENT * 2}self.{k} = nn.Parameter(torch.{rand_func}({int_range}{shape}, dtype={v.dtype}))" + ) + code = str(model.code) + + def dump(f): + f.write(prelude) + f.write("\n".join(lines)) + f.write( + "\n".join( + [ + INDENT + line.replace("self._holder.", "self.") + for line in code.split("\n") + ] + ) + ) + f.write("\n") + + if isinstance(output, (Path, str)): + with open(str(output), "w") as f: + dump(f) + else: + dump(output) + + +class ModelPackager: + @classmethod + def set_extern_modules(cls, pe: PackageExporter) -> None: + pe.extern( + [ + "builtins", + "sys", + "torch.**", + ] + ) + + @classmethod + def set_mocked_modules(cls, pe: PackageExporter): + pe.mock( + "**", + exclude=[ + "torch_tensorrt.fx.tracer.acc_tracer.acc_ops", + "torch_tensorrt.fx.tracer.acc_tracer.acc_normalizer", + "torch_tensorrt.fx.tracer.acc_tracer.acc_op_properties", + ], + ) + + @classmethod + def package_model( + cls, + model: torch.nn.Module, + model_inputs: Sequence[torch.Tensor], + output: Union[str, Path, BinaryIO], + preserve_model_structure: bool = False, + ) -> None: + if not preserve_model_structure: + model = flatten_model(model) + with PackageExporter(output) as pe: + cls.set_extern_modules(pe) + cls.set_mocked_modules(pe) + pe.intern("**") + pe.save_pickle("repro", "model", model) + pe.save_pickle("repro", "inputs", model_inputs) diff --git a/py/torch_tensorrt/dynamo/tools/node_profiler.py b/py/torch_tensorrt/dynamo/tools/node_profiler.py new file mode 100644 index 0000000000..1a37c27197 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/node_profiler.py @@ -0,0 +1,53 @@ +from typing import Any + +import torch +from torch import fx + + +class NodeProfiler(fx.Interpreter): + """ + This is basically a variant of shape prop in + https://github.com/pytorch/pytorch/blob/74849d9188de30d93f7c523d4eeceeef044147a9/torch/fx/passes/shape_prop.py#L65. + Instead of propagating just the shape, we record all the intermediate node Tensor values. + This is useful to debug some of lowering pass issue where we want to check a specific + tensor value. Note that output value can be tuple(Tensor) as well as Tensor. + """ + + def __init__(self, module: fx.GraphModule): + super().__init__(module) + self.execution_time = {} + self.node_map = {} + self.iter = 100 + + def run_node(self, n: fx.Node) -> Any: + result = super().run_node(n) + if n.op not in {"call_function", "call_method", "call_module"}: + return result + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + for _ in range(self.iter): + result = super().run_node(n) + + end_event.record() + torch.cuda.synchronize() + + self.execution_time[f"{n.name}"] = ( + start_event.elapsed_time(end_event) / self.iter + ) + self.node_map[n.name] = n + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + Args: + *args (Tensor): the sample input. + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) diff --git a/py/torch_tensorrt/dynamo/tools/tensor_prop.py b/py/torch_tensorrt/dynamo/tools/tensor_prop.py new file mode 100644 index 0000000000..a52e0a3929 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/tensor_prop.py @@ -0,0 +1,33 @@ +from typing import Any + +from torch import fx + + +class TensorProp(fx.Interpreter): + """ + This is basically a variant of shape prop in + https://github.com/pytorch/pytorch/blob/74849d9188de30d93f7c523d4eeceeef044147a9/torch/fx/passes/shape_prop.py#L65. + Instead of propagating just the shape, we record all the intermediate node Tensor values. + This is useful to debug some of lowering pass issue where we want to check a specific + tensor value. Note that output value can be tuple(Tensor) as well as Tensor. + """ + + def __init__(self, module: fx.GraphModule): + super().__init__(module) + self.tensor_map = {} + + def run_node(self, n: fx.Node) -> Any: + result = super().run_node(n) + self.tensor_map[n.name] = result + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + Args: + *args (Tensor): the sample input. + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) diff --git a/py/torch_tensorrt/dynamo/tools/timing_cache_utils.py b/py/torch_tensorrt/dynamo/tools/timing_cache_utils.py new file mode 100644 index 0000000000..4580843e98 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/timing_cache_utils.py @@ -0,0 +1,39 @@ +import logging +import os + +logger = logging.getLogger(__name__) + + +class TimingCacheManager: + def __init__(self, timing_cache_prefix: str = "", save_timing_cache=False): + # Setting timing cache for TRTInterpreter + tc = os.environ.get("TRT_TIMING_CACHE_PREFIX", "") + timing_cache_prefix_name = timing_cache_prefix + if not timing_cache_prefix and tc: + timing_cache_prefix_name = tc + + self.timing_cache_prefix_name = timing_cache_prefix_name + self.save_timing_cache = save_timing_cache + + def get_file_full_name(self, name: str): + return f"{self.timing_cache_prefix_name}_{name}.npy" + + def get_timing_cache_trt(self, timing_cache_file: str) -> bytearray: + timing_cache_file = self.get_file_full_name(timing_cache_file) + try: + with open(timing_cache_file, "rb") as raw_cache: + cache_data = raw_cache.read() + return bytearray(cache_data) + except Exception: + return None + + def update_timing_cache( + self, timing_cache_file: str, serilized_cache: bytearray + ) -> None: + if not self.save_timing_cache: + return + timing_cache_file = self.get_file_full_name(timing_cache_file) + with open(timing_cache_file, "wb") as local_cache: + local_cache.seek(0) + local_cache.write(serilized_cache) + local_cache.truncate() diff --git a/py/torch_tensorrt/dynamo/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/tools/trt_minimizer.py new file mode 100644 index 0000000000..78b2f252bb --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/trt_minimizer.py @@ -0,0 +1,101 @@ +import logging +from typing import Any, Callable, Tuple + +import torch +import torch.fx.passes.net_min_base as net_min_base +from torch.fx.passes.tools_common import Tensors + +from .. import InputTensorSpec, TRTInterpreter, TRTModule + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def lower_mod_default( + mod: torch.fx.GraphModule, + inputs: Tensors, + use_experimental_rt: bool = False, +) -> TRTModule: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + interpreter_result = interp.run() + if use_experimental_rt: + import io + + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext + + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() + + res_mod = TRTModuleNext( + engine_str, + name=str(type(mod)), + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + target_device=Device(f"cuda:{torch.cuda.current_device()}"), + # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do + ) + else: + res_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + return res_mod + + +class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase): + def __init__( + self, explicit_batch_dimension: Any = True, use_experimental_rt: bool = False + ): + if use_experimental_rt and not explicit_batch_dimension: + raise ValueError( + "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_rt=True" + ) + + self.explicit_batch_dimension = explicit_batch_dimension + self.use_experimental_rt = use_experimental_rt + super(TensorRTMinizerSetting, self).__init__() + + +class TensorRTMinimizer(net_min_base._MinimizerBase): + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]], + settings: TensorRTMinizerSetting = TensorRTMinizerSetting(), + lower_fn: Callable[ + [torch.fx.GraphModule, Tensors, Any, bool], TRTModule + ] = lower_mod_default, + ): + self.lower_fn = lower_fn + self.use_experiemental_rt = settings.use_experimental_rt + super().__init__(module, sample_input, compare_fn, settings) + + def run_a(self, mod, inputs): + mod.eval() + with torch.no_grad(): + return mod(*inputs) + + def run_b(self, mod, inputs): + mod.eval() + try: + mod = self.lower_fn( + mod, inputs, self.use_experiemental_rt + ) + output = mod(*inputs) + except RuntimeError as e: + raise net_min_base.FxNetMinimizerRunFuncError( + f"Encounter an error when processing \n{mod.graph}\n {e}" + ) + else: + return output + + def get_nodes(self, start=None, end=None, enable_print=False): + nodes = self._collect_nodes(start, end) + if enable_print: + _LOGGER.info(f"Nodes fetched from start {start} to end {end} as: {nodes}") + return nodes diff --git a/py/torch_tensorrt/dynamo/tools/trt_profiler_sorted.py b/py/torch_tensorrt/dynamo/tools/trt_profiler_sorted.py new file mode 100644 index 0000000000..48293773c4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/trt_profiler_sorted.py @@ -0,0 +1,58 @@ +import json +import logging +import operator +from typing import List, Mapping, Optional + +import torch + +from tensorrt import tensorrt as trt + +from .. import TRTModule + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class SortedTRTProfiler(trt.IProfiler): + def __init__(self): + super().__init__() + self.layers = {} + + def report_layer_time(self, layer_name: str, ms: int) -> None: + self.layers[layer_name] = ms + + def print_sorted_profile( + self, additional_info: Optional[Mapping[str, str]] + ) -> None: + additional_info = {} if additional_info is None else additional_info + for k, v in sorted(self.layers.items(), key=operator.itemgetter(1)): + additional_str = additional_info.get(k, "") + _LOGGER.info(f"{k} {additional_str}: {v}ms") + + +def profile_trt_module( + name: str, trt_mod: TRTModule, mod_input: List[torch.Tensor] +) -> None: + """ + Provide per layer timing and shape info + """ + layer_info = json.loads(trt_mod.get_layer_info()) # pyre-ignore[29] + shape_map = {} + for layer in layer_info["Layers"]: + # if type is str, it means verbose_profile is off in interpreter.run() + # Theorectically, we can print profiling information without shape information + # but we choose to not print profiling information so we can use verbose_profile to control it + if type(layer) is str: + return + name = layer["Name"] + input_str = ", ".join( + [str(x.get("Dimensions", "[]")) for x in layer.get("Inputs", [])] + ) + output_str = ", ".join( + [str(x.get("Dimensions", "[]")) for x in layer.get("Outputs", [])] + ) + shape_map[name] = f"({input_str}) -> ({output_str})" + + trt_mod.enable_profiling(profiler=SortedTRTProfiler()) # pyre-ignore[29] + _ = trt_mod(*mod_input) + trt_mod.context.profiler.print_sorted_profile(shape_map) # pyre-ignore[16] + trt_mod.disable_profiling() # pyre-ignore[29] diff --git a/py/torch_tensorrt/dynamo/tools/trt_splitter.py b/py/torch_tensorrt/dynamo/tools/trt_splitter.py new file mode 100644 index 0000000000..bea925453f --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/trt_splitter.py @@ -0,0 +1,138 @@ +from typing import Any, Dict, Iterable, Sequence + +import torch +import torch.fx.passes.operator_support as ops +import torch.fx.passes.splitter_base as splitter_base +from torch.fx.passes.tools_common import get_acc_ops_name, Tensors + +from .. import ( + CONVERTERS, + InputTensorSpec, + NO_EXPLICIT_BATCH_DIM_SUPPORT, + NO_IMPLICIT_BATCH_DIM_SUPPORT, + TRTInterpreter, + TRTModule, +) +from ..tools.trt_minimizer import TensorRTMinimizer + + +def create_trt_operator_support( + use_implicit_batch_dim=True, + exclude_support_node_name: set = (), +) -> ops.OperatorSupportBase: + """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.""" + # Create an `OperatorSupport` that declares a node supported if it + # finds a registered TRT converter. + support_dict: Dict[str, None] = {} + for k in CONVERTERS.keys(): + if use_implicit_batch_dim: + if k not in NO_IMPLICIT_BATCH_DIM_SUPPORT.keys(): + support_dict[get_acc_ops_name(k)] = None + elif k not in NO_EXPLICIT_BATCH_DIM_SUPPORT.keys(): + support_dict[get_acc_ops_name(k)] = None + supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict) + + return ops.chain( + ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), + # 1. Node is not supported if it has args with int64 dtype: + ops.OpSupports.decline_if_input_dtype(torch.int64), + # 2. Node is supported if it has TRT converter: + supported_if_converter_registered, + ) + + +class TRTSplitterSetting(splitter_base._SplitterSettingBase): + def __init__(self): + super().__init__() + + # Determines what batch mode we'll use for lowering. + # During split, we'll split out the operators that + # don't support the batch dim. + self.use_implicit_batch_dim: bool = True + self.exclude_support_node_name: set = set() + self.use_experimental_rt: bool = False + + if self.use_experimental_rt and self.use_implicit_batch_dim: + raise ValueError( + "The experimental unifed runtime only supports explicit batch. Please make sure to set use_implicit_batch_dim=False when use_experimental_rt=True" + ) + + +class TRTSplitter(splitter_base._SplitterBase): + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Sequence[Any], + operator_support: ops.OperatorSupportBase = None, + settings: TRTSplitterSetting = None, + ): + if not settings: + settings = TRTSplitterSetting() + if not operator_support: + operator_support = create_trt_operator_support( + settings.use_implicit_batch_dim, settings.exclude_support_node_name + ) + super().__init__( + module, + sample_input, + operator_support, + settings, + non_acc_submodule_name="_run_on_gpu_", + ) + + def _lower_model_to_backend( + self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor] + ): + """ + Lower a GraphModule `mod` to TensorRT with `inputs`. + """ + # Current code for lowering is place-holder, subject to future change + # based on feeds model's actual status + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + interpreter_result = interp.run(*inputs) + if self.settings.use_experimental_rt: + import io + + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext + + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() + + return TRTModuleNext( + engine_str, + name=str(type(mod)), + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + target_device=Device(f"cuda:{torch.cuda.current_device()}"), + # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do + ) + else: + return TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors): + """ + This function serves the preview functionality in Splitter. When previewing + splitting result, if something wrong happens during lowering model to TensorRT + or running a TensorRT model, this function will be called to find any culprit + that is responsible for the error. + """ + # Since we don't care about accuracy here, we pass in a dummy compare function. + minimizer = TensorRTMinimizer(mod, inputs, lambda a, b, c: (1, True)) + minimizer.settings.traverse_method = "sequential" + minimizer.settings.find_all = True + culprits = minimizer.minimize() + + if len(culprits) == 0: + reports = "Unable to find a culprit!\n" + else: + reports = "Found some problematic nodes:\n" + for node in culprits: + reports += f"{node.format_node()}\n" + + return reports diff --git a/py/torch_tensorrt/dynamo/trt_module.py b/py/torch_tensorrt/dynamo/trt_module.py new file mode 100644 index 0000000000..099bbfcdc9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/trt_module.py @@ -0,0 +1,239 @@ +from typing import Any, List, Sequence + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch + +from .utils import torch_dtype_from_trt + + +class TRTModule(torch.nn.Module): + def __init__( + self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1 + ): + super(TRTModule, self).__init__() + self._register_state_dict_hook(TRTModule._on_state_dict) + self.engine = engine + self.input_names = input_names + self.output_names = output_names + self.cuda_graph_batch_size = cuda_graph_batch_size + self.initialized = False + + if engine: + self._initialize() + + def _initialize(self): + self.initialized = True + self.context = self.engine.create_execution_context() + + # Indices of inputs/outputs in the trt engine bindings, in the order + # as they are in the original PyTorch model. + self.input_binding_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.input_names + ] + self.output_binding_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.output_names + ] + primary_input_outputs = set() + primary_input_outputs.update(self.input_binding_indices_in_order) + primary_input_outputs.update(self.output_binding_indices_in_order) + self.hidden_output_binding_indices_in_order: Sequence[int] = [] + self.hidden_output_names: Sequence[str] = [] + for i in range( + self.engine.num_bindings // self.engine.num_optimization_profiles + ): + if i not in primary_input_outputs: + self.hidden_output_binding_indices_in_order.append(i) + self.hidden_output_names.append(self.engine.get_binding_name(i)) + + assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( + len(self.input_names) + + len(self.output_names) + + len(self.hidden_output_names) + ) + + self.input_dtypes: Sequence[torch.dtype] = [ + torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + for idx in self.input_binding_indices_in_order + ] + self.input_shapes: Sequence[Sequence[int]] = [ + tuple(self.engine.get_binding_shape(idx)) + for idx in self.input_binding_indices_in_order + ] + self.output_dtypes: Sequence[torch.dtype] = [ + torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + for idx in self.output_binding_indices_in_order + ] + self.output_shapes = [ + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + for idx in self.output_binding_indices_in_order + ] + self.hidden_output_dtypes: Sequence[torch.dtype] = [ + torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + for idx in self.hidden_output_binding_indices_in_order + ] + self.hidden_output_shapes = [ + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + for idx in self.hidden_output_binding_indices_in_order + ] + + def _check_initialized(self): + if not self.initialized: + raise RuntimeError("TRTModule is not initialized.") + + def _on_state_dict(self, state_dict, prefix, local_metadata): + self._check_initialized() + state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) + state_dict[prefix + "input_names"] = self.input_names + state_dict[prefix + "output_names"] = self.output_names + state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + engine_bytes = state_dict[prefix + "engine"] + + logger = trt.Logger() + runtime = trt.Runtime(logger) + self.engine = runtime.deserialize_cuda_engine(engine_bytes) + + self.input_names = state_dict[prefix + "input_names"] + self.output_names = state_dict[prefix + "output_names"] + self._initialize() + + def __getstate__(self): + state = self.__dict__.copy() + state["engine"] = bytearray(self.engine.serialize()) + state.pop("context", None) + return state + + def __setstate__(self, state): + logger = trt.Logger() + runtime = trt.Runtime(logger) + state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) + self.__dict__.update(state) + if self.engine: + self.context = self.engine.create_execution_context() + + def forward(self, *inputs): + with torch.autograd.profiler.record_function("TRTModule:Forward"): + self._check_initialized() + + with torch.autograd.profiler.record_function("TRTModule:ProcessInputs"): + assert len(inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." + + # This is only used when the trt engine is using implicit batch dim. + batch_size = inputs[0].shape[0] + contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] + bindings: List[Any] = [None] * ( + len(self.input_names) + + len(self.output_names) + + len(self.hidden_output_names) + ) + + for i, input_name in enumerate(self.input_names): + assert inputs[ + i + ].is_cuda, f"{i}th input({input_name}) is not on cuda device." + assert ( + inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}." + + idx = self.input_binding_indices_in_order[i] + bindings[idx] = contiguous_inputs[i].data_ptr() + + if not self.engine.has_implicit_batch_dimension: + self.context.set_binding_shape( + idx, tuple(contiguous_inputs[i].shape) + ) + else: + assert inputs[i].size()[1:] == self.input_shapes[i], ( + f"Shape mismatch for {i}th input({input_name}). " + f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}." + ) + + with torch.autograd.profiler.record_function("TRTModule:ProcessOutputs"): + # create output tensors + outputs: List[torch.Tensor] = [] + + for i, idx in enumerate(self.output_binding_indices_in_order): + if self.engine.has_implicit_batch_dimension: + shape = (batch_size,) + self.output_shapes[i] + else: + shape = tuple(self.context.get_binding_shape(idx)) + + output = torch.empty( # type: ignore[call-overload] + size=shape, + dtype=self.output_dtypes[i], + device=torch.cuda.current_device(), + ) + outputs.append(output) + bindings[idx] = output.data_ptr() + + for i, idx in enumerate(self.hidden_output_binding_indices_in_order): + if self.engine.has_implicit_batch_dimension: + shape = (batch_size,) + self.hidden_output_shapes[i] + else: + shape = tuple(self.context.get_binding_shape(idx)) + + output = torch.empty( # type: ignore[call-overload] + size=shape, + dtype=self.hidden_output_dtypes[i], + device=torch.cuda.current_device(), + ) + bindings[idx] = output.data_ptr() + + with torch.autograd.profiler.record_function("TRTModule:TensorRTRuntime"): + if self.engine.has_implicit_batch_dimension: + self.context.execute_async( + batch_size, bindings, torch.cuda.current_stream().cuda_stream + ) + else: + self.context.execute_async_v2( + bindings, torch.cuda.current_stream().cuda_stream + ) + + if len(outputs) == 1: + return outputs[0] + + return tuple(outputs) + + def enable_profiling(self, profiler: "trt.IProfiler" = None): + """ + Enable TensorRT profiling. After calling this function, TensorRT will report + time spent on each layer in stdout for each forward run. + """ + self._check_initialized() + + if not self.context.profiler: + self.context.profiler = trt.Profiler() if profiler is None else profiler + + def disable_profiling(self): + """ + Disable TensorRT profiling. + """ + self._check_initialized() + + torch.cuda.synchronize() + del self.context + self.context = self.engine.create_execution_context() + + def get_layer_info(self) -> str: + """ + Get layer info of the engine. Only support for TRT > 8.2. + """ + inspector = self.engine.create_engine_inspector() + return inspector.get_engine_information(trt.LayerInformationFormat.JSON) diff --git a/py/torch_tensorrt/dynamo/types.py b/py/torch_tensorrt/dynamo/types.py new file mode 100644 index 0000000000..f233f8dd9c --- /dev/null +++ b/py/torch_tensorrt/dynamo/types.py @@ -0,0 +1,24 @@ +from typing import Sequence, Tuple + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt + +if hasattr(trt, "__version__"): + TRTNetwork = trt.INetworkDefinition + TRTTensor = trt.tensorrt.ITensor + TRTLayer = trt.ILayer + TRTPluginFieldCollection = trt.PluginFieldCollection + TRTPlugin = trt.IPluginV2 + TRTDataType = trt.DataType + TRTElementWiseOp = trt.ElementWiseOperation +else: + TRTNetwork = "trt.INetworkDefinition" + TRTTensor = "trt.tensorrt.ITensor" + TRTLayer = "trt.ILayer" + TRTPluginFieldCollection = "trt.PluginFieldCollection" + TRTPlugin = "trt.IPluginV2" + TRTDataType = "trt.DataType" + TRTElementWiseOp = "trt.ElementWiseOperation" + +Shape = Sequence[int] +ShapeRange = Tuple[Shape, Shape, Shape] diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py new file mode 100644 index 0000000000..52daf6f09e --- /dev/null +++ b/py/torch_tensorrt/dynamo/utils.py @@ -0,0 +1,140 @@ +from enum import Enum +from typing import List, Callable +from packaging import version + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from functorch import make_fx +from functorch.experimental import functionalize +from torch_tensorrt.dynamo.passes.lower_basic_pass import ( + replace_op_with_indices, + run_const_fold, +) + +from .types import Shape, TRTDataType + + +class LowerPrecision(Enum): + FP32 = "fp32" + FP16 = "fp16" + INT8 = "int8" + + +def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType: + """ + Convert PyTorch data types to TensorRT data types. + + Args: + dtype (torch.dtype): A PyTorch data type. + + Returns: + The equivalent TensorRT data type. + """ + if trt.__version__ >= "7.0" and dtype == torch.bool: + return trt.bool + elif dtype == torch.int8: + return trt.int8 + elif dtype == torch.int32: + return trt.int32 + elif dtype == torch.float16: + return trt.float16 + elif dtype == torch.float32: + return trt.float32 + else: + raise TypeError("%s is not supported by tensorrt" % dtype) + + +def torch_dtype_from_trt(dtype: TRTDataType) -> torch.dtype: + """ + Convert TensorRT data types to PyTorch data types. + + Args: + dtype (TRTDataType): A TensorRT data type. + + Returns: + The equivalent PyTorch data type. + """ + if dtype == trt.int8: + return torch.int8 + elif trt.__version__ >= "7.0" and dtype == trt.bool: + return torch.bool + elif dtype == trt.int32: + return torch.int32 + elif dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + else: + raise TypeError("%s is not supported by torch" % dtype) + + +def get_dynamic_dims(shape: Shape) -> List[int]: + """ + This function finds the dynamic dimensions in the given + shape. A dimension is dynamic if it's -1. + + Args: + shape (Shape): A sequence of integer that represents + the shape of a tensor. + + Returns: + A list of integers contains all the dynamic dimensions + in the given shape + """ + dynamic_dims = [] + + for i, s in enumerate(shape): + if s == -1: + dynamic_dims.append(i) + + return dynamic_dims + + +def proxytensor_trace(mod, inputs): + + mod.eval() + + def f(*inp): + return mod(*inp) + + mod = make_fx(functionalize(f))(*inputs) + + # Remove const operation. For ex, nn.Linear has transpose operation on weight + mod.graph.eliminate_dead_code() + mod = run_const_fold(mod) + mod = replace_op_with_indices(mod) + return mod + + +def req_torch_version(min_torch_version: str = "2.dev"): + """ + Create a decorator which verifies the Torch version installed + against a specified version range + + Args: + min_torch_version (str): The minimum required Torch version + for the decorated function to work properly + + Returns: + A decorator which raises a descriptive error message if + an unsupported Torch version is used + """ + + def nested_decorator(f: Callable): + def function_wrapper(*args, **kwargs): + # Parse minimum and current Torch versions + min_version = version.parse(min_torch_version) + current_version = version.parse(torch.__version__) + + if current_version < min_version: + raise AssertionError( + f"Expected Torch version {min_torch_version} or greater, " + + f"when calling {f}. Detected version {torch.__version__}" + ) + else: + return f(*args, **kwargs) + + return function_wrapper + + return nested_decorator From 9ce574610211e51913d08fbd3fad5ca96dea2c1a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 4 Apr 2023 17:35:53 -0700 Subject: [PATCH 06/45] chore: Initial refactoring of FX tests in dynamo namespace Signed-off-by: Dheeraj Peri --- .../acc_op/test_adaptive_avgpool.py | 99 + .../dynamo/test/converters/acc_op/test_any.py | 91 + .../test/converters/acc_op/test_as_strided.py | 59 + .../test/converters/acc_op/test_avgpool.py | 265 ++ .../test/converters/acc_op/test_batchnorm.py | 66 + .../test/converters/acc_op/test_binary_ops.py | 212 ++ .../dynamo/test/converters/acc_op/test_cat.py | 101 + .../test/converters/acc_op/test_chunk.py | 79 + .../test/converters/acc_op/test_clamp.py | 71 + .../converters/acc_op/test_convolution.py | 199 ++ .../test/converters/acc_op/test_dequantize.py | 68 + .../test/converters/acc_op/test_einsum.py | 66 + .../dynamo/test/converters/acc_op/test_elu.py | 52 + .../test/converters/acc_op/test_embedding.py | 107 + .../dynamo/test/converters/acc_op/test_eq.py | 289 ++ .../test/converters/acc_op/test_expand.py | 34 + .../test/converters/acc_op/test_flatten.py | 70 + .../test/converters/acc_op/test_gelu.py | 95 + .../test/converters/acc_op/test_getitem.py | 198 ++ .../dynamo/test/converters/acc_op/test_gt.py | 276 ++ .../converters/acc_op/test_hard_sigmoid.py | 59 + .../test/converters/acc_op/test_hardtanh.py | 57 + .../converters/acc_op/test_interpolate.py | 150 + .../test/converters/acc_op/test_isinf.py | 63 + .../test/converters/acc_op/test_leaky_relu.py | 52 + .../test/converters/acc_op/test_linear.py | 60 + .../converters/acc_op/test_logical_and.py | 230 ++ .../test/converters/acc_op/test_logical_or.py | 201 ++ .../converters/acc_op/test_logical_xor.py | 201 ++ .../dynamo/test/converters/acc_op/test_lt.py | 274 ++ .../converters/acc_op/test_masked_fill.py | 72 + .../test/converters/acc_op/test_matmul.py | 117 + .../dynamo/test/converters/acc_op/test_max.py | 160 + .../test/converters/acc_op/test_maximum.py | 82 + .../test/converters/acc_op/test_maxpool.py | 379 +++ .../dynamo/test/converters/acc_op/test_min.py | 159 + .../test/converters/acc_op/test_minimum.py | 82 + .../test/converters/acc_op/test_narrow.py | 55 + .../dynamo/test/converters/acc_op/test_ne.py | 304 ++ .../test/converters/acc_op/test_new_ones.py | 73 + .../test/converters/acc_op/test_numel.py | 41 + .../dynamo/test/converters/acc_op/test_pad.py | 102 + .../test/converters/acc_op/test_permute.py | 87 + .../test/converters/acc_op/test_prod.py | 118 + .../acc_op/test_quantize_per_tensor.py | 65 + .../test/converters/acc_op/test_reduce_ops.py | 108 + .../test/converters/acc_op/test_relu.py | 52 + .../acc_op/test_repeat_interleave.py | 76 + .../test/converters/acc_op/test_reshape.py | 138 + .../test/converters/acc_op/test_selu.py | 52 + .../test/converters/acc_op/test_sigmoid.py | 35 + .../test/converters/acc_op/test_silu.py | 52 + .../test/converters/acc_op/test_size.py | 71 + .../test/converters/acc_op/test_softmax.py | 81 + .../test/converters/acc_op/test_softsign.py | 52 + .../test/converters/acc_op/test_split.py | 107 + .../test/converters/acc_op/test_squeeze.py | 41 + .../dynamo/test/converters/acc_op/test_std.py | 117 + .../test/converters/acc_op/test_tanh.py | 52 + .../test/converters/acc_op/test_tile.py | 145 + .../test/converters/acc_op/test_to_dtype.py | 319 ++ .../test/converters/acc_op/test_topk.py | 84 + .../acc_op/test_transpose_convolution.py | 137 + .../test/converters/acc_op/test_type_as.py | 150 + .../test/converters/acc_op/test_unary_ops.py | 165 + .../test/converters/acc_op/test_unsqueeze.py | 60 + .../test/converters/acc_op/test_where.py | 114 + .../aten_op/test_adaptive_avgpool_aten.py | 127 + .../converters/aten_op/test_batchnorm_aten.py | 65 + .../aten_op/test_binary_ops_aten.py | 205 ++ .../test/converters/aten_op/test_cat_aten.py | 58 + .../aten_op/test_convolution_aten.py | 203 ++ .../converters/aten_op/test_expand_aten.py | 31 + .../converters/aten_op/test_flatten_aten.py | 70 + .../converters/aten_op/test_linear_aten.py | 71 + .../converters/aten_op/test_maxpool_aten.py | 245 ++ .../test/converters/aten_op/test_relu_aten.py | 51 + .../converters/aten_op/test_reshape_aten.py | 102 + .../converters/vanilla/test_add_vanilla.py | 28 + .../vanilla/test_convolution_vanilla.py | 113 + .../dynamo/test/core/test_import_fx2trt.py | 18 + .../dynamo/test/core/test_input.py | 88 + .../test/core/test_input_tensor_spec.py | 93 + .../dynamo/test/core/test_trt_module.py | 147 + ...test_fix_clamp_numerical_limits_to_fp16.py | 72 + .../test/passes/test_fix_reshape_batch_dim.py | 51 + .../passes/test_fuse_permute_linear_trt.py | 88 + .../passes/test_fuse_permute_matmul_trt.py | 142 + .../dynamo/test/passes/test_graph_opts.py | 187 ++ .../dynamo/test/passes/test_multi_fuse_trt.py | 66 + .../test_remove_duplicate_output_args.py | 73 + .../dynamo/test/passes/test_setitem_trt.py | 600 ++++ .../dynamo/test/quant/test_quant_trt.py | 907 ++++++ .../dynamo/test/tools/test_model_packager.py | 56 + .../dynamo/test/tracer/test_acc_shape_prop.py | 98 + .../dynamo/test/tracer/test_acc_tracer.py | 2801 +++++++++++++++++ .../test/tracer/test_dispatch_tracer.py | 245 ++ .../dynamo/test/tracer/test_resnet.py | 86 + .../dynamo/test/trt_lower/test_diagnostics.py | 200 ++ .../test/trt_lower/test_fx2trt_lower.py | 104 + .../dynamo/test/trt_lower/test_observer.py | 128 + .../test/trt_lower/test_observer_gpu.py | 53 + .../trt_lower/trt_operator_supported_test.py | 80 + .../test/trt_lower/trt_splitter_test.py | 1176 +++++++ 104 files changed, 16876 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_flatten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py create mode 100644 py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py create mode 100644 py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py create mode 100644 py/torch_tensorrt/dynamo/test/core/test_input.py create mode 100644 py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py create mode 100644 py/torch_tensorrt/dynamo/test/core/test_trt_module.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py create mode 100644 py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py create mode 100644 py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py create mode 100644 py/torch_tensorrt/dynamo/test/tools/test_model_packager.py create mode 100644 py/torch_tensorrt/dynamo/test/tracer/test_acc_shape_prop.py create mode 100644 py/torch_tensorrt/dynamo/test/tracer/test_acc_tracer.py create mode 100644 py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py create mode 100644 py/torch_tensorrt/dynamo/test/tracer/test_resnet.py create mode 100644 py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py create mode 100644 py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py create mode 100644 py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py create mode 100644 py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py create mode 100644 py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py create mode 100644 py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py new file mode 100644 index 0000000000..0b194e4c77 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py @@ -0,0 +1,99 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestAdaptiveAvgPoolConverter(AccTestCase): + @parameterized.expand( + [ + ((64, 64),), + ((128, 64),), + (64,), + ] + ) + def test_adaptive_avgpool( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.adaptive_avg_pool2d}) + + def test_adaptive_avgpool_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 256, 256), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool2d} + ) + + @parameterized.expand( + [ + ((16, 16, 16),), + ((32, 16, 4),), + (32,), + ] + ) + def test_adaptive_avgpool3d( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 64, 64)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.adaptive_avg_pool3d}) + + def test_adaptive_avgpool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 32, 64, 64), + dtype=torch.float32, + shape_ranges=[ + ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) + ], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool3d} + ) + + # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py new file mode 100644 index 0000000000..1e46e3cff1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + +# from torch_tensorrt.dynamo.tools.common_fx2trt import InputTensorSpec + + +class TestAnyConverters(AccTestCase): + @parameterized.expand( + [ + ("bool", torch.bool), + ("int", torch.int), + ("float", torch.float), + ] + ) + def test_ops(self, _, input_dtype): + class TestModule(nn.Module): + def forward(self, x): + return torch.any(x) + + inputs = [torch.randn(2, 3).to(input_dtype)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.any}, + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("bool", torch.bool, 0), + ("int", torch.int, 1), + ("float", torch.float, 0), + ] + ) + def test_ops_dim(self, _, input_dtype, dim): + class TestModule(nn.Module): + def forward(self, x): + return torch.any(x, dim, keepdim=True) + + inputs = [torch.randn(2, 3).to(input_dtype)] + self.run_test( + TestModule(), inputs, expected_ops={}, test_implicit_batch_dim=False + ) + + @parameterized.expand( + [ + ("bool", torch.bool), + ("int", torch.int), + ("float", torch.float), + ] + ) + def test_ops_method(self, _, input_dtype): + class TestModule(nn.Module): + def forward(self, x): + return x.any() + + inputs = [torch.randn(2, 3).to(input_dtype)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.any}, + test_implicit_batch_dim=False, + ) + + # Testing with shape (-1, -1, -1, -1) results into error: torch.zeros(tuple([*input_t.shape])). Trying to create tensor with negative dimension -1: [-1, -1, -1, -1] + """ + def test_ops_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.any(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.any} + ) + """ + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py new file mode 100644 index 0000000000..72eecb5810 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + +# from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestConverter(AccTestCase): + @parameterized.expand( + [ + ("2d_dim_v1", (5, 5), (2, 3), (1, 2), 0), + ("2d_dim_v2", (5, 5), (2, 3), (2, 2), 1), + ("3d_dim_v1", (20, 20), (2, 3, 2), (2, 2, 2), 0), + # take long time on large dimensions, we do not have better implementation yet + # ("4d_dim_v1", (200, 200, 200, 200), (9, 9, 3, 2), (2, 2, 2, 3), 0), + # ("4d_dim_v2", (200, 200, 200, 200), (1, 15, 512, 1), (4096, 256, 1, 1), 0), + ] + ) + def test_as_strided(self, _, x_size, size, stride, offset): + class Stride(nn.Module): + def forward(self, x): + return torch.as_strided(x, size, stride, offset) + + inputs = [torch.randn(*x_size)] + self.run_test( + Stride(), + inputs, + expected_ops={acc_ops.as_strided}, + test_implicit_batch_dim=False, + ) + + # Testing with shape (-1, 3) results into error: + # RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 + + """ + def test_as_strided_with_dynamic_shape_four_dimensions(self): + class Stride(nn.Module): + def forward(self, x): + return torch.as_strided(torch.tensor([5, 5]), (2, 3), (1, 2), 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3), + dtype=torch.float32, + shape_ranges=[((1, 3), (2, 3), (2, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Stride(), input_specs, expected_ops={acc_ops.as_strided} + ) + """ + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py new file mode 100644 index 0000000000..88f55c58a9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py @@ -0,0 +1,265 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestAvgPoolConverter(AccTestCase): + @parameterized.expand( + [ + ("default", 1), + ("kernal_size", 3), + ("stride", 1, 2), + ("tuple_parameters", 2, (1,), (1,)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + param("include_pad", 2, padding=1, count_include_pad=False), + ] + ) + def test_avg_pool1d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool1d( + kernel_size, stride, padding, ceil_mode, count_include_pad + ) + + def forward(self, x): + return self.avg_pool(x) + + inputs = [torch.randn(1, 3, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d}) + + @parameterized.expand( + [ + ("default", 1), + ("kernal_size", 3), + ("stride", 1, 2), + ("tuple_parameters", 2, (1,), (1,)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + param("include_pad", 2, padding=1, count_include_pad=False), + ] + ) + def test_avg_pool1d_with_dynamic_shape( + self, + test_name="default", + kernel_size=1, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool1d( + kernel_size, stride, padding, ceil_mode, count_include_pad + ) + + def forward(self, x): + return self.avg_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d} + ) + + def test_avg_pool2d_with_dynamic_shape_four_dimensions( + self, + test_name="default", + kernel_size=1, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool2d( + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def forward(self, x): + return self.avg_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} + ) + + @parameterized.expand( + [ + ("default", 1), + ("kernal_size", 3), + ("stride", 1, 2), + ("tuple_parameters", 2, (1, 1), (1, 1)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + param("include_pad", 2, padding=1, count_include_pad=False), + ] + ) + def test_avg_pool2d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool2d( + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def forward(self, x): + return self.avg_pool(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d}) + + @parameterized.expand( + [ + ("kernal_size", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_avg_pool1d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.avg_pool1d( + x, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + + inputs = [torch.randn(1, 3, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d}) + + @parameterized.expand( + [ + ("kernal_size", 2), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_avg_pool2d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.avg_pool2d( + x, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d}) + + def test_stride_none_avg_pool2d_with_dynamic_shape_four_dimensions( + self, + test_name="default", + kernel_size=1, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.avg_pool2d( + x, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py new file mode 100644 index 0000000000..965bbcd729 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py @@ -0,0 +1,66 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestBatchNormConverter(AccTestCase): + def test_batchnorm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.batch_norm}) + + def test_batchnorm1d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.batch_norm} + ) + + def test_batchnorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.batch_norm} + ) + + # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py new file mode 100644 index 0000000000..f2a9fb1620 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py @@ -0,0 +1,212 @@ +from typing import Callable + +import torch +import torch.nn as nn + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + +NEED_TEST_BOTH_CONSTANTS_CASE = True + +elementwise_ops = [ + ((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.sub(x, y)), acc_ops.sub, False), + ((lambda x, y: x.sub(y)), acc_ops.sub, False), + ((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: torch.div(x, y, rounding_mode="trunc")), + acc_ops.trunc_div, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="floor")), + acc_ops.floor_div, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: torch.div(x, y)), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.fmod(x, y)), acc_ops.fmod, not NEED_TEST_BOTH_CONSTANTS_CASE), + # torch.floor_divide rounds result toward zero, rather than -Inf. + # https://github.com/pytorch/pytorch/issues/43874 + ( + (lambda x, y: torch.floor_divide(x, y)), + acc_ops.trunc_div, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x * y), acc_ops.mul, NEED_TEST_BOTH_CONSTANTS_CASE), + (torch.pow, acc_ops.pow, not NEED_TEST_BOTH_CONSTANTS_CASE), +] + + +class TestBinaryOpConverters(AccTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [torch.rand(1, 1) + 1] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops_with_one_constant( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand( + [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] + ) + def test_elementwise_op_with_both_constants( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant0 = torch.nn.Parameter(torch.randn(1)) + self.constant1 = torch.nn.Parameter(torch.randn(1)) + self.orig_op = orig_op + + def forward(self, x): + const = self.orig_op(self.constant0, self.constant1) + return self.orig_op(x, const) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + (-1, -1, -1), + ((1, 1, 1), (2, 2, 2), (3, 3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape( + self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=x_shape, + dtype=torch.float32, + shape_ranges=[x_shape_ranges], + ), + InputTensorSpec( + shape=y_shape, + dtype=torch.float32, + shape_ranges=[y_shape_ranges], + ), + ] + + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape_four_dimensions( + self, _, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + def test_elementwise_ops_with_scalar_lhs(self): + def orig_op(x, y): + return x + y + + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, self.constant) + + m = TestModule(orig_op) + inputs = [torch.randn(10)] + self.run_test( + m, + inputs, + expected_ops={acc_ops.add}, + test_explicit_batch_dim=False, + test_implicit_batch_dim=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py new file mode 100644 index 0000000000..e9232c35d7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestCatConverter(AccTestCase): + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat(self, _, op): + class Cat(nn.Module): + def forward(self, x, y, z): + return op((x, y, z), 1) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) + + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat_neg(self, _, op): + class Cat(nn.Module): + def forward(self, x, y, z): + return op((x, y, z), -1) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)] + self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) + + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat_with_dynamic_shape(self, _, op): + class Cat(nn.Module): + def forward(self, x, y): + x = x + y + return op((x, y), 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (2, 3, 4), (2, 3, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (2, 3, 4), (2, 3, 10))], + ), + ] + self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) + + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat_with_dynamic_shape_four_dimensions(self, _, op): + class Cat(nn.Module): + def forward(self, x, y): + x = x + y + return op((x, y), 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 4), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 4), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) + + def test_concat(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.concat((x, y, z), 1) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py new file mode 100644 index 0000000000..49fb8cff5b --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestChunkConverter(AccTestCase): + @parameterized.expand( + [ + ("chunk", 3, 1), + ("chunk", 2000, 2), + ("chunk", 3, -2), + ] + ) + def test_chunk(self, _, chunk, dim): + class Chunk(nn.Module): + def forward(self, x): + return x.chunk(chunk, dim)[0] + + inputs = [torch.randn(3, 10, 20)] + self.run_test( + Chunk(), + inputs, + expected_ops={acc_ops.chunk}, + ) + + @parameterized.expand( + [ + ("chunk", 3, 1), + ("chunk", 2000, 1), + ("chunk", 3, -2), + ] + ) + def test_chunk_with_dynamic_shape(self, _, chunk, dim): + class Chunk(nn.Module): + def forward(self, x): + return x.chunk(chunk, dim)[0] + + input_specs = [ + InputTensorSpec( + shape=(-1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 20), (5, 10, 20), (10, 10, 20))], + ), + ] + self.run_test_with_dynamic_shape( + Chunk(), input_specs, expected_ops={acc_ops.chunk} + ) + + # Testing with (-1, -1, -1, -1) results in Error: AssertionError: Can't chunk on dynamic shape dimension! + @parameterized.expand( + [ + ("chunk", 3, 1), + ("chunk", 2000, 1), + ("chunk", 3, -2), + ] + ) + def test_chunk_with_dynamic_shape_four_dimensions(self, _, chunk, dim): + class Chunk(nn.Module): + def forward(self, x): + return x.chunk(chunk, dim)[0] + + input_specs = [ + InputTensorSpec( + shape=(-1, 1, 3, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 5), (3, 1, 3, 5), (5, 1, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Chunk(), input_specs, expected_ops={acc_ops.chunk} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py new file mode 100644 index 0000000000..96e611626c --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py @@ -0,0 +1,71 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestClampConverter(AccTestCase): + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + param("float32Boundary", min=-3.4028234663852886e38), + ] + ) + def test_clamp( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + inputs = [torch.randn(3, 4)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.clamp}) + + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + ] + ) + def test_clamp_with_dynamic_shape_four_dimensions( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + class TestScalarModule(torch.nn.Module): + def forward(self, x): + y = torch.sum(x) + return torch.clamp(y, min, max) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.clamp} + ) + self.run_test_with_dynamic_shape( + TestScalarModule(), input_specs, expected_ops={acc_ops.clamp} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py new file mode 100644 index 0000000000..bedc75f194 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py @@ -0,0 +1,199 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestConvolutionConverter(AccTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.conv1d}, + test_explicit_precision=True, + ) + + @parameterized.expand( + [ + ("default", 1), + ] + ) + def test_conv1d_with_dynamic_shape( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv1d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) + + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv2d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + # TODO TRT 8.4.1 will trigger issue with this test. T127981773 + # param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) + + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv3d} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py new file mode 100644 index 0000000000..212a77ec63 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py @@ -0,0 +1,68 @@ +import unittest + +import tensorrt as trt +import torch.fx +import torch.nn as nn + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +@unittest.skip( + """ + Tests related to quantize have issue creating engine, disable now. + """ +) +@unittest.skipIf( + trt.__version__ < "8.0", + "Explicit quantization only supported in TensorRT 8.0 and later", +) +class TestDequantizeConverter(AccTestCase): + def test_dequantize(self): + class TestModule(nn.Module): + def forward(self, x): + x = torch.quantize_per_tensor(x, 1, 0, torch.quint8) + return x.dequantize() + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.dequantize}) + + def test_dequantize_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + x = torch.quantize_per_tensor(x, 1, 0, torch.quint8) + return x.dequantize() + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.dequantize} + ) + + def test_dequantize_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + x = torch.quantize_per_tensor(x, 1, 0, torch.quint8) + return x.dequantize() + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.dequantize} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py new file mode 100644 index 0000000000..88a7e5fae7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestConverter(AccTestCase): + @parameterized.expand( + [ + ("2d_dim", "ij,jk->ik", (2, 3), (3, 4)), + ("2d_dim_ext", "ij,kj->ik", (2, 3), (4, 3)), + ("3d_dim", "cxd,cyd->cxy", (3, 4, 5), (3, 6, 5)), + ("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)), + ("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)), + # TRT does not support ellipsis or diagonal operations + ] + ) + def test_einsum(self, _, equation, x_size, y_size): + class Einsum(nn.Module): + def forward(self, x, y): + return torch.einsum(equation, x, y) + + inputs = [torch.randn(*x_size), torch.randn(*y_size)] + self.run_test( + Einsum(), + inputs, + expected_ops={acc_ops.einsum}, + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)), + ("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)), + # TRT does not support ellipsis or diagonal operations + ] + ) + def test_einsum_with_dynamic_shape_four_dimensions( + self, _, equation, x_size, y_size + ): + class Einsum(nn.Module): + def forward(self, x, y): + return torch.einsum(equation, x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Einsum(), input_specs, expected_ops={acc_ops.einsum} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py new file mode 100644 index 0000000000..313d8ec022 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestELUConverter(AccTestCase): + def test_elu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x, alpha=1.5) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.elu}) + + def test_elu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.elu} + ) + + def test_elu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.elu} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py new file mode 100644 index 0000000000..ecfa171f0c --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py @@ -0,0 +1,107 @@ +import unittest + +import torch + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +@unittest.skip( + "Current implementation is limited. All implementations in hf use int64. T113156424" +) +class TestEmbeddingConverter(AccTestCase): + @parameterized.expand( + [ + param( + test_name="1d_indices", + indices_tensor=torch.tensor([3, 1, 2]), + weights_tensor=torch.randn(5, 10), + ), + param( + test_name="2d_indices", + indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]]), + weights_tensor=torch.randn(5, 10), + ), + param( + test_name="3d_indices", + indices_tensor=torch.tensor([[[0, 1], [2, 3]], [[3, 4], [4, 0]]]), + weights_tensor=torch.randn(5, 10), + ), + ] + ) + def test_embedding( + self, + test_name, + indices_tensor, + weights_tensor, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + class TestEmbedding(torch.nn.Module): + def forward(self, indices, weights): + return torch.nn.functional.embedding( + input=indices, + weight=weights, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + self.run_test( + TestEmbedding(), + inputs=[indices_tensor.int(), weights_tensor.float()], + expected_ops={acc_ops.embedding}, + test_implicit_batch_dim=False, + test_explicit_batch_dim=True, + ) + + def test_embedding_with_dynamic_shape_four_dimensions( + self, + test_name, + indices_tensor, + weights_tensor, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + class TestEmbedding(torch.nn.Module): + def forward(self, indices, weights): + return torch.nn.functional.embedding( + input=indices, + weight=weights, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestEmbedding(), input_specs, expected_ops={acc_ops.embedding} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py new file mode 100644 index 0000000000..befe675232 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py @@ -0,0 +1,289 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestEqConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ( + "rand_4d_float_bool_dim", + torch.randn(3, 4, 5, 6).to(torch.float), + torch.randn(3, 1, 1, 6).to(torch.bool), + ), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def forward(self, x, y): + mask = torch.eq(x, y) + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False + ) + + +class TestEqMethodConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ( + "rand_4d_float_bool_dim", + torch.randn(3, 4, 5, 6).to(torch.float), + torch.randn(3, 1, 1, 6).to(torch.bool), + ), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def forward(self, x, y): + mask = x.eq(y) + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ( + "rand_4d_float_bool_dim", + torch.randn(3, 4, 5, 6).to(torch.float), + torch.randn(3, 1, 1, 6).to(torch.bool), + ), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def forward(self, x, y): + mask = x == y + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def forward(self, x, y): + return x == y + + inputs = [ + input, + other, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase): + def test_eq(self): + class Eq(torch.nn.Module): + def forward(self, x, y): + return x == y + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq}) + + +class TestEqOperatorConstantConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ("rand_2d_float_single_bool", torch.randn(3, 4), False), + ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), + ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def __init__(self): + super().__init__() + self.other = other + + def forward(self, x): + return x == self.other + + inputs = [ + input, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverter(AccTestCase): + def test_eq(self): + class Eq(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] == 4 + + input = torch.randn(3, 4) + inputs = [ + input, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverterWithDynamicShape(AccTestCase): + def test_eq(self): + class Eq(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] == 4 + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py new file mode 100644 index 0000000000..fd369459f3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +class TestExpandConverter(AccTestCase): + @parameterized.expand( + [ + ("2d_dim", (2, 3), (2, 1)), + ("3d_dim", (2, 3, 4), (2, 1, 1)), + ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), + ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), + ] + ) + def test_expand(self, _, sizes, init_size): + class Expand(nn.Module): + def forward(self, x): + return x.expand(*sizes) + + inputs = [torch.randn(*init_size)] + self.run_test( + Expand(), + inputs, + expected_ops={acc_ops.expand}, + ) + + # Dynamic shape is not suitable for the expand operation. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_flatten.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_flatten.py new file mode 100644 index 0000000000..346669d695 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_flatten.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestFlattenConverter(AccTestCase): + @parameterized.expand( + [ + ("flatten_middle_dims", 1, 2), + ("flatten_last_3_dims", 1, 3), + ("flatten_last_1", 3, 3), + ("flatten_all", 0, 3), + ] + ) + def test_flatten(self, _, start_dim, end_dim): + class Flatten(nn.Module): + def __init__(self, start, end): + super().__init__() + self.start = start + self.end = end + + def forward(self, x): + return torch.flatten(x, self.start, self.end) + + inputs = [torch.randn(1, 2, 3, 1)] + self.run_test( + Flatten(start_dim, end_dim), + inputs, + expected_ops={acc_ops.flatten}, + test_implicit_batch_dim=(start_dim != 0), + ) + + @parameterized.expand( + [ + ("flatten_middle_dims", 1, 2), + ("flatten_last_3_dims", 2, 4), + ("flatten_last_1", 4, 4), + ("flatten_first_2", 0, 1), + ("flatten_all", 0, 4), + ] + ) + def test_flatten_with_dynamic_shape(self, _, start_dim, end_dim): + class Flatten(nn.Module): + def __init__(self, start, end): + super().__init__() + self.start = start + self.end = end + + def forward(self, x): + return torch.flatten(x, self.start, self.end) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 3, 2, 1), (3, 3, 3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Flatten(start_dim, end_dim), + input_specs, + expected_ops={acc_ops.flatten}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py new file mode 100644 index 0000000000..e7b7bd806d --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py @@ -0,0 +1,95 @@ +import unittest + +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +@unittest.skip( + reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4" +) +class TestGELU(AccTestCase): + def test_gelu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + inputs = [torch.randn(3, 10, 20)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.gelu}, + test_implicit_batch_dim=False, + ) + + def test_gelu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.gelu} + ) + + def test_gelu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.gelu} + ) + + def test_gelu_module(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, x): + return self.gelu(x) + + inputs = [torch.randn(3, 10, 20)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.gelu}, + test_implicit_batch_dim=False, + ) + + def test_gelu_module_throw(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU(approximate="tanh") + + def forward(self, x): + return self.gelu(x) + + inputs = [torch.randn(3, 10, 20)] + self.run_test_with_assert_error( + TestModule(), + inputs, + expect_error=RuntimeError, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py new file mode 100644 index 0000000000..9cc68ad87e --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py @@ -0,0 +1,198 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestGetitemConverter(AccTestCase): + @parameterized.expand( + [ + ("slice_batch_dim", slice(None, None, None)), + ("slice_basic", (slice(None, None, None), slice(0, 3, 2))), + ("slice_full", (slice(None, None, None), slice(0, 10, 3))), + ("ellipsis", (slice(None, None, None), ..., slice(0, 3, 2))), + ( + "slice_all_none", + (slice(None, None, None), slice(None, None, None)), + ), + ( + "slice_start_none", + (slice(None, None, None), slice(None, 2, 1)), + ), + ("slice_end_none", (slice(None, None, None), slice(1, None, 1))), + ( + "slice_step_none", + (slice(None, None, None), slice(0, 3, None)), + ), + ("slice_neg_idx", (slice(None, None, None), -1)), + ("slice_neg_slice", (slice(None, None, None), slice(-8, -2, 3))), + ("multi_dim", (slice(None, None, None), 0, 1)), + ( + "slice_multi_dim", + (slice(None, None, None), slice(0, 3, 2), slice(1, -1, 3)), + ), + ( + "none", + (slice(None, None, None), None, slice(1, -1, 3), 1), + ), + ( + "slice_zero_slice", + (slice(None, None, None), slice(None, None, None), slice(0, 0, None)), + ), + ] + ) + def test_getitem(self, _, idx): + class Getitem(nn.Module): + def __init__(self, idx): + super().__init__() + self.idx = idx + + def forward(self, x): + x = x + x + return x[self.idx] + + inputs = [torch.randn(2, 10, 10, 10)] + self.run_test(Getitem(idx), inputs, expected_ops={acc_ops.getitem}) + + @parameterized.expand( + [ + ("slice_batch_dim", slice(None, None, None)), + ("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))), + ( + "slice_all_none", + (slice(None, None, None), slice(None, None, None)), + ), + ( + "slice_end_none", + (slice(None, None, None), slice(None, None, None), slice(1, None, 1)), + ), + ( + "slice_step_none", + (slice(None, None, None), slice(None, None, None), slice(0, 3, None)), + ), + ("slice_neg_idx", (slice(None, None, None), -1, slice(None, None, None))), + ( + "slice_neg_slice", + (slice(None, None, None), slice(None, None, None), slice(-8, -2, 3)), + ), + ("multi_dim", (slice(None, None, None), 0, 1)), + ( + "slice_multi_dim", + (slice(None, None, None), slice(0, 3, 2), slice(1, -1, 3)), + ), + ( + "none", + (slice(None, None, None), None, slice(1, -1, 3)), + ), + ] + ) + def test_getitem_with_dynamic_shape(self, _, idx): + class Getitem(nn.Module): + def __init__(self, idx): + super().__init__() + self.idx = idx + + def forward(self, x): + x = x + x + return x[self.idx] + + input_specs = [ + InputTensorSpec( + shape=(-1, 256, 256), + dtype=torch.float32, + shape_ranges=[((1, 256, 256), (3, 256, 256), (5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + Getitem(idx), input_specs, expected_ops={acc_ops.getitem} + ) + + @parameterized.expand( + [ + ("slice_batch_dim", slice(None, None, None)), + ("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))), + ( + "slice_all_none", + (slice(None, None, None), slice(None, None, None)), + ), + ( + "slice_end_none", + (slice(None, None, None), slice(None, None, None), slice(1, None, 1)), + ), + ( + "slice_step_none", + (slice(None, None, None), slice(None, None, None), slice(0, 3, None)), + ), + ] + ) + def test_getitem_with_multi_dynamic_shape(self, _, idx): + class Getitem(nn.Module): + def __init__(self, idx): + super().__init__() + self.idx = idx + + def forward(self, x): + x = x + x + return x[self.idx] + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 256), + dtype=torch.float32, + shape_ranges=[((1, 128, 256), (3, 192, 256), (5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + Getitem(idx), input_specs, expected_ops={acc_ops.getitem} + ) + + # Testing with following parameters results into Error: + # AssertionError: We don't support slicing tensor on dynamic shape. + """ + ("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))), + ( + "slice_end_none", + (slice(None, None, None), slice(None, None, None), slice(1, None, 1)), + ), + ( + "slice_step_none", + (slice(None, None, None), slice(None, None, None), slice(0, 3, None)), + ), + """ + + @parameterized.expand( + [ + ("slice_batch_dim", slice(None, None, None)), + ( + "slice_all_none", + (slice(None, None, None), slice(None, None, None)), + ), + ] + ) + def test_getitem_with_dynamic_shape_four_dimensions(self, _, idx): + class Getitem(nn.Module): + def __init__(self, idx): + super().__init__() + self.idx = idx + + def forward(self, x): + x = x + x + return x[self.idx] + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Getitem(idx), input_specs, expected_ops={acc_ops.getitem} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py new file mode 100644 index 0000000000..b6e8e602d7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py @@ -0,0 +1,276 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestGtConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ] + ) + def test_gt(self, _, input, other): + class Gt(torch.nn.Module): + def forward(self, x, y): + mask = torch.gt(x, y) + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False + ) + + +class TestGtMethodConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ] + ) + def test_gt(self, _, input, other): + class Gt(torch.nn.Module): + def forward(self, x, y): + mask = x.gt(y) + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False + ) + + +class TestGtOperatorConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ] + ) + def test_gt(self, _, input, other): + class Gt(torch.nn.Module): + def forward(self, x, y): + mask = x > y + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def forward(self, x, y): + return x > y + + inputs = [ + input, + other, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase): + def test_eq( + self, + ): + class Eq(torch.nn.Module): + def forward(self, x, y): + return x > y + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.gt}) + + +class TestEqOperatorConstantConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ("rand_2d_float_single_bool", torch.randn(3, 4), False), + ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), + ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def __init__(self): + super().__init__() + self.other = other + + def forward(self, x): + return x > self.other + + inputs = [ + input, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverter(AccTestCase): + def test_gt(self): + class Gt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] > 4 + + input = torch.randn(3, 4) + inputs = [ + input, + ] + self.run_test( + Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverterWithDynamicShape(AccTestCase): + def test_gt(self): + class Gt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] > 4 + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape(Gt(), input_specs, expected_ops={acc_ops.gt}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py new file mode 100644 index 0000000000..86d5b1a099 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py @@ -0,0 +1,59 @@ +import torch +from parameterized import parameterized +from torch import nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.fx.tracer.acc_tracer import acc_ops + + +class TestHardSigmoid(AccTestCase): + @parameterized.expand( + [ + ("3", 3), + ("0", 0), + ("-3", -4), + ] + ) + def test_hardsigmoid(self, _, pad): + class Hardsigmoid(nn.Module): + def forward(self, x): + return torch.nn.functional.hardsigmoid(x) + + inputs = [torch.randn(1, 2, 3) + pad] + self.run_test(Hardsigmoid(), inputs, expected_ops={acc_ops.hardsigmoid}) + + def test_hardsigmoid_with_dynamic_shape(self): + class Hardsigmoid(nn.Module): + def forward(self, x): + return torch.nn.functional.hardsigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Hardsigmoid(), input_specs, expected_ops={acc_ops.hardsigmoid} + ) + + def test_hardsigmoid_with_dynamic_shape_four_dimensions(self): + class Hardsigmoid(nn.Module): + def forward(self, x): + return torch.nn.functional.hardsigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Hardsigmoid(), input_specs, expected_ops={acc_ops.hardsigmoid} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py new file mode 100644 index 0000000000..97d326851a --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestHardtanhConverter(AccTestCase): + @parameterized.expand( + [ + (-2.0, 6), + (0, 1), + (0.5, 7), + ] + ) + def test_hardtanh(self, test_min_value, test_max_value): + class Hardtanh(nn.Module): + def forward(self, x): + return nn.functional.hardtanh( + x, min_val=test_min_value, max_val=test_max_value + ) + + inputs = [torch.randn(2, 10, 10, 10)] + self.run_test(Hardtanh(), inputs, expected_ops={acc_ops.hardtanh}) + + +class TestHardtanhConverterWithDynamicShape(AccTestCase): + @parameterized.expand( + [ + (-2.0, 6), + (0, 1), + (0.5, 7), + ] + ) + def test_hardtanh(self, test_min_value, test_max_value): + class Hardtanh(nn.Module): + def forward(self, x): + return nn.functional.hardtanh( + x, min_val=test_min_value, max_val=test_max_value + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Hardtanh(), input_specs, expected_ops={acc_ops.hardtanh} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py new file mode 100644 index 0000000000..c3e10f96ee --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestInterpolateConverter(AccTestCase): + @parameterized.expand( + [ + # 3D + ("3d_dim_scale", (2, 3, 4), (None), (2), ("nearest"), (None)), + ("3d_dim_scale_seq", (2, 3, 4), (None), (2,), ("nearest"), (None)), + ("3d_dim_size", (2, 3, 4), (2), (None), ("nearest"), (None)), + ("3d_dim_size_seq", (2, 3, 4), (8,), (None), ("nearest"), (None)), + ( + "3d_dim_scale_linear", + (2, 3, 4), + (None), + (2), + ("linear"), + (None), + ), # linear for 3D only + ( + "3d_dim_scale_align", + (2, 3, 4), + (None), + (2), + ("linear"), + (True), + ), # align_corners for linear,bilinear,trilinear,bicubic only + # 4D + ("4d_dim_scale", (2, 3, 4, 5), (None), (2), ("nearest"), (None)), + ("4d_dim_scale_seq", (2, 3, 4, 5), (None), (2, 2), ("nearest"), (None)), + ("4d_dim_size", (2, 3, 4, 5), (2), (None), ("nearest"), (None)), + ("4d_dim_size_seq", (2, 3, 4, 5), (8, 10), (None), ("nearest"), (None)), + ( + "4d_dim_scale_bilinear", + (2, 3, 4, 5), + (None), + (2), + ("bilinear"), + (None), + ), # linear for 4D only + ( + "4d_dim_scale_bilinear_align_corners_bool", + (2, 3, 4, 5), + (None), + (2), + ("bilinear"), + (False), + ), # linear for 4D only + ( + "4d_dim_scale_align", + (2, 3, 4, 5), + (None), + (2), + ("bilinear"), + (True), + ), # align_corners for linear,bilinear,trilinear,bicubic only + # 5D + ("5d_dim_scale", (2, 3, 4, 5, 6), (None), (2), ("nearest"), (None)), + ( + "5d_dim_scale_seq", + (2, 3, 4, 5, 6), + (None), + (2, 2, 2), + ("nearest"), + (None), + ), + ("5d_dim_size", (2, 3, 4, 5, 6), (2), (None), ("nearest"), (None)), + ( + "5d_dim_size_seq", + (2, 3, 4, 5, 6), + (8, 10, 12), + (None), + ("nearest"), + (None), + ), + ( + "5d_dim_scale_trilinear", + (2, 3, 4, 5, 6), + (None), + (2), + ("trilinear"), + (None), + ), # trilinear for 5D only + ( + "5d_dim_scale_align", + (2, 3, 4, 5, 6), + (None), + (2), + ("trilinear"), + (True), + ), # align_corners for linear,bilinear,trilinear,bicubic only + ] + ) + def test_interpolate(self, _, init_size, size, scale_factor, mode, align_corners): + class Interpolate(nn.Module): + def forward(self, x): + return torch.nn.functional.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) # only one of size or scale_factor should be defined + + inputs = [torch.randn(*init_size)] + self.run_test( + Interpolate(), + inputs, + expected_ops={acc_ops.interpolate}, + ) + + @parameterized.expand( + [ + # 4D + ("4d_dim_scale", (2, 3, 4, 5), (None), (2), ("nearest"), (None)), + ] + ) + def test_interpolate_with_dynamic_shape_four_dimensions( + self, _, init_size, size, scale_factor, mode, align_corners + ): + class Interpolate(nn.Module): + def forward(self, x): + return torch.nn.functional.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) # only one of size or scale_factor should be defined + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Interpolate(), input_specs, expected_ops={acc_ops.interpolate} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py new file mode 100644 index 0000000000..9717eb52c1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py @@ -0,0 +1,63 @@ +import unittest + +import torch + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +@unittest.skip("Implementation is commented out due to accuracy issue T113156424") +class TestInfConverter(AccTestCase): + def test_isinf(self): + class Test(torch.nn.Module): + def forward(self, x): + return torch.isinf(x) + + input = torch.randn(2, 3) + input[0][0] = float("inf") + input[0][1] = float("-inf") + input.cuda() + inputs = [ + input, + ] + self.run_test( + Test(), inputs, expected_ops={acc_ops.isinf}, test_implicit_batch_dim=False + ) + + def test_isinf_large(self): + class Test(torch.nn.Module): + def forward(self, x): + return torch.isinf(x) + + input = torch.randn(2, 3, 4, 5) + input[0][0][0][:] = float("inf") + input[0][0][1][:] = float("-inf") + input.cuda() + inputs = [ + input, + ] + self.run_test( + Test(), inputs, expected_ops={acc_ops.isinf}, test_implicit_batch_dim=False + ) + + def test_isinf_large_with_dynamic_shape_four_dimensions(self): + class Test(torch.nn.Module): + def forward(self, x): + return torch.isinf(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Test(), input_specs, expected_ops={acc_ops.isinf} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py new file mode 100644 index 0000000000..601aa7ee91 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestLeakyReLUConverter(AccTestCase): + def test_leaky_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.leaky_relu}) + + def test_leaky_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.leaky_relu} + ) + + def test_leaky_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.leaky_relu} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py new file mode 100644 index 0000000000..361a25fa04 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py @@ -0,0 +1,60 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestLinearConverter(AccTestCase): + @parameterized.expand( + [ + ("default", [1, 512]), + ("matrix", [32, 512]), + ("no_bias", [1, 512], False), + ] + ) + def test_linear( + self, + test_name, + shape, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias) + + def forward(self, x): + return self.linear(x) + + inputs = [torch.randn(shape)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.linear}) + + def test_linear_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256) + + def forward(self, x): + return self.linear(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 512), + dtype=torch.float32, + shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={acc_ops.linear}, + ) + + # Testing with (-1, -1, 512) results into following error: + # AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py new file mode 100644 index 0000000000..85f18ea3f3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py @@ -0,0 +1,230 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestAndMethodSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_and(self, _, input, other): + class And(torch.nn.Module): + def forward(self, x, y): + return x.logical_and(y) + + inputs = [ + input, + other, + ] + self.run_test( + And(), + inputs, + expected_ops={acc_ops.logical_and}, + test_implicit_batch_dim=False, + ) + + +class TestAndMethodSimpleConverterWithDynamicShape(AccTestCase): + def test_and(self): + class And(torch.nn.Module): + def forward(self, x, y): + return x.logical_and(y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + And(), input_specs, expected_ops={acc_ops.logical_and} + ) + + +class TestAndFunctionSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_and(self, _, input, other): + class And(torch.nn.Module): + def forward(self, x, y): + return torch.logical_and(x, y) + + inputs = [ + input, + other, + ] + self.run_test( + And(), + inputs, + expected_ops={acc_ops.logical_and}, + test_implicit_batch_dim=False, + ) + + +class TestAndOperatorSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_and(self, _, input, other): + class And(torch.nn.Module): + def forward(self, x, y): + return x & y + + inputs = [ + input, + other, + ] + self.run_test( + And(), + inputs, + expected_ops={acc_ops.bitwise_and}, + test_implicit_batch_dim=False, + ) + + +class TestAndOperatorConstantConverter(AccTestCase): + @parameterized.expand( + [ + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_and(self, _, input, other): + class And(torch.nn.Module): + def __init__(self): + super().__init__() + self.other = other + + def forward(self, x): + return x & self.other + + inputs = [ + input, + ] + self.run_test( + And(), + inputs, + expected_ops={acc_ops.bitwise_and}, + test_implicit_batch_dim=False, + ) + + +class TestAndFunctionSimpleConverterWithDynamicShape(AccTestCase): + def test_and(self): + class And(torch.nn.Module): + def forward(self, x, y): + return torch.logical_and(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.bool, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.bool, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + And(), input_specs, expected_ops={acc_ops.logical_and} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py new file mode 100644 index 0000000000..265f5735eb --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py @@ -0,0 +1,201 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestLogicalOrMethodSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), + ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), + ( + "rand_4d_bool_bool", + torch.randn(3, 4, 5, 6) > 0, + torch.randn(3, 4, 5, 6) > 0, + ), + ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4) > 0, + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0) > 0, + ), + ] + ) + def test_logical_or(self, _, input, other): + class LogicalOr(torch.nn.Module): + def forward(self, x, y): + return x.logical_or(y) + + inputs = [ + input, + other, + ] + self.run_test( + LogicalOr(), + inputs, + expected_ops={acc_ops.logical_or}, + test_implicit_batch_dim=False, + ) + + +class TestLogicalOrMethodSimpleConverterWithDynamicShape(AccTestCase): + def test_logical_or(self): + class LogicalOr(torch.nn.Module): + def forward(self, x, y): + return x.logical_or(y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + LogicalOr(), input_specs, expected_ops={acc_ops.logical_or} + ) + + +class TestLogicalOrFunctionSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), + ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), + ( + "rand_4d_bool_bool", + torch.randn(3, 4, 5, 6) > 0, + torch.randn(3, 4, 5, 6) > 0, + ), + ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4) > 0, + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0) > 0, + ), + ] + ) + def test_logical_or(self, _, input, other): + class LogicalOr(torch.nn.Module): + def forward(self, x, y): + return torch.logical_or(x, y) + + inputs = [ + input, + other, + ] + self.run_test( + LogicalOr(), + inputs, + expected_ops={acc_ops.logical_or}, + test_implicit_batch_dim=False, + ) + + +class TestLogicalOrFunctionSimpleConverterWithDynamicShape(AccTestCase): + def test_logical_or(self): + class LogicalOr(torch.nn.Module): + def forward(self, x, y): + return torch.logical_or(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + LogicalOr(), input_specs, expected_ops={acc_ops.logical_or} + ) + + +class TestLogicalOrOperatorSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), + ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), + ( + "rand_4d_bool_bool", + torch.randn(3, 4, 5, 6) > 0, + torch.randn(3, 4, 5, 6) > 0, + ), + ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4) > 0, + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0) > 0, + ), + ] + ) + def test_logical_or(self, _, input, other): + class LogicalOr(torch.nn.Module): + def forward(self, x, y): + return x | y + + inputs = [ + input, + other, + ] + self.run_test( + LogicalOr(), + inputs, + expected_ops={acc_ops.logical_or}, + test_implicit_batch_dim=False, + ) + + +class TestLogicalOrOperatorSimpleConverterWithDynamicShape(AccTestCase): + def test_logical_or(self): + class LogicalOr(torch.nn.Module): + def forward(self, x, y): + return x | y + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.bool, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.bool, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + LogicalOr(), input_specs, expected_ops={acc_ops.logical_or} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py new file mode 100644 index 0000000000..0cd6174950 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py @@ -0,0 +1,201 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestLogicalXorMethodSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), + ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), + ( + "rand_4d_bool_bool", + torch.randn(3, 4, 5, 6) > 0, + torch.randn(3, 4, 5, 6) > 0, + ), + ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4) > 0, + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0) > 0, + ), + ] + ) + def test_logical_xor(self, _, input, other): + class LogicalXor(torch.nn.Module): + def forward(self, x, y): + return x.logical_xor(y) + + inputs = [ + input, + other, + ] + self.run_test( + LogicalXor(), + inputs, + expected_ops={acc_ops.logical_xor}, + test_implicit_batch_dim=False, + ) + + +class TestLogicalXorMethodSimpleConverterWithDynamicShape(AccTestCase): + def test_logical_xor(self): + class LogicalXor(torch.nn.Module): + def forward(self, x, y): + return x.logical_xor(y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + LogicalXor(), input_specs, expected_ops={acc_ops.logical_xor} + ) + + +class TestLogicalXorFunctionSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), + ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), + ( + "rand_4d_bool_bool", + torch.randn(3, 4, 5, 6) > 0, + torch.randn(3, 4, 5, 6) > 0, + ), + ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4) > 0, + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0) > 0, + ), + ] + ) + def test_logical_xor(self, _, input, other): + class LogicalXor(torch.nn.Module): + def forward(self, x, y): + return torch.logical_xor(x, y) + + inputs = [ + input, + other, + ] + self.run_test( + LogicalXor(), + inputs, + expected_ops={acc_ops.logical_xor}, + test_implicit_batch_dim=False, + ) + + +class TestLogicalXorFunctionSimpleConverterWithDynamicShape(AccTestCase): + def test_logical_xor(self): + class LogicalXor(torch.nn.Module): + def forward(self, x, y): + return torch.logical_xor(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + LogicalXor(), input_specs, expected_ops={acc_ops.logical_xor} + ) + + +class TestLogicalXorOperatorSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), + ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), + ( + "rand_4d_bool_bool", + torch.randn(3, 4, 5, 6) > 0, + torch.randn(3, 4, 5, 6) > 0, + ), + ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4) > 0, + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0) > 0, + ), + ] + ) + def test_logical_xor(self, _, input, other): + class LogicalXor(torch.nn.Module): + def forward(self, x, y): + return x ^ y + + inputs = [ + input, + other, + ] + self.run_test( + LogicalXor(), + inputs, + expected_ops={acc_ops.logical_xor}, + test_implicit_batch_dim=False, + ) + + +class TestLogicalXorOperatorSimpleConverterWithDynamicShape(AccTestCase): + def test_logical_xor(self): + class LogicalXor(torch.nn.Module): + def forward(self, x, y): + return x ^ y + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.bool, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.bool, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + LogicalXor(), input_specs, expected_ops={acc_ops.logical_xor} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py new file mode 100644 index 0000000000..df51e6bf58 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py @@ -0,0 +1,274 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestLtConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + (torch.randn(3, 4)).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ] + ) + def test_lt(self, _, input, other): + class Lt(torch.nn.Module): + def forward(self, x, y): + mask = torch.lt(x, y) + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False + ) + + +class TestLtMethodConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ] + ) + def test_lt(self, _, input, other): + class Lt(torch.nn.Module): + def forward(self, x, y): + mask = x.lt(y) + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False + ) + + +class TestLtOperatorConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), + ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), + ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_bool", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.zeros(3, 4).to(torch.int), + ), + ] + ) + def test_lt(self, _, input, other): + class Lt(torch.nn.Module): + def forward(self, x, y): + mask = x < y + return x.masked_fill(mask, 5) + + inputs = [ + input, + other, + ] + self.run_test( + Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorSimpleConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def forward(self, x, y): + return x < y + + inputs = [ + input, + other, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False + ) + + +class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase): + def test_eq(self): + class Eq(torch.nn.Module): + def forward(self, x, y): + return x < y + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.lt}) + + +class TestEqOperatorConstantConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ("rand_2d_float_single_bool", torch.randn(3, 4), False), + ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), + ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), + ] + ) + def test_eq(self, _, input, other): + class Eq(torch.nn.Module): + def __init__(self): + super().__init__() + self.other = other + + def forward(self, x): + return x < self.other + + inputs = [ + input, + ] + self.run_test( + Eq(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverter(AccTestCase): + def test_lt(self): + class Lt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] < 4 + + input = torch.randn(3, 4) + inputs = [ + input, + ] + self.run_test( + Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverterWithDynamicShape(AccTestCase): + def test_lt(self): + class Lt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] < 4 + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ) + ] + + self.run_test_with_dynamic_shape(Lt(), input_specs, expected_ops={acc_ops.lt}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py new file mode 100644 index 0000000000..9e3ca83015 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +class TestMaskedFill(AccTestCase): + @parameterized.expand( + [ + ("same_dims", (2, 3), 5), + ("same_dims_tensor", (2, 3), torch.tensor(5)), + ("not_same_dims", (2, 1), 5), + ("not_same_dims_tensor", (2, 1), torch.tensor(5)), + ] + ) + def test_masked_fill(self, _, input_shape, value): + class MaskedFill(nn.Module): + def __init__(self, input_shape): + super().__init__() + self.mask = torch.zeros(input_shape) + self.mask[0, 0] = 1 + self.mask = self.mask.to(torch.bool) + self.value = value + + def forward(self, x): + return x.masked_fill(self.mask, self.value) + + inputs = [torch.ones(*input_shape)] + self.run_test( + MaskedFill(input_shape), + inputs, + expected_ops={acc_ops.masked_fill}, + test_implicit_batch_dim=False, + ) + + # Testing with (-1, -1, -1, -1) results into following error: + # RuntimeError: Trying to create tensor with negative dimension -1: [-1, -1, -1, -1] + + @parameterized.expand( + [ + ("same_dims", (2, 3), (2, 3), 5), + ("expand_first_dims", (2, 3), (1, 3), 5), + ("expand_second_dims", (2, 3), (2, 1), 5), + ("expand_third_dims", (2, 3, 4), (2, 3, 1), 5), + ] + ) + def test_masked_fill_expand(self, _, input_shape, mask_shape, value): + class MaskedFill(nn.Module): + def __init__(self, input_shape): + super().__init__() + self.value = value + + def forward(self, x, mask_input): + return x.masked_fill(mask_input, self.value) + + mask_input = torch.zeros(*mask_shape) + index = (0) * len(mask_shape) + mask_input[index] = 1 + mask_input = mask_input.to(torch.bool) + inputs = [torch.ones(*input_shape), mask_input] + self.run_test( + MaskedFill(input_shape), + inputs, + expected_ops={acc_ops.masked_fill}, + test_implicit_batch_dim=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py new file mode 100644 index 0000000000..7e7456c437 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMatMulConverter(AccTestCase): + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("2_1", (2, 3), (3,)), + ("4_2", (1, 2, 2, 3), (3, 2)), + ("1_2", (3,), (3, 2)), + ] + ) + def test_matmul_other_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.other = nn.Parameter(torch.randn(*other_shape)) + + def forward(self, input): + return torch.matmul(input, self.other) + + inputs = [torch.randn(*input_shape)] + self.run_test( + MatMul(), + inputs, + expected_ops={acc_ops.matmul}, + test_implicit_batch_dim=(len(input_shape) > 1), + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("1_2", (3,), (3, 2)), + ("3_4", (2, 2, 3), (3, 1, 3, 3)), + ] + ) + def test_matmul_input_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.input = nn.Parameter(torch.randn(*input_shape)) + + def forward(self, other): + return torch.matmul(self.input, other) + + inputs = [torch.randn(*other_shape)] + self.run_test( + MatMul(), + inputs, + expected_ops={acc_ops.matmul}, + test_implicit_batch_dim=(len(other_shape) > 2), + ) + + @parameterized.expand( + [ + ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)), + ("4_2", (2, 1, 2, 3), (3, 2)), + ("2_3", (2, 3), (2, 3, 4)), + ("2_2", (2, 3), (3, 2)), + ("2_1", (2, 3), (3,)), + ("1_2", (3,), (3, 2)), + ("1_1", (3,), (3,)), + ] + ) + def test_matmul(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def forward(self, input, other): + return torch.matmul(input, other) + + inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] + test_implicit_batch_dim = ( + input_shape[0] == other_shape[0] + and len(input_shape) > 2 + and len(other_shape) > 2 + ) + self.run_test( + MatMul(), + inputs, + expected_ops={acc_ops.matmul}, + test_implicit_batch_dim=test_implicit_batch_dim, + ) + + def test_matmal_dynamic_shape( + self, + ): + class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, other): + return torch.matmul(input, other) + + input_specs = [ + InputTensorSpec( + shape=(-1, 1, 2, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 2, 3), (9, 1, 2, 3), (9, 1, 2, 3))], + ), + InputTensorSpec( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Matmul(), input_specs, expected_ops={acc_ops.matmul} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py new file mode 100644 index 0000000000..c2cf7d252d --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py @@ -0,0 +1,160 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMaxConverter(AccTestCase): + @parameterized.expand( + [ + ("dim0_keepdim", 0, True, torch.randn(2, 2, 3)), + ("dim1_keepdim", 1, True, torch.randn(2, 2, 3)), + ("dim2_keepdim", 2, True, torch.randn(2, 2, 3)), + ("dim3_keepdim", 3, True, torch.randn(2, 2, 3, 3)), + ("dim2_no_keepdim", 2, False, torch.randn(2, 2, 3)), + ("dim1_no_keepdim", 1, False, torch.randn(2, 2, 3)), + ("dim0_no_keepdim", 0, False, torch.randn(2, 2, 3)), + ] + ) + def test_max_dim_reduce(self, test_name, dim, keepdim, input): + class MaxDimReduce(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.max(x, self.dim, self.keepdim) + + inputs = [input] + self.run_test( + MaxDimReduce(dim, keepdim), + inputs, + expected_ops={acc_ops.max_dim_reduce}, + test_implicit_batch_dim=(dim != 0), + ) + + @parameterized.expand( + [ + ("no_dim_no_keepdim"), + ] + ) + def test_max_full_reduce( + self, + test_name, + ): + class MaxFullReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.max(x) + + inputs = [torch.randn(3, 2, 3, 3)] + self.run_test( + MaxFullReduce(), + inputs, + expected_ops={acc_ops.max_full_reduce}, + # We can't do a full reduce over the batch dimension + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("max_method_no_dim_no_keepdim"), + ("max_method_no_dim_no_keepdim"), + ] + ) + def test_max_method(self, test_name): + class MaxMethod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, other): + return input.max(other) + + inputs = [torch.randn(3, 4), torch.randn(3, 4)] + self.run_test(MaxMethod(), inputs, expected_ops={acc_ops.maximum}) + + +class TestMaxConverterWithDynamicShape(AccTestCase): + @parameterized.expand( + [ + # keepdim can not be False for dynamic shape + ("dim0_keepdim", 0, True), + ("dim1_keepdim", 1, True), + ("dim2_keepdim", 2, True), + ("dim3_keepdim", 3, True), + ] + ) + def test_max_dim_reduce(self, _, dim, keepdim): + class MaxDimReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.max(x, dim, keepdim=keepdim) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce} + ) + + def test_max_full_reduce( + self, + ): + class MaxFullReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.max(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce} + ) + + def test_max_method(self): + class MaxMethod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, other): + return input.max(other) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MaxMethod(), input_specs, expected_ops={acc_ops.maximum} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py new file mode 100644 index 0000000000..e0bec6f15d --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py @@ -0,0 +1,82 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMaximumConverter(AccTestCase): + def test_maximum(self): + class Maximum(torch.nn.Module): + def forward(self, x, y): + return torch.maximum(x, y) + + inputs = [ + torch.randn(3, 4), + torch.randn(3, 4), + ] + self.run_test(Maximum(), inputs, expected_ops={acc_ops.maximum}) + + +class TestMaximumConverterWithDynamicShape(AccTestCase): + def test_maximum(self): + class Maximum(torch.nn.Module): + def forward(self, x, y): + return torch.maximum(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Maximum(), input_specs, expected_ops={acc_ops.maximum} + ) + + +class TestMaximumMethodConverter(AccTestCase): + def test_maximum(self): + class Maximum(torch.nn.Module): + def forward(self, x, y): + return x.maximum(y) + + inputs = [ + torch.randn(3, 4), + torch.randn(3, 4), + ] + self.run_test(Maximum(), inputs, expected_ops={acc_ops.maximum}) + + +class TestMaximumMethodConverterWithDynamicShape(AccTestCase): + def test_maximum(self): + class Maximum(torch.nn.Module): + def forward(self, x, y): + return x.maximum(y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Maximum(), input_specs, expected_ops={acc_ops.maximum} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py new file mode 100644 index 0000000000..ddb48b4b69 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py @@ -0,0 +1,379 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMaxPoolConverter(AccTestCase): + @parameterized.expand( + [ + ("default", 1), + ("kernel_3", 3), + ("stride", 1, 2), + param("padding", 2, padding=1), + param("padding_even", 5, padding=2), + param("ceil_mode", 1, ceil_mode=True), + ] + ) + def test_max_pool1d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + dilation=1, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool1d( + kernel_size, stride, padding, ceil_mode=ceil_mode, dilation=dilation + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 224)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.max_pool1d}, + ) + + def test_max_pool1d_with_dynamic_shape( + self, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool1d(1) + + def forward(self, x): + return self.max_pool(x) + + # shape is not set to (-1, -1, -1) as reshape dimension with + # more than one -1 wildcard is not allowed while adding unsqueeze layer + input_specs = [ + InputTensorSpec( + shape=(1, 1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 1, 4), (1, 1, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={acc_ops.max_pool1d}, + ) + + @parameterized.expand( + [ + ("default", 1), + ("stride", 1, 2), + ("tuple_parameters", 2, (1, 1), (1, 1)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + ] + ) + def test_max_pool2d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool2d( + kernel_size, stride, padding, ceil_mode=ceil_mode + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool2d}) + + def test_max_pool2d_with_dynamic_shape( + self, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool2d(1, 1) + + def forward(self, x): + return self.max_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.max_pool2d} + ) + + @parameterized.expand( + [ + ("default", 1), + ("stride", 1, 2), + ("tuple_parameters", 2, (1, 1, 1), (1, 1, 1)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + ] + ) + def test_max_pool3d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool3d( + kernel_size, stride, padding, ceil_mode=ceil_mode + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool3d}) + + def test_max_pool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool3d(1, 1) + + def forward(self, x): + return self.max_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.max_pool3d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool1d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool1d( + x, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + dilation=dilation, + ) + + inputs = [torch.randn(1, 3, 224)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.max_pool1d}, + test_explicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool2d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool2d}) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool3d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool3d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool3d}) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool1d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool1d( + x, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + dilation=dilation, + ) + + # shape is not set to (-1, -1, -1) as reshape dimension with + # more than one -1 wildcard is not allowed while adding unsqueeze layer + input_specs = [ + InputTensorSpec( + shape=(1, 1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 1, 4), (1, 1, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={acc_ops.max_pool1d}, + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool2d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.max_pool2d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool3d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool3d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.max_pool3d} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py new file mode 100644 index 0000000000..9f37238240 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py @@ -0,0 +1,159 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMinConverter(AccTestCase): + @parameterized.expand( + [ + ("dim0_keepdim", 0, True, torch.randn(2, 2, 3)), + ("dim1_keepdim", 1, True, torch.randn(2, 2, 3)), + ("dim2_keepdim", 2, True, torch.randn(2, 2, 3)), + ("dim3_keepdim", 3, True, torch.randn(2, 2, 3, 3)), + ("dim2_no_keepdim", 2, False, torch.randn(2, 2, 3)), + ("dim1_no_keepdim", 1, False, torch.randn(2, 2, 3)), + ("dim0_no_keepdim", 0, False, torch.randn(2, 2, 3)), + ] + ) + def test_min_dim_reduce(self, test_name, dim, keepdim, input): + class MinDimReduce(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.min(x, self.dim, self.keepdim) + + inputs = [input] + self.run_test( + MinDimReduce(dim, keepdim), + inputs, + expected_ops={acc_ops.min_dim_reduce}, + test_implicit_batch_dim=(dim != 0), + ) + + @parameterized.expand( + [ + ("no_dim_no_keepdim"), + ] + ) + def test_min_full_reduce( + self, + test_name, + ): + class MinFullReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.min(x) + + inputs = [torch.randn(3, 2, 3, 3)] + self.run_test( + MinFullReduce(), + inputs, + expected_ops={acc_ops.min_full_reduce}, + # We can't do a full reduce over the batch dimension + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("min_method_no_dim_no_keepdim"), + ("min_method_no_dim_no_keepdim"), + ] + ) + def test_min_method(self, test_name): + class MinMethod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, other): + return input.min(other) + + inputs = [torch.randn(3, 4), torch.randn(3, 4)] + self.run_test(MinMethod(), inputs, expected_ops={acc_ops.minimum}) + + +class TestMinConverterWithDynamicShape(AccTestCase): + @parameterized.expand( + [ + ("dim0_keepdim", 0, True), + ("dim1_keepdim", 1, True), + ("dim2_keepdim", 2, True), + ("dim3_keepdim", 3, True), + ] + ) + def test_min_dim_reduce(self, test_name, dim, keepdim): + class MinDimReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.min(x, dim, keepdim) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce} + ) + + def test_min_full_reduce( + self, + ): + class MinFullReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.min(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce} + ) + + def test_min_method(self): + class MinMethod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, other): + return input.min(other) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MinMethod(), input_specs, expected_ops={acc_ops.minimum} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py new file mode 100644 index 0000000000..a4b605cf66 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py @@ -0,0 +1,82 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMinimumConverter(AccTestCase): + def test_minimum(self): + class Minimum(torch.nn.Module): + def forward(self, x, y): + return torch.minimum(x, y) + + inputs = [ + torch.randn(3, 4), + torch.randn(3, 4), + ] + self.run_test(Minimum(), inputs, expected_ops={acc_ops.minimum}) + + +class TestMinimumMethodConverter(AccTestCase): + def test_minimum(self): + class Minimum(torch.nn.Module): + def forward(self, x, y): + return x.minimum(y) + + inputs = [ + torch.randn(3, 4), + torch.randn(3, 4), + ] + self.run_test(Minimum(), inputs, expected_ops={acc_ops.minimum}) + + +class TestMinimumConverterWithDynamicShape(AccTestCase): + def test_minimum(self): + class Minimum(torch.nn.Module): + def forward(self, x, y): + return torch.minimum(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Minimum(), input_specs, expected_ops={acc_ops.minimum} + ) + + +class TestMinimumMethodConverterWithDynamicShape(AccTestCase): + def test_minimum(self): + class Minimum(torch.nn.Module): + def forward(self, x, y): + return x.minimum(y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Minimum(), input_specs, expected_ops={acc_ops.minimum} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py new file mode 100644 index 0000000000..93cf4ea523 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestNarrowConverterWithDynamicShape(AccTestCase): + @parameterized.expand( + [ + ("positive_dim", 1, 0, 1), + ] + ) + def test_narrow(self, _, dim, start, length): + class Narrow(nn.Module): + def forward(self, x): + return x.narrow(dim, start, length) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Narrow(), input_specs, expected_ops={acc_ops.slice_tensor} + ) + + +class TestNarrowConverter(AccTestCase): + @parameterized.expand( + [ + ("positive_dim", 1, 0, 1), + ("negative_dim", -1, 1, 2), + ] + ) + def test_narrow(self, _, dim, start, length): + class Narrow(nn.Module): + def forward(self, x): + return x.narrow(dim, start, length) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Narrow(), + inputs, + expected_ops={acc_ops.slice_tensor}, + test_explicit_batch_dim=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py new file mode 100644 index 0000000000..affbc57aae --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py @@ -0,0 +1,304 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestNeFunctionConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_ne(self, _, input, other): + class Ne(torch.nn.Module): + def forward(self, x, y): + return torch.ne(x, y) + + inputs = [ + input, + other, + ] + self.run_test( + Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False + ) + + +class TestNeFunctionConverterWithDynamicShape(AccTestCase): + def test_ne(self): + class Ne(torch.nn.Module): + def forward(self, x, y): + return torch.ne(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) + + +class TestNeMethodConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_ne(self, _, input, other): + class Ne(torch.nn.Module): + def forward(self, x, y): + return x.ne(y) + + inputs = [ + input, + other, + ] + self.run_test( + Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False + ) + + +class TestNeMethodConverterWithDynamicShape(AccTestCase): + def test_ne(self): + class Ne(torch.nn.Module): + def forward(self, x, y): + return x.ne(y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) + + +class TestNeOperatorConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ( + "rand_2d_float_single_bool", + torch.randn(3, 4), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_int_single_bool", + torch.randn(3, 4).to(torch.int), + torch.tensor(0).to(torch.bool), + ), + ( + "rand_2d_bool_single_bool", + torch.randn(3, 4).to(torch.bool), + torch.tensor(0).to(torch.bool), + ), + ] + ) + def test_ne(self, _, input, other): + class Ne(torch.nn.Module): + def forward(self, x, y): + return x != y + + inputs = [ + input, + other, + ] + self.run_test( + Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False + ) + + +class TestNeOperatorConverterWithDynamicShape(AccTestCase): + def test_ne(self): + class Ne(torch.nn.Module): + def forward(self, x, y): + return x != y + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) + + +class TestNeOperatorConstantConverter(AccTestCase): + @parameterized.expand( + [ + ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), + ( + "rand_2d_int_bool", + torch.randn(3, 4).to(torch.int), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_bool_bool", + torch.randn(3, 4).to(torch.bool), + torch.randn(3, 4).to(torch.bool), + ), + ( + "rand_2d_float_int", + torch.randn(3, 4).to(torch.float), + torch.randn(3, 4).to(torch.int), + ), + ("rand_2d_float_single_bool", torch.randn(3, 4), False), + ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), + ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), + ] + ) + def test_ne(self, _, input, other): + class Ne(torch.nn.Module): + def __init__(self): + super().__init__() + self.other = other + + def forward(self, x): + return x != self.other + + inputs = [ + input, + ] + self.run_test( + Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverter(AccTestCase): + def test_ne(self): + class Ne(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] != 4 + + input = torch.randn(3, 4) + inputs = [ + input, + ] + self.run_test( + Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False + ) + + +class TestConstInputConverterWithDynamicShape(AccTestCase): + def test_ne(self): + class Ne(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[0] != 4 + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py new file mode 100644 index 0000000000..dabfedb139 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestNewOnesConverter(AccTestCase): + def test_newone_no_dtype(self): + class TestModule(nn.Module): + def forward(self, x): + return x.new_ones((3, 5)) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.new_ones}, + test_implicit_batch_dim=False, + ) + + def test_newone_device(self): + class TestModule(nn.Module): + def forward(self, x): + return x.new_ones((3, 5), device="cuda") + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.new_ones}, + test_implicit_batch_dim=False, + ) + + +class TestNewOnesConverterWithDynamicShape(AccTestCase): + def test_newone_no_dtype(self): + class TestModule(nn.Module): + def forward(self, x): + return x.new_ones((3, 5)) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.new_ones} + ) + + def test_newone_device(self): + class TestModule(nn.Module): + def forward(self, x): + return x.new_ones((3, 5), device="cuda") + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.new_ones} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py new file mode 100644 index 0000000000..a2eafd4fdc --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +class TestNumelConverter(AccTestCase): + def test_numel(self): + class Numel(nn.Module): + def forward(self, x): + return torch.numel(x) * x + + inputs = [torch.ones(1, 2, 3, 4)] + self.run_test(Numel(), inputs, expected_ops={acc_ops.numel}) + + +# Testing with (-1, -1, -1 , -1) results in following error: +# RuntimeError: numel does not support dynamic shapes. +""" +class TestNumelConverterWithDynamicShape(AccTestCase): + def test_numel(self): + class Numel(nn.Module): + def forward(self, x): + return torch.numel(x) * x + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Numel(), input_specs, expected_ops={acc_ops.numel} + ) +""" + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py new file mode 100644 index 0000000000..e850268fde --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py @@ -0,0 +1,102 @@ +import unittest + +import tensorrt as trt +import torch +import torch.nn as nn + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + +# from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestPadConverter(AccTestCase): + @parameterized.expand( + [ + ("1d", (1, 2), 9), + ("2d", (2, 0, 0, 1), 10), + ] + ) + def test_pad_value(self, _, pad, value): + class Pad(nn.Module): + def forward(self, x): + return torch.nn.functional.pad(x, pad, value=value) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Pad(), + inputs, + expected_ops={acc_ops.pad}, + # enable value will not work with implicit batch + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("1d", (1, 2)), + ("2d", (2, 0, 0, 1)), + ] + ) + def test_pad(self, _, pad): + class Pad(nn.Module): + def forward(self, x): + return torch.nn.functional.pad(x, pad) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Pad(), + inputs, + expected_ops={acc_ops.pad}, + # enable value will not work with implicit batch + test_implicit_batch_dim=False, + ) + + # Testing with (-1, 3, 3, 3) results into following error: + # test_pad_with_dynamic_shape_four_dimensions_0_2d (deeplearning.trt.torch_tensorrt.py.torch_tensorrt.fx.test.converters.acc_op.test_pad.TestPadConverter) ... [07/15/2022-09:23:18] [TRT] [E] 2: [intInterval.cpp::max::26] Error Code 2: Internal Error (Assertion !empty() failed. ) + # Segmentation fault (core dumped) + + """ + def test_pad_with_dynamic_shape_four_dimensions(self): + class Pad(nn.Module): + def forward(self, x): + return torch.nn.functional.pad(x, (1, 1)) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3, 3), (2, 3, 3, 3), (2, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape(Pad(), input_specs, expected_ops={acc_ops.pad}) + """ + + @parameterized.expand( + [ + ("3d", (2, 2, 3, 1, 2, 2)), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.2", + "Padding 3d only supported in TensorRT 8.2 and later", + ) + def test_pad_3d(self, _, pad): + class Pad(nn.Module): + def forward(self, x): + return torch.nn.functional.pad(x, pad) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Pad(), + inputs, + expected_ops={acc_ops.pad}, + # enable value will not work with implicit batch + test_implicit_batch_dim=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py new file mode 100644 index 0000000000..9e4ebc9cf4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestPermuteConverter(AccTestCase): + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute_list(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={acc_ops.permute}) + + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(*permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={acc_ops.permute}) + + @parameterized.expand( + [ + ("positive", (1, 2)), + ("negative", (-1, -2)), + ] + ) + def test_transpose(self, _, dims): + class Transpose(nn.Module): + def forward(self, x): + return x.transpose(*dims) + + inputs = [torch.randn(1, 2, 3)] + self.run_test(Transpose(), inputs, expected_ops={acc_ops.permute}) + + def test_permute_with_dynamic_shape(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={acc_ops.permute} + ) + + def test_permute_with_dynamic_shape_four_dimensions(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 3, 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={acc_ops.permute} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py new file mode 100644 index 0000000000..0d6c16b98e --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py @@ -0,0 +1,118 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + +# NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples + + +class TestProdConverter(AccTestCase): + @parameterized.expand( + [ + ( + f"{acc_ops.prod.__name__}_dim0_keepdim", + 0, + True, + torch.prod, + acc_ops.prod, + ), + ( + f"{acc_ops.prod.__name__}_dim0_no_keepdim", + 0, + False, + torch.prod, + acc_ops.prod, + ), + ( + f"{acc_ops.prod.__name__}_dim1_keepdim", + 1, + True, + torch.prod, + acc_ops.prod, + ), + ( + f"{acc_ops.prod.__name__}_dim1_no_keepdim", + 1, + False, + torch.prod, + acc_ops.prod, + ), + ( + f"{acc_ops.prod.__name__}_dim1_keepdim", + 2, + True, + torch.prod, + acc_ops.prod, + ), + ( + f"{acc_ops.prod.__name__}_dim1_no_keepdim", + 2, + False, + torch.prod, + acc_ops.prod, + ), + ] + ) + def test_prod(self, test_name, dim, keepdim, op, expected_acc_op): + class Prod(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return op(x, dim=self.dim, keepdim=self.keepdim) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Prod(dim, keepdim), + inputs, + expected_ops={expected_acc_op}, + test_implicit_batch_dim=(dim != 0), + ) + + @parameterized.expand( + [(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)] + ) + def test_prod_all_dims( + self, + test_name, + op, + expected_acc_op, + ): + class Prod(torch.nn.Module): + def forward(self, x): + return op(x) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Prod(), + inputs, + expected_ops={expected_acc_op}, + test_implicit_batch_dim=False, + ) + + def test_prod_all_dims_with_dynamic_shape( + self, + op=torch.prod, + ): + class Prod(torch.nn.Module): + def forward(self, x): + return op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Prod(), input_specs, expected_ops={acc_ops.prod} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py new file mode 100644 index 0000000000..5830a3e463 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py @@ -0,0 +1,65 @@ +import unittest + +import tensorrt as trt +import torch.fx +import torch.nn as nn + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +@unittest.skip( + """ + Tests related to quantize have issue creating engine, disable now. + """ +) +@unittest.skipIf( + trt.__version__ < "8.0", + "Explicit quantization only supported in TensorRT 8.0 and later", +) +class TestQuantizePerTensorConverter(AccTestCase): + def test_quantize_per_tensor(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.quantize_per_tensor(x, 1, 0, torch.quint8) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.quantize_per_tensor}) + + def test_quantize_per_tensor_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.quantize_per_tensor(x, 1, 0, torch.quint8) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.quantize_per_tensor} + ) + + def test_quantize_per_tensor_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.quantize_per_tensor(x, 1, 0, torch.quint8) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.quantize_per_tensor} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py new file mode 100644 index 0000000000..988eb7b477 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py @@ -0,0 +1,108 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + +reduce_ops = [(torch.sum, acc_ops.sum), (torch.mean, acc_ops.mean)] + + +class TestReduceConverter(AccTestCase): + @parameterized.expand( + case + for op, acc_op in reduce_ops + for case in [ + (f"{acc_op.__name__}_single_dim_no_keepdim", 1, False, op, acc_op), + (f"{acc_op.__name__}_single_dim_keepdim", 1, True, op, acc_op), + (f"{acc_op.__name__}_two_dim_no_keepdim", (1, 2), False, op, acc_op), + (f"{acc_op.__name__}_two_dim_keepdim", (1, 2), True, op, acc_op), + (f"{acc_op.__name__}_three_dim_no_keepdim", (1, 2, 3), False, op, acc_op), + (f"{acc_op.__name__}_three_dim_keepdim", (1, 2, 3), True, op, acc_op), + (f"{acc_op.__name__}_dim0_keepdim", 0, True, op, acc_op), + (f"{acc_op.__name__}_dim0_no_keepdim", 0, False, op, acc_op), + (f"{acc_op.__name__}_neg_single_dim_no_keepdim", -1, False, op, acc_op), + (f"{acc_op.__name__}_neg_single_dim_keepdim", -1, True, op, acc_op), + (f"{acc_op.__name__}_neg_two_dim_no_keepdim", (-1, -2), False, op, acc_op), + (f"{acc_op.__name__}_neg_two_dim_keepdim", (-1, -2), True, op, acc_op), + ( + f"{acc_op.__name__}_neg_pos_two_dim_no_keepdim", + (-1, 1), + False, + op, + acc_op, + ), + (f"{acc_op.__name__}_neg_pos_two_dim_keepdim", (-1, 1), True, op, acc_op), + ] + ) + def test_reduce(self, test_name, dim, keepdim, op, expected_acc_op): + class Reduce(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return op(x, dim=self.dim, keepdim=self.keepdim) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Reduce(dim, keepdim), + inputs, + expected_ops={expected_acc_op}, + test_implicit_batch_dim=(dim != 0), + ) + + @parameterized.expand( + [ + (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) + for op, acc_op in reduce_ops + ] + ) + def test_reduce_all_dims( + self, + test_name, + op, + expected_acc_op, + ): + class Reduce(torch.nn.Module): + def forward(self, x): + return op(x) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Reduce(), + inputs, + expected_ops={expected_acc_op}, + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) + for op, acc_op in reduce_ops + ] + ) + def test_reduce_all_dims_with_dynamic_shape_four_dimensions( + self, + test_name, + op, + expected_acc_op, + ): + class Reduce(torch.nn.Module): + def forward(self, x): + return op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Reduce(), input_specs, expected_ops={expected_acc_op} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py new file mode 100644 index 0000000000..e520b742c1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestReLUConverter(AccTestCase): + def test_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.relu}) + + def test_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.relu} + ) + + def test_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.relu} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py new file mode 100644 index 0000000000..efb1f80c0f --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py @@ -0,0 +1,76 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch import nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestRepeatInterLeave(AccTestCase): + @parameterized.expand( + [ + ("none_dim", (2, 3, 4), 3, None), + ("dim_0", (2, 3, 4), 3, 0), + ("dim_1", (2, 3, 4), 3, 1), + ("dim_2", (2, 3, 4), 3, 2), + ] + ) + def test_repeat_interleave(self, _, input_shape, repeat, dim): + class RepeatInterleave(nn.Module): + def __init__(self, dim): + super().__init__() + self.repeat = repeat + self.dim = dim + + def forward(self, x): + return torch.repeat_interleave(x, self.repeat, self.dim) + + inputs = [torch.randn(*input_shape)] + expected_ops = {acc_ops.tile, acc_ops.unsqueeze, acc_ops.reshape} + if dim is not None: + expected_ops.update({acc_ops.getitem, acc_ops.size}) + self.run_test( + RepeatInterleave(dim), + inputs, + expected_ops=expected_ops, + test_implicit_batch_dim=dim is not None and dim != 0, + ) + + @parameterized.expand( + [ + ("none_dim", (-1, 2, 3), 3, None), + ("dim_0", (-1, 2, 3), 3, 0), + ("dim_1", (-1, 2, 3), 3, 1), + ("dim_2", (-1, 3, 2), 3, 2), + ] + ) + def test_repeat_interleave_with_dynamic_shape(self, _, input_shape, repeat, dim): + class RepeatInterleave(nn.Module): + def __init__(self, dim): + super().__init__() + self.repeat = repeat + self.dim = dim + + def forward(self, x): + return torch.repeat_interleave(x, self.repeat, self.dim) + + input_specs = [ + InputTensorSpec( + shape=input_shape, + dtype=torch.float32, + shape_ranges=[ + ( + tuple(i if i != -1 else 1 for i in input_shape), + tuple(i if i != -1 else 2 for i in input_shape), + tuple(i if i != -1 else 3 for i in input_shape), + ) + ], + ), + ] + self.run_test_with_dynamic_shape( + RepeatInterleave(dim), input_specs, expected_ops={acc_ops.tile} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py new file mode 100644 index 0000000000..b5b1dc8f6c --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py @@ -0,0 +1,138 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestReshapeConverter(AccTestCase): + @parameterized.expand( + [ + ((1, 20),), + ((1, 10, -1),), + ] + ) + def test_reshape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + inputs = [torch.randn(1, 2, 10)] + self.run_test(TestModule(target_shape), inputs, expected_ops={acc_ops.reshape}) + + @parameterized.expand( + [ + ((-1, 2),), + ((1, 2, -1),), + ] + ) + def test_reshape_with_dynamic_shape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(target_shape), input_specs, expected_ops={acc_ops.reshape} + ) + + @parameterized.expand( + [ + ((-1, 2),), + ((1, 2, -1),), + ] + ) + def test_reshape_with_dynamic_shape_with_four_dimensions(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(target_shape), input_specs, expected_ops={acc_ops.reshape} + ) + + def test_reshape_with_dynamic_shape_size(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + shape_y = y.shape + t = shape_y[1] + return torch.reshape(x, [-1, t, 3]) + + input_specs = [ + InputTensorSpec( + shape=(-1, 5, 6), + dtype=torch.float32, + shape_ranges=[((1, 5, 6), (2, 5, 6), (3, 5, 6))], + ), + InputTensorSpec( + shape=(-1, 5), + dtype=torch.float32, + shape_ranges=[((1, 5), (1, 5), (3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.reshape} + ) + + def test_reshape_with_dynamic_shape_mul(self): + class TestModule(torch.nn.Module): + def forward(self, x, y, z): + t = 8000 + a = torch.reshape(x, [-1, t, 64]) + b = torch.reshape(y, [-1, t, 64]) + c = torch.reshape(z, [-1, t, 64]) + d = a + b + c + return d + + input_specs = [ + InputTensorSpec( + shape=(-1, 42, 512), + dtype=torch.float32, + shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], + ), + InputTensorSpec( + shape=(-1, 42, 512), + dtype=torch.float32, + shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], + ), + InputTensorSpec( + shape=(-1, 42, 512), + dtype=torch.float32, + shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.reshape} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py new file mode 100644 index 0000000000..4a89f364ee --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSeLUConverter(AccTestCase): + def test_selu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.selu}) + + def test_selu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.selu} + ) + + def test_selu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.selu} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py new file mode 100644 index 0000000000..d8abf37707 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSigmoid(AccTestCase): + def test_sigmoid(self): + class Sigmoid(nn.Module): + def forward(self, x): + return torch.sigmoid(x) + + inputs = [torch.randn(1, 2, 3)] + self.run_test(Sigmoid(), inputs, expected_ops={acc_ops.sigmoid}) + + def test_sigmoid_with_dynamic_shape_four_dimensions(self): + class Sigmoid(nn.Module): + def forward(self, x): + return torch.sigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Sigmoid(), input_specs, expected_ops={acc_ops.sigmoid} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py new file mode 100644 index 0000000000..684b2247e8 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py @@ -0,0 +1,52 @@ +import torch +from torch import nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.tracer.acc_tracer import acc_ops + + +class TestSilu(AccTestCase): + def test_silu(self): + class Silu(nn.Module): + def forward(self, x): + return torch.nn.functional.silu(x) + + inputs = [torch.randn(1, 2, 3)] + self.run_test(Silu(), inputs, expected_ops={acc_ops.sigmoid, acc_ops.mul}) + + def test_silu_with_dynamic_shape(self): + class Silu(nn.Module): + def forward(self, x): + return torch.nn.functional.silu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Silu(), input_specs, expected_ops={acc_ops.sigmoid, acc_ops.mul} + ) + + def test_silu_with_dynamic_shape_four_dimensions(self): + class Silu(nn.Module): + def forward(self, x): + return torch.nn.functional.silu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Silu(), input_specs, expected_ops={acc_ops.sigmoid, acc_ops.mul} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py new file mode 100644 index 0000000000..9fd1e45015 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSizeConverter(AccTestCase): + def test_size(self): + class Size(nn.Module): + def forward(self, x): + bs = x.size(0) + return x.view(bs, -1) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test(Size(), inputs, expected_ops={acc_ops.size}) + + def test_size_param(self): + class Size(nn.Module): + def __init__(self, x): + super().__init__() + self.param = torch.nn.Parameter(x) + + def forward(self, y): + bs = self.param.size(0) + return y.view(bs, -1) + + self.run_test( + Size(torch.randn(1, 2, 3, 4)), + [torch.randn(1, 2, 3, 4)], + expected_ops={acc_ops.size}, + ) + + def test_size_dynamic_shape(self): + class Size(nn.Module): + def forward(self, x): + bs = x.size(0) + return x.view(bs, -1) + + input_specs = [ + InputTensorSpec( + shape=(-1, 12, 32), + dtype=torch.float32, + shape_ranges=[((1, 12, 32), (3, 12, 32), (100, 12, 32))], + ), + ] + self.run_test_with_dynamic_shape( + Size(), input_specs, expected_ops={acc_ops.size} + ) + + def test_size_dynamic_shape_four_dimensions(self): + class Size(nn.Module): + def forward(self, x): + bs = x.size(0) + return x.view(bs, -1) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 12, 32, 3), (3, 12, 32, 3), (100, 12, 32, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Size(), input_specs, expected_ops={acc_ops.size} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py new file mode 100644 index 0000000000..5c8b9ed58b --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSoftmaxConverter(AccTestCase): + @parameterized.expand( + [("none_dim", None), ("basic", 1), ("batch_dim", 0), ("negative_dim", -2)] + ) + def test_softmax(self, _, dim): + class Softmax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return nn.functional.softmax(x, dim=self.dim) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Softmax(dim), + inputs, + expected_ops={acc_ops.softmax}, + test_implicit_batch_dim=(dim is None or dim % len(inputs[0].shape) != 0), + ) + + def test_softmax_with_dynamic_shape(self): + class Softmax(nn.Module): + def forward(self, x): + return nn.functional.softmax(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Softmax(), input_specs, expected_ops={acc_ops.softmax} + ) + + def test_softmax_with_dynamic_shape_four_dimensions(self): + class Softmax(nn.Module): + def forward(self, x): + return nn.functional.softmax(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Softmax(), input_specs, expected_ops={acc_ops.softmax} + ) + + def test_softmax_with_implicit_batch_dim0_fail(self): + class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return nn.functional.softmax(x, dim=0) + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test_with_assert_error( + Softmax(), + inputs, + expect_error=AssertionError, + test_explicit_batch_dim=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py new file mode 100644 index 0000000000..47241685fb --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSoftsignConverter(AccTestCase): + def test_softsign(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.softsign(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.softsign}) + + def test_softsign_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.softsign(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.softsign} + ) + + def test_softsign_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.softsign(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.softsign} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py new file mode 100644 index 0000000000..4861fecc34 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSplitConverter(AccTestCase): + @parameterized.expand( + [ + ("split_size", 3, 1), + ("sections", [5, 2, 3], 1), + ] + ) + def test_split(self, _, split_size_or_sections, dim): + class Split(nn.Module): + def forward(self, x): + return x.split(split_size_or_sections, dim)[0] + + inputs = [torch.randn(1, 10)] + self.run_test( + Split(), + inputs, + expected_ops={ + acc_ops.split + if isinstance(split_size_or_sections, int) + else acc_ops.slice_tensor + }, + test_explicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("split_with_size", [2, 3, 5], 1), + ] + ) + def test_split_with_size(self, _, split_size, dim): + class Split(nn.Module): + def forward(self, x): + return x.split_with_sizes(split_size, dim) + + inputs = [torch.randn(1, 10)] + self.run_test( + Split(), + inputs, + expected_ops={acc_ops.slice_tensor}, + test_explicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("split_size", 3, 1), + ("sections", [5, 2, 3], 1), + ] + ) + def test_split_with_dynamic_shape(self, _, split_size_or_sections, dim): + class Split(nn.Module): + def forward(self, x): + return x.split(split_size_or_sections, dim)[0] + + input_specs = [ + InputTensorSpec( + shape=(-1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 10), (5, 10, 15), (10, 10, 20))], + ), + ] + self.run_test_with_dynamic_shape( + Split(), + input_specs, + expected_ops={ + acc_ops.split + if isinstance(split_size_or_sections, int) + else acc_ops.slice_tensor + }, + ) + + # Testing with (-1, -1, -1) results into following error: + # AssertionError: Can't chunk on dynamic shape dimension! + + @parameterized.expand( + [ + ("split_with_size", [2, 3, 5], 1), + ] + ) + def test_split_with_size_dynamic_shape(self, _, split_size, dim): + class Split(nn.Module): + def forward(self, x): + return x.split_with_sizes(split_size, dim) + + input_specs = [ + InputTensorSpec( + shape=(-1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 20), (5, 10, 20), (10, 10, 20))], + ), + ] + self.run_test_with_dynamic_shape( + Split(), + input_specs, + expected_ops={acc_ops.slice_tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py new file mode 100644 index 0000000000..bc65e010e4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestSqueeze(AccTestCase): + def test_squeeze(self): + class Squeeze(nn.Module): + def forward(self, x): + return x.squeeze(2) + + inputs = [torch.randn(1, 2, 1)] + self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) + + # Testing with shape=(-1, -1, -1, -1) results in error: + # AssertionError: We don't support squeeze dynamic dim. + + # Testing with more than one dynamic dim results in error: + # AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. + + def test_squeeze_with_dynamic_shape(self): + class Squeeze(nn.Module): + def forward(self, x): + return x.squeeze(0) + + input_specs = [ + InputTensorSpec( + shape=(1, -1, 2), + dtype=torch.float32, + shape_ranges=[((1, 1, 2), (1, 2, 2), (1, 3, 2))], + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), input_specs, expected_ops={acc_ops.squeeze} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py new file mode 100644 index 0000000000..cd38314295 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py @@ -0,0 +1,117 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMinConverter(AccTestCase): + @parameterized.expand( + [ + ("norm_1d", (-1), False), + ("norm_1d", (-1), True), + ("norm_2d", (2, 3), False), + ("norm_2d", (2, 3), True), + ] + ) + def test_std(self, _, dim, unbiased): + class Std(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.std(x, dim, unbiased=unbiased, keepdim=True) + + inputs = [torch.randn(2, 3, 4, 5)] + self.run_test( + Std(), + inputs, + expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, + ) + + @parameterized.expand( + [ + ("norm_1d", (-1), False), + ("norm_1d", (-1), True), + ("norm_2d", (2, 3), False), + ("norm_2d", (2, 3), True), + ] + ) + def test_std_with_dynamic_shape_four_dimensions(self, _, dim, unbiased): + class Std(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.std(x, dim, unbiased=unbiased, keepdim=True) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Std(), + input_specs, + expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, + ) + + @parameterized.expand( + [ + ("norm_1d", (-1), True), + ("norm_1d", (-1), False), + ("norm_2d", (2, 3), True), + ("norm_2d", (2, 3), False), + ] + ) + def test_std_method(self, _, dim, unbiased): + class Std(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.std(dim, unbiased=unbiased, keepdim=True) + + inputs = [torch.randn(2, 3, 4, 5)] + self.run_test( + Std(), + inputs, + expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, + ) + + @parameterized.expand( + [ + ("norm_1d", (-1), True), + ("norm_1d", (-1), False), + ("norm_2d", (2, 3), True), + ("norm_2d", (2, 3), False), + ] + ) + def test_std_method_with_dynamic_shape_four_dimensions(self, _, dim, unbiased): + class Std(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.std(dim, unbiased=unbiased, keepdim=True) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Std(), + input_specs, + expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py new file mode 100644 index 0000000000..94c442a4ed --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestTanh(AccTestCase): + def test_tanh(self): + class Tanh(nn.Module): + def forward(self, x): + return torch.tanh(x) + + inputs = [torch.randn(1, 2, 3)] + self.run_test(Tanh(), inputs, expected_ops={acc_ops.tanh}) + + def test_tanh_with_dynamic_shape(self): + class Tanh(nn.Module): + def forward(self, x): + return torch.tanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Tanh(), input_specs, expected_ops={acc_ops.tanh} + ) + + def test_tanh_with_dynamic_shape_four_dimensions(self): + class Tanh(nn.Module): + def forward(self, x): + return torch.tanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Tanh(), input_specs, expected_ops={acc_ops.tanh} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py new file mode 100644 index 0000000000..1d14987adc --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py @@ -0,0 +1,145 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestTile(AccTestCase): + @parameterized.expand( + [ + ("same_num_dims", (2, 2, 3), (1, 2, 2)), + ("less_dims", (2, 2, 3), (2,)), + ("more_dims", (2, 3), (1, 2, 2, 1)), + ] + ) + def test_tile(self, _, input_shape, dims): + class Tile(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.tile(x, self.dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Tile(dims), + inputs, + expected_ops={acc_ops.tile}, + test_implicit_batch_dim=( + len(input_shape) > len(dims) + or (len(input_shape) == len(dims) and dims[0] == 1) + ), + ) + + @parameterized.expand( + [ + ("same_num_dims", (-1, 2, 3), (1, 2, 2)), + ("less_dims", (-1, 2, 3), (2,)), + ("more_dims", (-1, 3), (1, 2, 2, 1)), + ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)), + ] + ) + def test_tile_with_dynamic_shape(self, _, shape, dims): + class Tile(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.tile(x, self.dims) + + input_specs = [ + InputTensorSpec( + shape=shape, + dtype=torch.float32, + shape_ranges=[ + ( + tuple(i if i != -1 else 1 for i in shape), + tuple(i if i != -1 else 2 for i in shape), + tuple(i if i != -1 else 3 for i in shape), + ) + ], + ), + ] + self.run_test_with_dynamic_shape( + Tile(dims), input_specs, expected_ops={acc_ops.tile} + ) + + @parameterized.expand( + [ + ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)), + ] + ) + def test_tile_with_dynamic_shape_four_dimensions(self, _, shape, dims): + class Tile(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.tile(x, self.dims) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Tile(dims), input_specs, expected_ops={acc_ops.tile} + ) + + def test_tile_non_int_dims(self): + class Tile(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y = y * 2 + return torch.tile(x, (1, y.shape[1], y.shape[1])) + + inputs = [torch.randn(2, 2, 3), torch.randn(2, 2, 3)] + batch_size_range = (1, 2, 3) + input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + inputs, batch_size_range + ) + self.run_test_with_dynamic_shape( + Tile(), + input_specs, + expected_ops={acc_ops.tile}, + ) + + def test_tile_non_int_dims_with_dynamic_shape_four_dimensions(self): + class Tile(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y = y * 2 + return torch.tile(x, (1, y.shape[1], y.shape[1])) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Tile(), input_specs, expected_ops={acc_ops.tile} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py new file mode 100644 index 0000000000..c057088c77 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py @@ -0,0 +1,319 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.utils import LowerPrecision + + +class TestToConverter(AccTestCase): + def test_fp16(self): + class To(torch.nn.Module): + def forward(self, x): + return x.to(torch.float16) + + input = torch.randn(2, 2) + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP16, + ) + + # Testing with shape shape=(-1, -1, -1, -1) results into following error: + # Error: assert engine + """ + def test_fp16_with_dynamic_shape_four_dimension(self): + class To(torch.nn.Module): + def forward(self, x): + return x.to(torch.float16) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ).cuda(), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} + ) + """ + + def test_fp32(self): + class To(torch.nn.Module): + def forward(self, x): + return x.to(torch.float32) + + input = torch.randn(2, 2).to(torch.float16) + inputs = [ + input, + ] + self.run_test( + To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False + ) + + def test_cuda_fp16(self): + class To(torch.nn.Module): + def forward(self, x): + return x.to(torch.device("cuda:0"), torch.float16) + + input = torch.randn(2, 2) + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP16, + ) + + def test_cuda(self): + class To(torch.nn.Module): + def forward(self, x): + x = x.to(torch.device("cuda")) + # append extra layer since to(device) is skipped in TRT + return x + torch.randn(2, 2).cuda() + + input = torch.randn(2, 2) + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype, acc_ops.add}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP32, + ) + + def test_cuda_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def forward(self, x): + x = x.to(torch.device("cuda")) + # append extra layer since to(device) is skipped in TRT + return x + torch.randn(3, 3, 3, 3).cuda() + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add} + ) + + def test_device(self): + class To(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 2) + + def forward(self, x): + idevice = x.device + a = self.a.to(idevice) + return x + a + + input = torch.randn(2, 2).cuda() + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP32, + ) + + def test_device_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(3, 3, 3, 3) + + def forward(self, x): + idevice = x.device + a = self.a.to(idevice) + return x + a + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add} + ) + + def test_device_fp16(self): + class To(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 2) + + def forward(self, x): + idevice = x.device + idtype = x.dtype + a = self.a.to(idevice) + # fx tracer could not handle "to(idevice, torch.float16)" + # TypeError: to() received an invalid combination of arguments - got (Attribute, torch.dtype) + return a.to(idtype) + + input = torch.randn(2, 2).half().cuda() + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP16, + ) + + # Testing with shape shape=(-1, -1, -1, -1) results into following error: + # Error: assert engine + """ + def test_device_fp16_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 2) + + def forward(self, x): + idevice = x.device + idtype = x.dtype + a = self.a.to(idevice) + # fx tracer could not handle "to(idevice, torch.float16)" + # TypeError: to() received an invalid combination of arguments - got (Attribute, torch.dtype) + return a.to(idtype) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((2, 2, 2, 2), (4, 4, 4, 4), (4, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} + ) + """ + + # tensor.float() + def test_float(self): + class To(torch.nn.Module): + def forward(self, x): + return x.float() + + input = torch.randn(2, 2).half() + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP32, + ) + + # tensor.float() + def test_float_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def forward(self, x): + return x.float() + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} + ) + + # Half is not suitable for dynamic shape + # Error: assert engine + + # tensor.half() + def test_half(self): + class To(torch.nn.Module): + def forward(self, x): + return x.half() + + input = torch.randn(2, 2) + inputs = [ + input, + ] + self.run_test( + To(), + inputs, + expected_ops={acc_ops.to_dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP16, + ) + + # TODO Open in future. TRT 8.5 does not work for this test + # The test is a rare case. We need to remove it in graph maybe. + # def test_int(self): + # class To(torch.nn.Module): + # def forward(self, x): + # x = x.int() + # # we do not expect int to be output type, so add an extra layer + # x = x.float() + # return x + + # input = torch.randn(2, 2) + # inputs = [ + # input, + # ] + # self.run_test( + # To(), + # inputs, + # expected_ops={acc_ops.to_dtype}, + # test_implicit_batch_dim=False, + # precision=LowerPrecision.FP32, + # ) + + # # tensor.int() + # def test_int_with_dynamic_shape_four_dimensions(self): + # class To(torch.nn.Module): + # def forward(self, x): + # x = x.int() + # # we do not expect int to be output type, so add an extra layer + # x = x.float() + # return x + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, -1, -1, -1), + # dtype=torch.int, + # shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + # ), + # ] + + # self.run_test_with_dynamic_shape( + # To(), input_specs, expected_ops={acc_ops.to_dtype} + # ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py new file mode 100644 index 0000000000..7790857f5a --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestTopKConverter(AccTestCase): + @parameterized.expand( + [ + ("top1", 1, -1), + ("top2", 2, -1), + ("none_dim", 1, None), + ("smallest", 1, -1, False), + ("top1_dim0", 1, 0, False), + ] + ) + def test_topk(self, _, k, dim, largest=True): + class TopK(nn.Module): + def __init__(self, k, dim): + super().__init__() + self.k = k + self.dim = dim + self.largest = largest + + def forward(self, x): + if self.dim is not None: + out = torch.topk( + x, k=self.k, dim=self.dim, largest=self.largest, sorted=False + ) + else: + out = torch.topk(x, k=self.k, largest=self.largest, sorted=False) + return out[0], out[1] + + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + TopK(k, dim), + inputs, + expected_ops={acc_ops.topk}, + test_implicit_batch_dim=(dim != 0), + ) + + @parameterized.expand( + [ + ("top1", 1, -1), + ("top2", 2, -1), + ("none_dim", 1, None), + ("smallest", 1, -1, False), + ("top1_dim0", 1, 0, False), + ] + ) + def test_topk_with_dynamic_shape_four_dimensions(self, _, k, dim, largest=True): + class TopK(nn.Module): + def __init__(self, k, dim): + super().__init__() + self.k = k + self.dim = dim + self.largest = largest + + def forward(self, x): + if self.dim is not None: + out = torch.topk( + x, k=self.k, dim=self.dim, largest=self.largest, sorted=False + ) + else: + out = torch.topk(x, k=self.k, largest=self.largest, sorted=False) + return out[0], out[1] + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TopK(k, dim), input_specs, expected_ops={acc_ops.topk} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py new file mode 100644 index 0000000000..934b4c0d81 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py @@ -0,0 +1,137 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestTransposeConvolutionConverter(AccTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv_transpose2d( + self, + _, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channels=3, + out_channels=6, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + ) + + def forward(self, x): + return self.conv_transpose(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv_transpose2d}) + + def test_conv_transpose2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d(3, 3, 1) + + def forward(self, x): + return self.conv_transpose(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv_transpose2d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv_transpose3d( + self, + _, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose3d( + in_channels=3, + out_channels=6, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + ) + + def forward(self, x): + return self.conv_transpose(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv_transpose3d}) + + def test_conv_transpose3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose3d(3, 6, 1) + + def forward(self, x): + return self.conv_transpose(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv_transpose3d} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py new file mode 100644 index 0000000000..24f99b5bff --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py @@ -0,0 +1,150 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.utils import LowerPrecision + + +class TestTypeAsConverter(AccTestCase): + def test_device_fp32(self): + class Type_as(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 2) + + def forward(self, x): + b = self.a.type_as(x) + return b + + # self.a = self.a.type_as(x) # error is throw + # return self.a + + input = torch.randn(2, 2).cuda() + inputs = [ + input, + ] + self.run_test( + Type_as(), + inputs, + expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, + test_implicit_batch_dim=False, + ) + + def test_device_fp16(self): + class Type_as(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 2) + + def forward(self, x): + return self.a.type_as(x) + + input = torch.randn(2, 2).half().cuda() + inputs = [ + input, + ] + self.run_test( + Type_as(), + inputs, + expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, + test_implicit_batch_dim=False, + precision=LowerPrecision.FP16, + ) + + def test_device_fp32_tensor(self): + class Type_as(torch.nn.Module): + def forward(self, input, other): + return other.type_as(input) + + input = torch.randn(2, 2).cuda() + other = torch.randn(2, 2) + inputs = [ + input, + other, + ] + self.run_test( + Type_as(), + inputs, + expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, + ) + + def test_device_fp16_tensor(self): + class Type_as(torch.nn.Module): + def forward(self, input, other): + return other.type_as(input) + + input = torch.randn(2, 2).half().cuda() + other = torch.randn(2, 2) + inputs = [ + input, + other, + ] + self.run_test( + Type_as(), + inputs, + expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, + precision=LowerPrecision.FP16, + ) + + def test_type_tensor(self): + class Type_as(torch.nn.Module): + def forward(self, input): + return input.type(dtype=torch.float16) + + input = torch.randn(2, 2) + + inputs = [ + input, + ] + self.run_test( + Type_as(), + inputs, + expected_ops={acc_ops.to_dtype}, + precision=LowerPrecision.FP16, + ) + + @unittest.skip("Does not pass in TRT 8.4.1 T127981773") + def test_type_tensor_with_dynamic_shape_four_dimensions(self): + class Type_as(torch.nn.Module): + def forward(self, input): + return input.type(dtype=torch.float32) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.int, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Type_as(), + input_specs, + expected_ops={acc_ops.to_dtype}, + ) + + def test_type_tensor_ext(self): + class Type_as(torch.nn.Module): + def forward(self, input, other): + t = input.type() + return other.type(t) + + input = torch.randn(2, 2).to(dtype=torch.float16) + other = torch.randn(2, 2) + + inputs = [ + input, + other, + ] + self.run_test( + Type_as(), + inputs, + expected_ops={acc_ops.to_dtype, acc_ops.dtype}, + precision=LowerPrecision.FP16, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py new file mode 100644 index 0000000000..f88c07c97a --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py @@ -0,0 +1,165 @@ +from typing import Callable + +import torch +import torch.nn as nn + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + +unary_ops = [ + (torch.sin, acc_ops.sin, False), + (torch.cos, acc_ops.cos, False), + (torch.tan, acc_ops.tan, False), + (torch.sinh, acc_ops.sinh, False), + (torch.cosh, acc_ops.cosh, False), + (torch.asin, acc_ops.asin, True), + (torch.acos, acc_ops.acos, True), + (torch.atan, acc_ops.atan, True), + (torch.abs, acc_ops.abs, False), + (torch.neg, acc_ops.neg, False), + (torch.reciprocal, acc_ops.reciprocal, False), + (torch.sqrt, acc_ops.sqrt, False), + (torch.log, acc_ops.log, False), + (torch.exp, acc_ops.exp, False), + (torch.floor, acc_ops.floor, False), + (torch.ceil, acc_ops.ceil, False), + (torch.sign, acc_ops.sign, False), +] + + +class TestUnaryOpConverters(AccTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1], op[2]) for op in unary_ops]) + def test_unary_ops( + self, name, orig_op: Callable, expected_op: Callable, range_req: bool + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x) + + m = TestModule(orig_op) + inputs = ( + [torch.distributions.uniform.Uniform(-1, 1).sample([2, 2, 3])] + if range_req + else [torch.randn(2, 2, 3)] + ) + self.run_test(m, inputs, expected_ops={expected_op}) + + +class TestUnaryVOpConvertersWithDynamicShapeFourDimensions(AccTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops]) + def test_unary_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(orig_op), input_specs, expected_ops={expected_op} + ) + + +class TestUnaryOpNotConverters(AccTestCase): + @parameterized.expand( + [ + ("not_bool", torch.logical_not, acc_ops.logical_not, torch.bool), + ("not_float", torch.logical_not, acc_ops.logical_not, torch.float), + ("not_int", torch.logical_not, acc_ops.logical_not, torch.int), + ] + ) + def test_unary_ops(self, name, orig_op: Callable, expected_op, input_dtype): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x) + return self.orig_op(x) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2, 3).to(input_dtype)] + self.run_test( + m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False + ) + + +class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase): + @parameterized.expand( + [ + ("not_bool", torch.logical_not, acc_ops.logical_not, torch.bool), + ("not_float", torch.logical_not, acc_ops.logical_not, torch.float), + ("not_int", torch.logical_not, acc_ops.logical_not, torch.int), + ] + ) + def test_unary_ops(self, name, orig_op: Callable, expected_op, input_dtype): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x) + return self.orig_op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(orig_op), input_specs, expected_ops={expected_op} + ) + + +class TestUnaryRSQRTConverters(AccTestCase): + def test_unary_ops(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + m = TestModule() + inputs = [torch.randn(2, 2, 3)] + self.run_test(m, inputs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}) + + +class TestUnaryRSQRTConvertersWithDynamicShapeFourDimensions(AccTestCase): + def test_unary_ops(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py new file mode 100644 index 0000000000..a422f1b6fe --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py @@ -0,0 +1,60 @@ +import torch +import torch.fx +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestUnsqueeze(AccTestCase): + @parameterized.expand( + [ + ("negative_dim", -2), + ("positive_dim", 2), + ] + ) + def test_unsqueeze(self, _, dim): + class Unsqueeze(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + inputs = [torch.randn(1, 2, 3)] + self.run_test(Unsqueeze(dim), inputs, expected_ops={acc_ops.unsqueeze}) + + # Testing with more than one dynamic dims results in following error: + # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. + + @parameterized.expand( + [ + ("negative_dim_dynamic", -4), + ("positive_dim_dynamic", 1), + ] + ) + def test_unsqueeze_with_dynamic_shape(self, _, dim): + class Unsqueeze(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + input_specs = [ + InputTensorSpec( + shape=(-1, 2, 3), + dtype=torch.float32, + shape_ranges=[((1, 2, 3), (2, 2, 3), (3, 2, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Unsqueeze(dim), input_specs, expected_ops={acc_ops.unsqueeze} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py new file mode 100644 index 0000000000..2985042f6b --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +class TestWhere(AccTestCase): + @parameterized.expand( + [ + ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), + ] + ) + def test_where(self, _, condition_shape, x_shape, y_shape): + class Where(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + inputs = [ + (torch.randn(condition_shape) > 0), + torch.randn(x_shape), + torch.ones(y_shape), + ] + self.run_test( + Where(), + inputs, + expected_ops={acc_ops.where}, + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), + ] + ) + def test_where_attribute_condition(self, _, condition_shape, x_shape, y_shape): + class Where(nn.Module): + def __init__(self, condition_shape): + super().__init__() + self.condition = torch.randn(condition_shape) > 0 + + def forward(self, x, y): + return torch.where(self.condition, x, y) + + inputs = [torch.randn(x_shape), torch.ones(y_shape)] + self.run_test( + Where(condition_shape), + inputs, + expected_ops={acc_ops.where}, + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), + ] + ) + def test_where_attribute_condition_x(self, _, condition_shape, x_shape, y_shape): + class Where(nn.Module): + def __init__(self, condition_shape, x_shape): + super().__init__() + self.condition = torch.randn(condition_shape) > 0 + self.x = torch.randn(x_shape) + + def forward(self, y): + return torch.where(self.condition, self.x, y) + + inputs = [torch.ones(y_shape)] + self.run_test( + Where(condition_shape, x_shape), + inputs, + expected_ops={acc_ops.where}, + test_implicit_batch_dim=False, + ) + + @parameterized.expand( + [ + ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), + ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), + ] + ) + def test_where_attribute_x_y(self, _, condition_shape, x_shape, y_shape): + class Where(nn.Module): + def __init__(self, x_shape, y_shape): + super().__init__() + + self.x = torch.randn(x_shape) + self.y = torch.ones(y_shape) + + def forward(self, condition): + return torch.where(condition, self.x, self.y) + + inputs = [(torch.randn(condition_shape) > 0)] + self.run_test( + Where(x_shape, y_shape), + inputs, + expected_ops={acc_ops.where}, + test_implicit_batch_dim=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py new file mode 100644 index 0000000000..b51c9a8f9a --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py @@ -0,0 +1,127 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestAdaptiveAvgPoolConverter(DispatchTestCase): + def test_adaptive_avgpool_mean(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.mean.dim}, + ) + + @parameterized.expand( + [ + ((64, 64),), + ((128, 64),), + (64,), + ] + ) + def test_adaptive_avgpool( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + def test_adaptive_avgpool_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 256, 256), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + @parameterized.expand( + [ + ((16, 16, 16),), + ((32, 16, 4),), + (32,), + ] + ) + def test_adaptive_avgpool3d( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 64, 64)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + def test_adaptive_avgpool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 32, 64, 64), + dtype=torch.float32, + shape_ranges=[ + ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) + ], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py new file mode 100644 index 0000000000..aed68ba35f --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py @@ -0,0 +1,65 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) + + def test_batchnorm1d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + def test_batchnorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py new file mode 100644 index 0000000000..b80cd514c1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py @@ -0,0 +1,205 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +NEED_TEST_BOTH_CONSTANTS_CASE = True + +elementwise_ops = [ + ((lambda x, y: x + y), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: torch.add(x, y)), + torch.ops.aten.add.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x.add(y)), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: x - y), torch.ops.aten.sub.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.sub(x, y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x.sub(y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x / y), torch.ops.aten.div.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: x // y), + torch.ops.aten.floor_divide.default, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="trunc")), + torch.ops.aten.div.Tensor_mode, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="floor")), + torch.ops.aten.div.Tensor_mode, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y)), + torch.ops.aten.div.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.fmod(x, y)), + torch.ops.aten.fmod.Tensor, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ## torch.floor_divide rounds result toward zero, rather than -Inf. + ## https://github.com/pytorch/pytorch/issues/43874 + ( + (lambda x, y: torch.floor_divide(x, y)), + torch.ops.aten.floor_divide.default, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x * y), torch.ops.aten.mul.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + (torch.pow, torch.ops.aten.pow.Tensor_Tensor, not NEED_TEST_BOTH_CONSTANTS_CASE), +] + + +class TestBinaryOpConverters(DispatchTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [torch.rand(1, 1) + 1] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops_with_one_constant( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand( + [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] + ) + def test_elementwise_op_with_both_constants( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant0 = torch.nn.Parameter(torch.randn(1)) + self.constant1 = torch.nn.Parameter(torch.randn(1)) + self.orig_op = orig_op + + def forward(self, x): + const = self.orig_op(self.constant0, self.constant1) + return self.orig_op(x, const) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + # Dynamic shape test + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + (-1, -1, -1), + ((1, 1, 1), (2, 2, 2), (3, 3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape( + self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=x_shape, + dtype=torch.float32, + shape_ranges=[x_shape_ranges], + ), + InputTensorSpec( + shape=y_shape, + dtype=torch.float32, + shape_ranges=[y_shape_ranges], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape_four_dimensions( + self, _, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py new file mode 100644 index 0000000000..1d181c0442 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestCatConverter(DispatchTestCase): + @parameterized.expand( + [ + ("pos", 1), + # ("neg", -2), #Dynamo tracer issue + ] + ) + def test_cat(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z), dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + @parameterized.expand( + [ + ("pos", 1), + # ("neg", -2), #Dynamo tracer issue + ] + ) + def test_cat_dynamic_shape(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y), dim) + + input_specs = [ + InputTensorSpec( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))], + ), + InputTensorSpec( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py new file mode 100644 index 0000000000..60971038fa --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py @@ -0,0 +1,203 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestConvolutionConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.convolution.default}, + test_explicit_precision=True, + ) + + def test_conv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, + 6, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 + # param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py new file mode 100644 index 0000000000..380cdc4db3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase + + +class TestExpandConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (2, 3), (2, 1)), + ("3d_dim", (2, 3, 4), (2, 1, 1)), + ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), + ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), + ] + ) + def test_expand(self, _, sizes, init_size): + class Expand(nn.Module): + def forward(self, x): + return x.expand(*sizes) + + inputs = [torch.randn(*init_size)] + self.run_test( + Expand(), + inputs, + expected_ops={torch.ops.aten.expand.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py new file mode 100644 index 0000000000..b1b0b584f0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py @@ -0,0 +1,70 @@ +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestFlattenConverter(DispatchTestCase): + @parameterized.expand( + [ + ("flatten_middle_dims", 1, 2), + ("flatten_last_3_dims", 1, 3), + ("flatten_all", 0, 3), + ] + ) + @unittest.skip("Not support yet") + def test_flatten(self, _, start_dim, end_dim): + class Flatten(nn.Module): + def __init__(self, start, end): + super().__init__() + self.start = start + self.end = end + + def forward(self, x): + return torch.flatten(x, self.start, self.end) + + inputs = [torch.randn(1, 2, 3, 1)] + self.run_test( + Flatten(start_dim, end_dim), + inputs, + expected_ops={torch.ops.aten.view.default}, + ) + + ## Dynamic shape does not work due to flatten converts to reshape in tracing. And batch or dynamic dimension is converted to fixed integer and loose dynamic + ## For ex., flatten (1, 512, 1, 1) with start_dim=1, end_dim=-1. After convert to reshape, output size=(1, 512) which is not correct since dim=0 is -1. + ## This problem may be solved using dynamic shape propogation. And we will know dim=0 is dynamic and we should set -1 in converter. + + # @parameterized.expand( + # [ + # ("flatten_middle_dims", 1, 2), + # ] + # ) + # def test_flatten_with_dynamic_shape(self, _, start_dim, end_dim): + # class Flatten(nn.Module): + # def __init__(self, start, end): + # super().__init__() + # self.start = start + # self.end = end + + # def forward(self, x): + # return torch.flatten(x, self.start, self.end) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, -1, -1, -1, -1), + # dtype=torch.float32, + # shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 3, 2, 1), (3, 3, 3, 3, 3))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # Flatten(start_dim, end_dim), + # input_specs, + # expected_ops={torch.ops.aten._reshape_alias.default}, + # ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py new file mode 100644 index 0000000000..5c06035e37 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py @@ -0,0 +1,71 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLinearConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", [1, 512], True, torch.ops.aten.linear), + ("matrix", [5, 512], True, torch.ops.aten.linear), + ("no_bias", [1, 512], False, torch.ops.aten.linear), + ( + "multi_dim_matrix", + [4, 5, 512], + True, + torch.ops.aten.linear, + ), + ( + "multi_dim_matrix", + [4, 5, 512], + False, + torch.ops.aten.linear, + ), + ] + ) + def test_linear(self, test_name, shape, bias, op): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias) + + def forward(self, x): + return self.linear(x) + + inputs = [torch.randn(shape)] + self.run_test(TestModule(), inputs, expected_ops={op}) + + # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern + # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. + + # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. + + # def test_linear_with_dynamic_shape(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.linear = torch.nn.Linear(512, 256) + + # def forward(self, x): + # return self.linear(x) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, 3, 512), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(), + # input_specs, + # expected_ops={torch.ops.aten.addmm.default}, + # ) + + ## Testing with (-1, -1, 512) results into following error: + ## AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py new file mode 100644 index 0000000000..5a121f0e07 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py @@ -0,0 +1,245 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestMaxPoolConverter(DispatchTestCase): + # TODO max_pool1d. It needs support of squeeze and unsqueeze + + @parameterized.expand( + [ + ("default", 1), + ("stride", 1, 2), + ("tuple_parameters", 2, (1, 1), (1, 1)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + ] + ) + def test_max_pool2d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool2d( + kernel_size, stride, padding, ceil_mode=ceil_mode + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) + + def test_max_pool2d_with_dynamic_shape( + self, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool2d(1, 1) + + def forward(self, x): + return self.max_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.max_pool2d}, + ) + + @parameterized.expand( + [ + ("default", 1), + # ("stride", 1, 2), + # ("tuple_parameters", 2, (1, 1, 1), (1, 1, 1)), + # param("padding", 2, padding=1), + # param("ceil_mode", 1, ceil_mode=True), + ] + ) + @unittest.skip("PT2 tracer issue") + def test_max_pool3d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool3d( + kernel_size, stride, padding, ceil_mode=ceil_mode + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={}) + + @unittest.skip("PT2 tracer issue") + def test_max_pool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool3d(1, 1) + + def forward(self, x): + return self.max_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool3d} + ) + + @parameterized.expand( + [ + ("default", 1), + # param("stride", 2, stride=()), #PT2 tracer issue + ] + ) + def test_stride_none_max_pool2d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + @unittest.skip("PT2 tracer issue") + def test_stride_none_max_pool3d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool3d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool3d}) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool2d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool2d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + @unittest.skip("PT2 tracer issue") + def test_stride_none_max_pool3d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool3d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool3d} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py new file mode 100644 index 0000000000..fb7fe2f509 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestReLUConverter(DispatchTestCase): + def test_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.relu.default}) + + def test_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + def test_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py new file mode 100644 index 0000000000..09dcb65ab1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py @@ -0,0 +1,102 @@ +import unittest + +import tensorrt as trt +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestReshapeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 20),), + ((1, 10, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + inputs = [torch.randn(1, 2, 10)] + self.run_test( + TestModule(target_shape), + inputs, + expected_ops={torch.ops.aten.view.default}, + ) + + @parameterized.expand( + [ + ((-1, 10),), + ((-1, 5),), + ((2, 2, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + input_specs = [ + InputTensorSpec( + shape=(-1, 2, 5), + dtype=torch.float32, + shape_ranges=[((1, 2, 5), (10, 2, 5), (10, 2, 5))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(target_shape), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) + + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape_size(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + shape_y = y.shape + t = shape_y[1] + return torch.reshape(x, [-1, t, 3]) + + input_specs = [ + InputTensorSpec( + shape=(-1, 5, 6), + dtype=torch.float32, + shape_ranges=[((1, 5, 6), (3, 5, 6), (3, 5, 6))], + ), + InputTensorSpec( + shape=(-1, 5), + dtype=torch.float32, + shape_ranges=[((1, 5), (3, 5), (3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py b/py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py new file mode 100644 index 0000000000..6f805421f4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py @@ -0,0 +1,28 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import operator + +import torch +import torch.fx +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import VanillaTestCase + + +class TestAddConverter(VanillaTestCase): + def test_operator_add(self): + def add(x): + return x + x + + inputs = [torch.randn(1, 1)] + self.run_test(add, inputs, expected_ops={operator.add}) + + def test_torch_add(self): + def add(x): + return torch.add(x, x) + + inputs = [torch.randn(1, 1)] + self.run_test(add, inputs, expected_ops={torch.add}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py b/py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py new file mode 100644 index 0000000000..c73c30a30e --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py @@ -0,0 +1,113 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import torch +import torch.fx +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.tools.common_fx2trt import VanillaTestCase + + +class TestConvolutionConverter(VanillaTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (0)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv1d}) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (0, 0)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv2d}) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (0, 0, 0)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + # TODO: Enable this when TRT fixes https://github.com/pytorch/TensorRT/issues/1445 + # param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv3d}) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py b/py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py new file mode 100644 index 0000000000..7373ddc4fe --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# (c) Facebook, Inc. and its affiliates. Confidential and proprietary. + +# Owner(s): ["oncall: gpu_enablement"] + +# Test that this import should not trigger any error when run +# in non-GPU hosts, or in any build mode. +import torch_tensorrt.dynamo.lower as fxl # noqa: F401 +from torch.testing._internal.common_utils import run_tests, TestCase + + +class MainTests(TestCase): + def test_1(self): + pass + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/core/test_input.py b/py/torch_tensorrt/dynamo/test/core/test_input.py new file mode 100644 index 0000000000..efe323c691 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/core/test_input.py @@ -0,0 +1,88 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import io +import os + +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import run_tests, TestCase + +class TestInput(TestCase): + def test_add_model(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return x + x + + inputs = [torch_tensorrt.Input(shape=(1, 3, 3, 4), dtype=torch.float32)] + rand_inputs = [torch.randn((1, 3, 3, 4), dtype=torch.float32).cuda()] + mod = TestModule().cuda().eval() + ref_output = mod(*rand_inputs) + + trt_mod = torch_tensorrt.compile( + mod, + ir="dynamo", + inputs=inputs, + min_block_size=1, + ) + trt_output = trt_mod(*rand_inputs) + + torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) + + def test_conv_model(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1, 1, 1, 1, 1, True) + + def forward(self, x): + return self.conv(x) + + inputs = [torch_tensorrt.Input(shape=(1, 3, 32, 32), dtype=torch.float32)] + rand_inputs = [torch.randn((1, 3, 32, 32), dtype=torch.float32).cuda()] + mod = TestModule().cuda().eval() + ref_output = mod(*rand_inputs) + + trt_mod = torch_tensorrt.compile( + mod, + ir="dynamo", + inputs=inputs, + min_block_size=1, + ) + trt_output = trt_mod(*rand_inputs) + + torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) + + def test_conv_model_with_dyn_shapes(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1, 1, 1, 1, 1, True) + + def forward(self, x): + return self.conv(x) + + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(8, 3, 32, 32), + max_shape=(16, 3, 32, 32), + dtype=torch.float32, + ) + ] + rand_inputs = [torch.randn((4, 3, 32, 32), dtype=torch.float32).cuda()] + mod = TestModule().cuda().eval() + ref_output = mod(*rand_inputs) + + trt_mod = torch_tensorrt.compile( + mod, + ir="dynamo", + inputs=inputs, + min_block_size=1, + ) + trt_output = trt_mod(*rand_inputs) + + torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py new file mode 100644 index 0000000000..89fbafe82b --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py @@ -0,0 +1,93 @@ +# Owner(s): ["oncall: gpu_enablement"] + +from typing import List, Optional + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo import generate_input_specs, InputTensorSpec, LowerSetting + + +class TestTRTModule(TestCase): + def _validate_spec( + self, + spec: InputTensorSpec, + tensor: torch.Tensor, + dynamic_dims: Optional[List[int]] = None, + ): + expected_shape = list(tensor.shape) + if dynamic_dims: + for dim in dynamic_dims: + expected_shape[dim] = -1 + self.assertSequenceEqual(spec.shape, expected_shape) + self.assertEqual(spec.dtype, tensor.dtype) + self.assertEqual(spec.device, tensor.device) + self.assertTrue(spec.has_batch_dim) + + def test_from_tensor(self): + tensor = torch.randn(1, 2, 3) + spec = InputTensorSpec.from_tensor(tensor) + self._validate_spec(spec, tensor) + + def test_from_tensors(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 4)] + specs = InputTensorSpec.from_tensors(tensors) + for spec, tensor in zip(specs, tensors): + self._validate_spec(spec, tensor) + + def test_from_tensors_with_dynamic_batch_size(self): + tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range + ) + for spec, tensor in zip(specs, tensors): + self._validate_spec(spec, tensor, dynamic_dims=[0]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[0]) + self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + + def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range, batch_dims=[0, 1] + ) + for i, spec_and_tensor in enumerate(zip(specs, tensors)): + spec, tensor = spec_and_tensor + self._validate_spec(spec, tensor, dynamic_dims=[i]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[i]) + tensor_shape = list(tensor.shape) + tensor_shape[i] = batch_size + self.assertSequenceEqual(tensor_shape, shape) + + def test_generate_input_specs(self): + lower_setting = LowerSetting( + explicit_batch_dimension=False, opt_profile_replica=2 + ) + + # Implicit batch dim. + inputs = [torch.randn(1, 2, 3)] + specs = generate_input_specs(inputs, lower_setting) + for spec, tensor in zip(specs, inputs): + self._validate_spec(spec, tensor) + + # Explicit batch dim without additional inputs. + lower_setting.explicit_batch_dimension = True + specs = generate_input_specs(inputs, lower_setting) + for spec, tensor in zip(specs, inputs): + self._validate_spec(spec, tensor, dynamic_dims=[0]) + self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + + # Explicit batch dim with additional inputs. + additional_inputs = [torch.randn(1, 1, 3)] + specs = generate_input_specs(inputs, lower_setting, additional_inputs) + for spec, tensor in zip(specs, inputs): + self._validate_spec(spec, tensor, dynamic_dims=[1]) + self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/core/test_trt_module.py b/py/torch_tensorrt/dynamo/test/core/test_trt_module.py new file mode 100644 index 0000000000..195c5ad65c --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/core/test_trt_module.py @@ -0,0 +1,147 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import io +import os + +import torch +import torch.fx + +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter, TRTModule + +# from torch_tensorrt import TRTModuleNext +# from torch_tensorrt import Device +from torch_tensorrt.dynamo.utils import LowerPrecision + + +class TestTRTModule(TestCase): + def test_save_and_load_trt_module(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return x + x + + inputs = [torch.randn(1, 1)] + mod = TestModule().eval() + ref_output = mod(*inputs) + + mod = acc_tracer.trace(mod, inputs) + interp = TRTInterpreter(mod, input_specs=InputTensorSpec.from_tensors(inputs)) + trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32)) + torch.save(trt_mod, "trt.pt") + reload_trt_mod = torch.load("trt.pt") + + torch.testing.assert_close( + reload_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04 + ) + os.remove(f"{os.getcwd()}/trt.pt") + + def test_save_and_load_state_dict(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return x + x + + inputs = [torch.randn(1, 1)] + mod = TestModule().eval() + ref_output = mod(*inputs) + + mod = acc_tracer.trace(mod, inputs) + interp = TRTInterpreter(mod, input_specs=InputTensorSpec.from_tensors(inputs)) + trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32)) + st = trt_mod.state_dict() + + new_trt_mod = TRTModule() + new_trt_mod.load_state_dict(st) + + torch.testing.assert_close( + new_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04 + ) + + +# TODO add unittest.skip later +# class TestTRTModuleNext(TestCase): +# def test_save_and_load_trt_module(self): +# class TestModule(torch.nn.Module): +# def forward(self, x): +# return x + x + +# inputs = [torch.randn(1, 1)] +# mod = TestModule().eval() +# ref_output = mod(*inputs) + +# mod = acc_tracer.trace(mod, inputs) + +# interp = TRTInterpreter( +# mod, +# input_specs=InputTensorSpec.from_tensors(inputs), +# explicit_batch_dimension=True, +# ) +# interp_res = interp.run(lower_precision=LowerPrecision.FP32) + +# with io.BytesIO() as engine_bytes: +# engine_bytes.write(interp_res.engine.serialize()) +# engine_str = engine_bytes.getvalue() + +# trt_mod = TRTModuleNext( +# name="TestModule", +# serialized_engine=engine_str, +# input_binding_names=interp_res.input_names, +# output_binding_names=interp_res.output_names, +# target_device=Device(f"cuda:{torch.cuda.current_device()}"), +# ) + +# torch.save(trt_mod, "trt.pt") +# reload_trt_mod = torch.load("trt.pt") + +# torch.testing.assert_allclose( +# reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), +# ref_output, +# rtol=1e-04, +# atol=1e-04, +# ) +# os.remove(f"{os.getcwd()}/trt.pt") + +# def test_save_and_load_state_dict(self): +# class TestModule(torch.nn.Module): +# def forward(self, x): +# return x + x + +# inputs = [torch.randn(1, 1)] +# mod = TestModule().eval() +# ref_output = mod(*inputs) + +# mod = acc_tracer.trace(mod, inputs) +# interp = TRTInterpreter( +# mod, +# input_specs=InputTensorSpec.from_tensors(inputs), +# explicit_batch_dimension=True, +# ) +# interp_res = interp.run(lower_precision=LowerPrecision.FP32) + +# with io.BytesIO() as engine_bytes: +# engine_bytes.write(interp_res.engine.serialize()) +# engine_str = engine_bytes.getvalue() + +# trt_mod = TRTModuleNext( +# name="TestModule", +# serialized_engine=engine_str, +# input_binding_names=interp_res.input_names, +# output_binding_names=interp_res.output_names, +# target_device=Device(f"cuda:{torch.cuda.current_device()}"), +# ) + +# st = trt_mod.state_dict() + +# new_trt_mod = TRTModuleNext() +# new_trt_mod.load_state_dict(st) + +# torch.testing.assert_allclose( +# new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), +# ref_output, +# rtol=1e-04, +# atol=1e-04, +# ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py new file mode 100644 index 0000000000..91d21b7fd0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py @@ -0,0 +1,72 @@ +import logging +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch_tensorrt.dynamo.passes.lower_basic_pass import fix_clamp_numerical_limits_to_fp16 + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: + """ + Helper func to print model's graph in plain and tabular format, also print code. + """ + _LOGGER.info(mod_graph.graph) + mod_graph.graph.print_tabular() + _LOGGER.info(mod_graph.code) + + +class ClampNumericalLimitsTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def test_clamp_numerical_limits_to_fp16(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.clamp(x + x, min=-1e8, max=1e8) + return y + + module = TestModule() + inputs = [torch.rand(3, 2, 1)] + + module.eval() + + # Before Opt + before_results = module(*inputs) + mod_traced = acc_tracer.trace(module, inputs) + before_node_list = list(mod_traced.graph.nodes) + clamp_node_before = [node for node in before_node_list if "clamp" in str(node)] + min_val_before = clamp_node_before[0].kwargs["min"] + max_val_before = clamp_node_before[0].kwargs["max"] + _LOGGER.info("Model before opt.") + debug_print_graph_module(mod_traced) + + # Apply Opt + module_after_pass = fix_clamp_numerical_limits_to_fp16(mod_traced, inputs) + + # After Opt + after_results = module_after_pass(*inputs) + after_node_list = list(mod_traced.graph.nodes) + clamp_node_after = [node for node in after_node_list if "clamp" in str(node)] + min_val_after = clamp_node_after[0].kwargs["min"] + max_val_after = clamp_node_after[0].kwargs["max"] + _LOGGER.info("Model after opt.") + mod_traced.recompile() + debug_print_graph_module(mod_traced) + + # Tests + # * Numerics + tol_args = {"rtol": 1e-2, "atol": 1e-2} + torch.testing.assert_close(before_results, after_results, **tol_args) + + # graph should not change + self.assertTrue(before_node_list == after_node_list) + + # values of clamp node changed + self.assertTrue(min_val_before != min_val_after) + self.assertTrue(max_val_before != max_val_after) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py new file mode 100644 index 0000000000..27ea6d038c --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py @@ -0,0 +1,51 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import logging +from copy import deepcopy + +import torch +import torch.fx as fx +import torch.nn as nn + +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo.passes.lower_basic_pass import fix_reshape_batch_dim +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer + +_LOGGER = logging.getLogger(__name__) + + +class TestFixReshapeBatchDim(TestCase): + def test_fix_reshape_batch_dim(self): + class Repro(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return y.view(x.size(0), -1, 3) + + mod = Repro() + modt = fx.symbolic_trace(mod) + inp = [ + torch.rand([10, 60]), + torch.rand([10, 60]), + ] + mod(*inp) + mod_acc_traced = acc_tracer.trace(modt, inp) + mod_fixed = fix_reshape_batch_dim(deepcopy(mod_acc_traced)) + + expected_graph = r""" +graph(): + %x : [#users=0] = placeholder[target=x] + %y : [#users=2] = placeholder[target=y] + %size : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.size](args = (), kwargs = {input: %y}) + %getitem_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.getitem](args = (), kwargs = {idx: 0, input: %size}) + %reshape : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) + return reshape +""" + assert ( + str(mod_fixed.graph).strip() == expected_graph.strip() + ), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}" + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py new file mode 100644 index 0000000000..f2a4a89b69 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py @@ -0,0 +1,88 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.passes.lower_basic_pass import ( + fuse_permute_linear, + trt_transposed_linear, +) +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +class TestFusePermuteLinear(AccTestCase): + def test_fuse_permute_linear(self): + class TestModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x.permute(0, 2, 1)) + + inputs = [torch.randn(6, 10, 20)] + a = TestModule(10, 30) + self.run_test( + TestModule(10, 30), + inputs, + {trt_transposed_linear}, + apply_passes=[fuse_permute_linear], + ) + + def test_fuse_permute_linear_keep_permute(self): + """ + Fusion while keep permute node since permute has more than one consumers + """ + + class TestModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + y = x.permute(0, 2, 1) + return self.linear(y), y + + inputs = [torch.randn(6, 10, 20)] + a = TestModule(10, 30) + self.run_test( + TestModule(10, 30), + inputs, + {acc_ops.permute, trt_transposed_linear}, + apply_passes=[fuse_permute_linear], + ) + + # TODO: The following test has been disabled due to a bug in TRT 8.5.1.7 + # with self.linear2. Issue : https://github.com/pytorch/TensorRT/issues/1444 + @unittest.skip( + reason="test_multi_fuse_permute_linear has been disabled due to a bug in TRT 8.5.1.7 https://github.com/pytorch/TensorRT/issues/1444" + ) + def test_multi_fuse_permute_linear(self): + """ + Fusion when permute output is shared by multiple linears + """ + + class TestModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, out_features) + self.linear2 = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + y = x.permute(0, 2, 1) + return self.linear1(y) + self.linear2(y) + + inputs = [torch.randn(8, 10, 20)] + a = TestModule(10, 30) + self.run_test( + TestModule(10, 30), + inputs, + {trt_transposed_linear}, + apply_passes=[fuse_permute_linear], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py new file mode 100644 index 0000000000..f48c759be7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py @@ -0,0 +1,142 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.passes.lower_basic_pass import ( + fuse_permute_matmul, + trt_transposed_matmul, +) +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +def tranpose_last_two_dims(x): + return x.transpose(-1, -2) + + +def permute021(x): + return x.permute(0, 2, 1) + + +class TestFusePermuteMatmul(AccTestCase): + @parameterized.expand( + [ + ("transpose_lhs_bmm", (3, 3, 2), (3, 3, 4), tranpose_last_two_dims), + param( + "transpose_rhs_bmm", (3, 2, 3), (3, 4, 3), rhs_op=tranpose_last_two_dims + ), + ("permute_lhs_bmm", (3, 3, 2), (3, 3, 4), permute021), + param("permute_rhs_bmm", (3, 2, 3), (3, 4, 3), rhs_op=permute021), + ("permute_both_bmm", (3, 3, 2), (3, 4, 3), permute021, permute021), + ( + "permute_both_matmul", + (3, 2, 3, 2), + (3, 2, 4, 3), + lambda x: x.permute(0, 1, 3, 2), + lambda x: x.permute(0, 1, 3, 2), + torch.matmul, + ), + param( + "transpose_lhs_bmm_broadcast", + (3, 2), + (3, 3, 4), + tranpose_last_two_dims, + op=torch.matmul, + ), + param( + "transpose_rhs_bmm_broadcast", + (3, 3, 4), + (3, 4), + rhs_op=tranpose_last_two_dims, + op=torch.matmul, + ), + ] + ) + def test_fuse_permute_matmul( + self, + _, + lhs_shape, + rhs_shape, + lhs_op=lambda x: x, + rhs_op=lambda x: x, + op=torch.bmm, + ): + class TestModule(torch.nn.Module): + def forward(self, x, y): + return op(lhs_op(x), rhs_op(y)) + + inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)] + self.run_test( + TestModule(), + inputs, + {trt_transposed_matmul}, + apply_passes=[fuse_permute_matmul], + test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)), + ) + + @parameterized.expand( + [ + ("permute_both_bmm", (3, 3, 2), (3, 4, 3), permute021, permute021), + ] + ) + def test_fuse_permute_matmul_keep_permute( + self, + _, + lhs_shape, + rhs_shape, + lhs_op=lambda x: x, + rhs_op=lambda x: x, + op=torch.bmm, + ): + """ + Fusion permute while keep permute node which has more than one consumers + """ + + class TestModule(torch.nn.Module): + def forward(self, x, y): + z = lhs_op(x) + return op(z, rhs_op(y)), z + + inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)] + self.run_test( + TestModule(), + inputs, + {trt_transposed_matmul, acc_ops.permute}, + apply_passes=[fuse_permute_matmul], + ) + + @parameterized.expand( + [ + ("permute_both_bmm", (3, 3, 2), (3, 4, 3), (3, 4, 3)), + ] + ) + def test_multifuse_permute_matmul( + self, + _, + x_shape, + y_shape, + z_shape, + ): + """ + Test cases when we have multiple bmm users of one permute + """ + + class TestModule(torch.nn.Module): + def forward(self, x, y, z): + x = permute021(x) + y = permute021(y) + z = permute021(z) + return torch.bmm(x, y) + torch.bmm(x, z) + + inputs = [torch.randn(*x_shape), torch.randn(*y_shape), torch.randn(*z_shape)] + self.run_test( + TestModule(), + inputs, + {trt_transposed_matmul}, + apply_passes=[fuse_permute_matmul], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py b/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py new file mode 100644 index 0000000000..8e75fbd17e --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py @@ -0,0 +1,187 @@ +import logging +import unittest +from collections import Counter +from typing import Callable, Dict, List + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch_tensorrt.dynamo.passes.graph_opts import common_subexpression_elimination + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: + """ + Helper func to print model's graph in plain and tabular format, also print code. + """ + _LOGGER.info(mod_graph.graph) + mod_graph.graph.print_tabular() + _LOGGER.info(mod_graph.code) + + +@torch.fx.wrap +def _test_op(keys, value): + return value + + +class GraphOptsTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def _test_opt_with_module( + self, + module: torch.nn.Module, + inputs: List, + opt: Callable, + should_change_graph: bool, + deleted_ops: Dict = None, + created_ops: Dict = None, + rtol: float = None, + atol: float = None, + ): + assert should_change_graph or not bool(deleted_ops or created_ops) + deleted_ops = deleted_ops or {} + created_ops = created_ops or {} + module.eval() + + # Before Opt + before_results = module(*inputs) + mod_traced = acc_tracer.trace(module, inputs) + before_node_list = list(mod_traced.graph.nodes) + _LOGGER.info("Model before opt.") + debug_print_graph_module(mod_traced) + + # Apply Opt + graph_changed = bool(opt(mod_traced)) + + # After Opt + after_results = mod_traced(*inputs) + after_node_list = list(mod_traced.graph.nodes) + _LOGGER.info("Model after opt.") + mod_traced.recompile() + debug_print_graph_module(mod_traced) + + # Tests + # * Numerics + tol_args = {} + if rtol is not None: + tol_args["rtol"] = rtol + if atol is not None: + tol_args["atol"] = atol + torch.testing.assert_close(before_results, after_results, **tol_args) + + # * opt changes graph + self.assertEqual(graph_changed, before_node_list != after_node_list) + self.assertEqual(should_change_graph, graph_changed) + + # * modified nodes + before_node_set = set(before_node_list) + after_node_set = set(after_node_list) + self.assertEqual( + dict(Counter([node.target for node in before_node_set - after_node_set])), + deleted_ops, + ) + self.assertEqual( + dict(Counter([node.target for node in after_node_set - before_node_set])), + created_ops, + ) + + return mod_traced + + def test_common_subexpression_elimination(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + xx = x + x + xx2 = x + x + return xx * xx2 - x + + self._test_opt_with_module( + module=TestModule(), + inputs=[torch.rand(3, 2, 1)], + opt=common_subexpression_elimination, + should_change_graph=True, + deleted_ops={acc_ops.add: 1}, + ) + + def test_common_subexpression_elimination2(self): + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + self._test_opt_with_module( + module=TestModule2(), + inputs=[torch.rand(3, 2, 1)], + opt=common_subexpression_elimination, + should_change_graph=False, + ) + + def test_common_subexpression_elimination3(self): + class TestModule3(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + x = a * b + y = b - c + z = a * b + xy = x + y + zy = z + y + return xy - zy + + self._test_opt_with_module( + module=TestModule3(), + inputs=[ + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + ], + opt=common_subexpression_elimination, + should_change_graph=True, + deleted_ops={acc_ops.add: 1, acc_ops.mul: 1}, + ) + + def test_common_subexpression_elimination4(self): + class TestModule3(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + x = torch.cat([a, b, c]) + y = torch.cat([a, b, c]) + z = torch.cat([c, b, a]) + return x + y + z + + self._test_opt_with_module( + module=TestModule3(), + inputs=[ + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + ], + opt=common_subexpression_elimination, + should_change_graph=True, + deleted_ops={acc_ops.cat: 1}, + ) + + def test_common_subexpression_elimination_string_arg(self): + class TestModule(torch.nn.Module): + def forward(self, a): + x = _test_op(["foo", "bar"], a) + return x + + self._test_opt_with_module( + module=TestModule(), + inputs=[ + torch.rand(3, 2, 1), + ], + opt=common_subexpression_elimination, + should_change_graph=False, + ) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py new file mode 100644 index 0000000000..d8edef52b9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py @@ -0,0 +1,66 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.passes.lower_basic_pass import ( + fuse_permute_linear, + fuse_permute_matmul, + trt_transposed_linear, + trt_transposed_matmul, +) +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +def permute021(x): + return x.permute(0, 2, 1) + + +class TestMultiFuse(AccTestCase): + @parameterized.expand( + [ + ("permute_both_bmm", (3, 3, 2), (3, 4, 3), permute021, permute021), + ] + ) + def test_fuse_permute_matmul( + self, + _, + lhs_shape, + rhs_shape, + lhs_op=lambda x: x, + rhs_op=lambda x: x, + op=torch.bmm, + ): + """ + Module: permute1 with linear and matmul, permute2 with matmul. + Permute1 permute2 + | | | + linear matmul + Fusion should crete pass fuse_permute_matmul and fuse_permute_linear, and eliminate both + permute node. + """ + + class TestModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x, y): + z = lhs_op(x) + bmm = op(z, rhs_op(y)) + linear = self.linear(z) + return (bmm, linear) + + inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)] + self.run_test( + TestModule(3, 6), + inputs, + {trt_transposed_matmul, trt_transposed_linear}, + {acc_ops.permute}, + apply_passes=[fuse_permute_matmul, fuse_permute_linear], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py new file mode 100644 index 0000000000..2ab06be627 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py @@ -0,0 +1,73 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import logging + +import torch.fx as fx +import torch.nn as nn + +import torch_tensorrt.dynamo.passes.remove_duplicate_output_args as dedup +from torch.testing._internal.common_utils import run_tests, TestCase + +_LOGGER = logging.getLogger(__name__) + + +class TestFx2TrtPasses(TestCase): + def test_remove_duplicate_output_args(self): + class Sub(nn.Module): + def forward(self, x): + return (x, x) + + class Top(nn.Module): + def __init__(self): + super().__init__() + self.a = Sub() + + def forward(self, x): + a_res = self.a(x) + return a_res[0] + a_res[1] + + class Tracer(fx.Tracer): + def is_leaf_module(self, m, qn): + if isinstance(m, Sub): # don't trace into + return True + return False + + top = Top() + ttop = fx.GraphModule(top, Tracer().trace(top), "top") + ttop.a = fx.symbolic_trace(ttop.a) + + name_to_processed_subnet = dedup.remove_duplicate_output_args(ttop, ["a"]) + + ttop(1) # run inference should work + + processed_a = name_to_processed_subnet["a"] + *_, a_output = processed_a.module.graph.nodes + a_output: fx.Node + + ttop_graph_actual = str(ttop.graph).strip() + ttop_graph_expected = """ +graph(): + %x : [#users=1] = placeholder[target=x] + %a : [#users=2] = call_module[target=a](args = (%x,), kwargs = {}) + %getitem : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {}) + %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %getitem_1), kwargs = {}) + return add +""".strip() + assert ( + ttop_graph_expected == ttop_graph_actual + ), f"Unexpected ttop graph: {ttop_graph_actual}" + + ttop_a_graph_actual = str(ttop.a.graph).strip() + ttop_a_graph_expected = """ +graph(): + %x : [#users=1] = placeholder[target=x] + return (x,) +""".strip() + assert ( + ttop_a_graph_expected == ttop_a_graph_actual + ), f"Unexpected ttop.a graph: {ttop_a_graph_actual}" + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py new file mode 100644 index 0000000000..d5fce3778d --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py @@ -0,0 +1,600 @@ +import torch +import torch._dynamo as torchdynamo +from parameterized import parameterized +from torch._dynamo.optimizations import backends +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.passes.lower_basic_pass import transform_setitem +from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase + + +class TestTransformSetitem(AccTestCase): + def test_setitem1d(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + y[0:2] = x + return y + + inputs = [torch.randn(2), torch.randn(3)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + def test_setitem1d_c2(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + y[:-1] = x + y[1:] = x + return y + + inputs = [torch.randn(2), torch.randn(3)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + def test_setitem1d_c3(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + y[1] = x + return y + + inputs = [torch.randn(2), torch.randn(3)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (4, 2), (4, 5), 0, 2), + ("c2", (4, 2), (4, 5), 1, 3), + ] + ) + def test_setitem2d_1v(self, name, x_shape, y_shape, y_start, y_end): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, y_start:y_end] = x + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (4, 2), (8, 2), 0, 2), + ("c2", (4, 2), (8, 2), 1, 3), + ] + ) + def test_setitem2d_1v_ex(self, name, x_shape, y_shape, y_start, y_end): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[y_start:y_end, :] = x + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (4, 2), (4, 2), 0, 1), + ] + ) + def test_setitem2d_1v_ex2(self, name, x_shape, y_shape, y_start, y_end): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, y_start:y_end] = x[:, 0] + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (3, 2), (4, 5), 0, 3, 0, 2), + ("c2", (3, 2), (4, 5), 1, 4, 1, 3), + ] + ) + def test_setitem2d_2v(self, name, x_shape, y_shape, x_start, x_end, y_start, y_end): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[x_start:x_end, y_start:y_end] = x + y = y + 3 + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4), (2, 5, 6), 0, 3, 0, 4), + ("c2", (2, 3, 4), (2, 5, 6), 1, 4, 1, 5), + ] + ) + def test_setitem3d_2v(self, name, x_shape, y_shape, start_1, end_1, start_2, end_2): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, start_1:end_1, start_2:end_2] = x + y = y + 3 + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (3, 2, 4), (5, 2, 6), 0, 3, 0, 4), + ("c2", (3, 2, 4), (5, 2, 6), 1, 4, 1, 5), + ] + ) + def test_setitem3d_2v_ext( + self, name, x_shape, y_shape, start_0, end_0, start_2, end_2 + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[start_0:end_0, :, start_2:end_2] = x + y = y + 3 + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4), (4, 5, 6), 0, 2, 0, 3, 0, 4), + ("c2", (2, 3, 4), (4, 5, 6), 1, 3, 1, 4, 1, 5), + ] + ) + def test_setitem3d_3v( + self, name, x_shape, y_shape, start_0, end_0, start_1, end_1, start_2, end_2 + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[start_0:end_0, start_1:end_1, start_2:end_2] = x + y = y + 3 + x = y[start_0:end_0, start_1:end_1, start_2:end_2] + return x + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4, 5), (2, 3, 6, 7), 0, 4, 0, 5), + ("c2", (2, 3, 4, 5), (2, 3, 6, 7), 1, 5, 1, 6), + ] + ) + def test_setitem4d_2v(self, name, x_shape, y_shape, start_2, end_2, start_3, end_3): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, :, start_2:end_2, start_3:end_3] = x + y = y + 3 + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4, 5), (2, 5, 4, 7), 0, 3, 0, 5), + ("c2", (2, 3, 4, 5), (2, 5, 4, 7), 1, 4, 1, 6), + ] + ) + def test_setitem4d_2v_ext( + self, name, x_shape, y_shape, start_1, end_1, start_3, end_3 + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, start_1:end_1, :, start_3:end_3] = x + y = y + 3 + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4, 5), (2, 5, 6, 7), 0, 3, 0, 4, 0, 5), + ("c2", (2, 3, 4, 5), (2, 5, 6, 7), 1, 4, 1, 5, 1, 6), + ] + ) + def test_setitem4d_3v( + self, name, x_shape, y_shape, start_1, end_1, start_2, end_2, start_3, end_3 + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, start_1:end_1, start_2:end_2, start_3:end_3] = x + y = y + 3 + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), + ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), + ] + ) + def test_setitem4d_4v( + self, + name, + x_shape, + y_shape, + start_0, + end_0, + start_1, + end_1, + start_2, + end_2, + start_3, + end_3, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x + y = y + 3 + x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] + return x + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + @parameterized.expand( + [ + ("c1", (2, 3, 4, 5, 6), (4, 5, 6, 7, 6), 0, 2, 0, 3, 0, 4, 0, 5), + ] + ) + def test_setitem5d_warning( + self, + name, + x_shape, + y_shape, + start_0, + end_0, + start_1, + end_1, + start_2, + end_2, + start_3, + end_3, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3, :] = x + y = y + 3 + x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] + return x + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + def transform_fx(gm, example_inputs): + gm = transform_setitem(gm, example_inputs) + return gm + + optimize_mod = torchdynamo.optimize( + transform_fx, + nopython=True, + )(m) + + optimize_mod(*inputs) + + # test with torchdynamo + def test_setitem1d_trt(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[1] = x + return y + + inputs = [torch.randn(1), torch.randn(3)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + ref_output = m(*inputs) + + optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + + output = optimize_mod(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) + + @parameterized.expand( + [ + ("c1", (4, 2), (4, 5), 0, 2), + ("c2", (4, 2), (4, 5), 1, 3), + ] + ) + def test_setitem2d_1v_trt(self, name, x_shape, y_shape, y_start, y_end): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, y_start:y_end] = x + return y + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + ref_output = m(*inputs) + optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + output = optimize_mod(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) + + @parameterized.expand( + [ + ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), + ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), + ] + ) + def test_setitem4d_4v_trt( + self, + name, + x_shape, + y_shape, + start_0, + end_0, + start_1, + end_1, + start_2, + end_2, + start_3, + end_3, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x + y = y + 3 + x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] + return x + + inputs = [torch.randn(x_shape), torch.randn(y_shape)] + m = TestModule() + + inputs = [i.cuda() for i in inputs] + m.cuda() + + ref_output = m(*inputs) + optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + output = optimize_mod(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py b/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py new file mode 100644 index 0000000000..146f5a6932 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py @@ -0,0 +1,907 @@ +# Owner(s): ["oncall: quantization"] + +import copy +import itertools +import operator +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.quantized._reference as nnqr + +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch.ao.quantization import default_qconfig +from torch.ao.quantization.backend_config import ( + get_tensorrt_backend_config_dict, + ObservationType, +) +from torch.ao.quantization.fx.match_utils import MatchAllNode +from torch.ao.quantization.quantize_fx import ( + convert_to_reference_fx, + prepare_fx, + prepare_qat_fx, +) +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, +) +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter, TRTModule +from torch_tensorrt.dynamo.passes.lower_basic_pass import run_const_fold +from torch_tensorrt.fx.tracer.acc_tracer import acc_ops +from torch_tensorrt.dynamo.utils import LowerPrecision + + +def lower_to_trt(model, inputs, shape_ranges): + """Lower a quantized model to TensorRT""" + assert len(inputs) == 1, "lower_to_trt only works for one input currently" + model = acc_tracer.trace(model, inputs) # type: ignore[attr-defined] + # TODO: test multiple inputs setting and enable multiple inputs + input_specs = [ + InputTensorSpec( + torch.Size([-1, *inputs[0].shape[1:]]), + torch.float, + shape_ranges=shape_ranges, + has_batch_dim=True, + ) + ] + + interp = TRTInterpreter( + model, input_specs, explicit_batch_dimension=True, explicit_precision=True + ) + result = interp.run(lower_precision=LowerPrecision.INT8) + trt_mod = TRTModule(result.engine, result.input_names, result.output_names) + return trt_mod + + +class TestConvertFxDoNotUse(QuantizationTestCase): + def setUp(self): + super().setUp() + self.trt_qconfig = torch.ao.quantization.QConfig( + activation=torch.ao.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), + weight=torch.ao.quantization.default_weight_observer, + ) + self.trt_backend_config_dict = get_tensorrt_backend_config_dict() + + def _test_quantized_inputs_outputs( + self, prepare_custom_config_dict, prepare_count_check, convert_count_check + ): + """ + Test the option to have inputs and outputs of the graph quantized + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + # quantized input, quantized output + m = M() + m.eval() + qconfig_dict = {"": torch.ao.quantization.default_qconfig} + example_inputs = (torch.rand(1, 1, 3, 3),) + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, + qconfig_dict, + example_inputs, + prepare_custom_config=prepare_custom_config_dict, + ) + self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) + mp(torch.randn(1, 1, 4, 4)) + mq = convert_to_reference_fx(mp, backend_config=self.trt_backend_config_dict) + self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) + + def test_quantized_input_quantized_output(self): + prepare_custom_config_dict = { + "input_quantized_idxs": [0], + "output_quantized_idxs": [0], + } + prepare_count_check = { + ns.call_module(torch.ao.quantization.MinMaxObserver): 2, + } + convert_count_check = { + # output of ref conv1 and output of ref conv2 + ns.call_function(torch.quantize_per_tensor): 2, + # input of ref conv1 and input of ref conv2 + ns.call_method("dequantize"): 2, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check + ) + + def test_fp32_input_quantized_output(self): + prepare_custom_config_dict = {"output_quantized_idxs": [0]} + prepare_count_check = { + ns.call_module(torch.ao.quantization.MinMaxObserver): 3, + } + convert_count_check = { + # input, output of conv1 and output of conv2 + ns.call_function(torch.quantize_per_tensor): 3, + # input of conv1, conv2 + ns.call_method("dequantize"): 2, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check + ) + + def test_quantized_input_fp32_output(self): + prepare_custom_config_dict = {"input_quantized_idxs": [0]} + prepare_count_check = { + ns.call_module(torch.ao.quantization.MinMaxObserver): 2, + } + convert_count_check = { + # output of conv1, conv2 + ns.call_function(torch.quantize_per_tensor): 2, + # input of ref conv1, input of ref conv2, final output + ns.call_method("dequantize"): 3, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check + ) + + def test_fp32_input_fp32_output(self): + prepare_custom_config_dict = {} + prepare_count_check = { + ns.call_module(torch.ao.quantization.MinMaxObserver): 3, + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor): 3, + ns.call_method("dequantize"): 3, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check + ) + + def _test_standalone_module( + self, + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check, + qconfig=None, + backend_config_dict=None, + ): + """Test standalone module with different quantized input/quantized output + configurations + """ + + class StandaloneModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.standalone = StandaloneModule() + + def forward(self, x): + x = self.conv(x) + x = self.standalone(x) + return x + + class RefM(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + if backend_config_dict is None: + backend_config_dict = self.trt_backend_config_dict + if qconfig is None: + qconfig = self.trt_qconfig + + data = torch.randn(1, 1, 1, 1) + # instantiate M and RefM and align the parameters + original_m = M().eval() + original_ref_m = RefM().eval() + original_ref_m.conv1.weight = torch.nn.Parameter( + original_m.conv.weight.detach() + ) + original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) + original_ref_m.conv2.weight = torch.nn.Parameter( + original_m.standalone.conv.weight.detach() + ) + original_ref_m.conv2.bias = torch.nn.Parameter( + original_m.standalone.conv.bias.detach() + ) + + sm_example_inputs = (data,) + prepare_config = { + "standalone_module_name": [ + ( + "standalone", + None, + sm_example_inputs, + interface_config, + backend_config_dict, + ) + ] + } + + original_m_copy = copy.deepcopy(original_m) + original_ref_m_copy = copy.deepcopy(original_ref_m) + + qconfig_dict = {"": qconfig} + example_inputs = (data,) + # check prepared model + m = prepare_fx( + original_m_copy, + qconfig_dict, + example_inputs, + prepare_custom_config=prepare_config, + backend_config=backend_config_dict, + ) + # calibration + m(data) + self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) + self.checkGraphModuleNodes( + m.standalone, expected_node_occurrence=standalone_prepare_count_check + ) + + # check converted/quantized model + m = convert_to_reference_fx(m, backend_config=backend_config_dict) + self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) + self.checkGraphModuleNodes( + m.standalone, expected_node_occurrence=standalone_convert_count_check + ) + res = m(data) + + # quantize the reference model + ref_m = prepare_fx( + original_ref_m_copy, + qconfig_dict, + example_inputs, + backend_config=backend_config_dict, + ) + ref_m(data) + ref_m = convert_to_reference_fx(ref_m, backend_config=backend_config_dict) + ref_res = ref_m(data) + self.assertEqual(res, ref_res) + + def test_standalone_module_float_interface(self): + float_interface_config = { + "input_quantized_idxs": [], # float input + "output_quantized_idxs": [], # float output + } + interface_config = float_interface_config + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + prepare_count_check = { + ns.call_module(torch.ao.quantization.HistogramObserver): 2 + } + # for input and output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.ao.quantization.HistogramObserver): 2 + } + convert_count_check = { + # input and output of reference conv + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_module(nnqr.Conv2d): 1, + ns.call_method("dequantize"): 2, + } + standalone_convert_count_check = { + # standalone module will take float as input and output + # so we'll see quantize and dequantize in the modoule + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_module(nnqr.Conv2d): 1, + ns.call_method("dequantize"): 2, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check, + ) + + def test_standalone_module_quantized_interface(self): + quantized_interface_config = { + "input_quantized_idxs": [0], # quantized input + "output_quantized_idxs": [0], # quantized output + } + interface_config = quantized_interface_config + # TODO: input_quantized_idxs only supports quint8, we can remove this + # custom_backend_config_dict after + # the `input_quantized_idxs` supports more complicated + # configurations, as a first step we can change it to use a dictionary from + # index to dtype + qconfig = torch.ao.quantization.QConfig( + activation=torch.ao.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 + ), + weight=torch.ao.quantization.default_weight_observer, + ) + weighted_op_quint8_dtype_config = { + # optional, input activation dtype + "input_dtype": torch.quint8, + # optional, weight dtype + "weight_dtype": torch.qint8, + # optional, bias dtype + "bias_dtype": torch.float, + # optional, output activation dtype + "output_dtype": torch.quint8, + } + conv_module_config = { + "pattern": torch.nn.Conv2d, + "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + "dtype_configs": [ + weighted_op_quint8_dtype_config, + ], + "root_module": torch.nn.Conv2d, + "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, + } + custom_backend_config_dict = {"configs": [conv_module_config]} + # observer for input and output of first conv + prepare_count_check = { + ns.call_module(torch.ao.quantization.HistogramObserver): 2 + } + # for output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.ao.quantization.HistogramObserver): 1 + } + convert_count_check = { + # quantizing input/output for reference conv + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_module(nnqr.Conv2d): 1, + # dequantize the input of reference conv and + # dequantizing output of standalone module + ns.call_method("dequantize"): 2, + } + standalone_convert_count_check = { + # quantization of input happens in parent module + # quantization of output happens in the standalone module + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_module(nnqr.Conv2d): 1, + # dequantization of input happens in the standalone module + # dequantization for output happens in parent module + ns.call_method("dequantize"): 1, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check, + qconfig=qconfig, + backend_config_dict=custom_backend_config_dict, + ) + + +@unittest.skipIf(not TEST_CUDA, "gpu is not available.") +class TestQuantizeFxTRTOps(QuantizationTestCase): + """Test TensorRT operator support""" + + def setUp(self): + super().setUp() + self.trt_qconfig = torch.ao.quantization.QConfig( + activation=torch.ao.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), + weight=torch.ao.quantization.default_weight_observer, + ) + self.trt_backend_config_dict = get_tensorrt_backend_config_dict() + + def _test_module( + self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False + ): + """ + Args: + m: the float module we want to test + inputs: list of inputs for the module + shape_ranges: a list of shape_range, where every shape_range is a tuple of + three tuples + ((min_input_shape), (optimized_input_shape), (max_input_shape)). + Each shape_range is used to populate a TensorRT optimization profile. + e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize + for (25, 224) because it's the most common input shape, then we set shape_ranges to + ((1, 224), (25, 225), (100, 224)) + no_prepare: node occurrence after prepare + no_convert: node occurrence after convert + """ + if is_qat: + m = m.train() + prepare = prepare_qat_fx + else: + m = m.eval() + prepare = prepare_fx + example_inputs = tuple(inputs) + prepared = prepare( + m, + {"": self.trt_qconfig}, + example_inputs, + backend_config=self.trt_backend_config_dict, + ) + self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare) + # calibration + prepared(*inputs) + quantized = convert_to_reference_fx( + prepared, + backend_config=self.trt_backend_config_dict, + ) + self.checkGraphModuleNodes(quantized, expected_node_occurrence=no_convert) + # lower to trt + trt_mod = lower_to_trt(quantized, inputs, shape_ranges) + inputs_cuda = [i.cuda() for i in inputs] + # make sure it runs + trt_mod(*inputs_cuda) + + def test_conv_relu_module(self): + conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + + conv1d_input = torch.rand(1, 3, 10) + conv2d_input = torch.rand(1, 3, 10, 10) + conv3d_input = torch.rand(1, 3, 10, 10, 10) + conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input} + + class ConvNdModule(torch.nn.Module): + def __init__(self, dim, has_relu=False, f_relu=False): + super().__init__() + self.conv = conv_module[dim](3, 3, 3).float() + if has_relu: + if f_relu: + self.relu = F.relu + else: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + return self.relu(self.conv(x)) + + # just testing conv2d since conv1d and conv3d are not supported in fx2trt + for dim, has_relu, f_relu, is_qat in itertools.product( + [1, 2], [True, False], [True, False], [True, False] + ): + # when has_relu=False, we have torch.nn.Identity, which would introduce + # extra quant-dequat pair + no_convert = { + ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), + ns.call_method("dequantize"): 2 + int(not has_relu), + } + self._test_module( + ConvNdModule(dim, has_relu, f_relu), + [conv_input[dim]], + [ + ( + (1, *conv_input[dim].shape[1:]), + (5, *conv_input[dim].shape[1:]), + (10, *conv_input[dim].shape[1:]), + ) + ], + no_convert=no_convert, + is_qat=is_qat, + ) + + def test_linear_relu_module(self): + class LinearModule(torch.nn.Module): + def __init__(self, has_relu=False, f_relu=False): + super().__init__() + self.linear = torch.nn.Linear(5, 10).float() + if has_relu: + if f_relu: + self.relu = F.relu + else: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + return self.relu(self.linear(x)) + + linear_input = torch.rand(8, 5) + + shape_ranges = [((1, 5), (5, 5), (10, 5))] + for has_relu, f_relu, is_qat in itertools.product( + [True, False], [True, False], [True, False] + ): + # when has_relu=False, we have torch.nn.Identity, which would introduce + # extra quant-dequat pair + no_convert = { + ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), + ns.call_method("dequantize"): 2 + int(not has_relu), + } + self._test_module( + LinearModule(has_relu, f_relu), + [linear_input], + shape_ranges, + no_convert=no_convert, + is_qat=is_qat, + ) + + def test_ops(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.linear = torch.nn.Linear(5, 5) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.linear(x) + x = x + 3 + x = self.relu(x) + x = x + 6 + return x + + m = M().eval() + example_inputs = (torch.rand(1, 3, 5, 5),) + m = prepare_fx( + m, + {"": self.trt_qconfig}, + example_inputs, + backend_config=self.trt_backend_config_dict, + ) + m = convert_to_reference_fx(m, backend_config=self.trt_backend_config_dict) + expected_occurrence = { + ns.call_function(torch.quantize_per_tensor): 5, + ns.call_method("dequantize"): 5, + ns.call_module(torch.nn.quantized._reference.Linear): 1, + ns.call_module(torch.nn.quantized._reference.Conv2d): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) + + def test_unsupported_qconfig(self): + """Check that we won't quantize the model if the qconfig is not supported""" + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + linear_module_input = torch.rand(8, 5) + + m = LinearModule().eval() + trt_unsupported_qconfig = default_qconfig + example_inputs = (torch.rand(1, 5),) + prepared = prepare_fx( + m, + {"": trt_unsupported_qconfig}, + example_inputs=example_inputs, + backend_config=self.trt_backend_config_dict, + ) + # calibration + prepared(linear_module_input) + quantized = convert_to_reference_fx( + prepared, + backend_config=self.trt_backend_config_dict, + ) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 0, + ns.call_method("dequantize"): 0, + ns.call_module(torch.nn.Linear): 1, + ns.call_module(torch.nn.quantized._reference.Linear): 0, + } + # check model is not quantized + self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) + + def test_cat(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cat([x, x], 1) + + m = M().eval() + example_inputs = (torch.rand(2, 2),) + prepared = prepare_fx( + m, + {"": self.trt_qconfig}, + example_inputs, + backend_config=self.trt_backend_config_dict, + ) + self.assertTrue(len(dict(prepared.named_children())) == 1) + quantized = convert_to_reference_fx( + prepared, + backend_config=self.trt_backend_config_dict, + ) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_function(torch.cat): 1, + ns.call_method("dequantize"): 2, + } + self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) + + def test_addmm(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(5, 5) + self.bias = torch.randn(5) + + def forward(self, x): + return torch.addmm(self.bias, x, self.weight) + + m = M().eval() + example_inputs = (torch.rand(1, 5),) + prepared = prepare_fx( + m, + {"": self.trt_qconfig}, + example_inputs, + backend_config=self.trt_backend_config_dict, + ) + node_occurrence = { + # weight + ns.call_module(torch.ao.quantization.MinMaxObserver): 1, + # activation + ns.call_module(torch.ao.quantization.HistogramObserver): 2, + } + self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) + quantized = convert_to_reference_fx( + prepared, + backend_config=self.trt_backend_config_dict, + ) + node_occurrence = { + # input activation, output activation and weight + ns.call_function(torch.quantize_per_tensor): 3, + ns.call_function(torch.addmm): 1, + ns.call_method("dequantize"): 3, + } + self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) + + @unittest.skip( + "This is not supported yet, we can enable the test after it's supported" + ) + def test_conv_add(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x, y): + return self.conv(x) + y + + weighted_op_qint8_dtype_config = { + # optional, input activation dtype + "input_dtype": torch.qint8, + # optional, weight dtype + "weight_dtype": torch.qint8, + # optional, bias dtype + "bias_dtype": torch.float, + # optional, output activation dtype + "output_dtype": torch.qint8, + } + + def conv_add_root_node_getter(pattern): + (_, conv, _) = pattern + return conv + + def conv_add_extra_inputs_getter(pattern): + _, _, extra_input = pattern + return [extra_input] + + conv_add_config = { + "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + "dtype_configs": [ + weighted_op_qint8_dtype_config, + ], + "root_node_getter": conv_add_root_node_getter, + "extra_inputs_getter": conv_add_extra_inputs_getter, + "root_module": torch.nn.Conv2d, + "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, + } + + if torch.__version__.startswith("1"): + conv_add_config["pattern"] = (operator.add, torch.nn.Conv2d, MatchAllNode) + else: + conv_add_config["pattern_complex_format"] = ( + operator.add, + torch.nn.Conv2d, + MatchAllNode, + ) + + m = M().eval() + modified_backend_config_dict = copy.deepcopy(self.trt_backend_config_dict) + modified_backend_config_dict["configs"].insert(0, conv_add_config) + example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1)) + m = prepare_fx( + m, + {"": self.trt_qconfig}, + example_inputs, + backend_config=modified_backend_config_dict, + ) + node_occurrence = { + ns.call_module(torch.ao.quantization.HistogramObserver): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + m = convert_to_reference_fx(m, backend_config=modified_backend_config_dict) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 3, + ns.call_method("dequantize"): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_conv_add_standalone_module(self): + class Standalone(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x, y): + return self.relu(self.conv(x) + y) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.standalone = Standalone() + + def forward(self, x): + y = self.conv(x) + return self.standalone(x, y) + + from torch.ao.quantization.backend_config import ObservationType + + weighted_op_quint8_dtype_config = { + # optional, input activation dtype + # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs + # are more flexible + "input_dtype": torch.quint8, + # optional, weight dtype + "weight_dtype": torch.qint8, + # optional, bias dtype + "bias_dtype": torch.float, + # optional, output activation dtype + "output_dtype": torch.quint8, + } + + conv_add_config = { + "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + "dtype_configs": [ + weighted_op_quint8_dtype_config, + ], + "root_module": torch.nn.Conv2d, + # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, + } + + if torch.__version__.startswith("1"): + conv_add_config["pattern"] = ( + torch.nn.ReLU, + (operator.add, torch.nn.Conv2d, MatchAllNode), + ) + else: + conv_add_config["pattern_complex_format"] = ( + torch.nn.ReLU, + (operator.add, torch.nn.Conv2d, MatchAllNode), + ) + + conv_config = { + "pattern": torch.nn.Conv2d, + "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + "dtype_configs": [ + weighted_op_quint8_dtype_config, + ], + "root_module": torch.nn.Conv2d, + # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, + } + + m = M().eval() + backend_config_dict = { + "configs": [ + conv_add_config, + conv_config, + ] + } + sm_example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1)) + prepare_custom_config_dict = { + "standalone_module_name": [ + ( + "standalone", + None, + sm_example_inputs, + {"input_quantized_idxs": [0, 1]}, + None, + ) + ] + } + # TODO: use self.trt_qconfig after input_quantized_idxs and output_quantized_idxs + # are more flexible + qconfig = torch.ao.quantization.QConfig( + activation=torch.ao.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 + ), + weight=torch.ao.quantization.default_weight_observer, + ) + example_inputs = (torch.rand(1, 3, 5, 5),) + m = prepare_fx( + m, + {"": qconfig}, + example_inputs, + prepare_custom_config=prepare_custom_config_dict, + backend_config=backend_config_dict, + ) + node_occurrence = { + # for input and output of conv, where input is used twice, once in conv and + # once in standalone module + ns.call_module(torch.ao.quantization.HistogramObserver): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + standalone_node_occurrence = { + # output of the standalone module + ns.call_module(torch.ao.quantization.HistogramObserver): 1, + } + self.checkGraphModuleNodes( + m.standalone, expected_node_occurrence=standalone_node_occurrence + ) + m = convert_to_reference_fx(m, backend_config=backend_config_dict) + node_occurrence = { + # two inputs for standalone module + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_module(nn.Conv2d): 1, + ns.call_method("dequantize"): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + standalone_node_occurrence = { + # output for the pattern in standalone module + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_module(nn.Conv2d): 1, + ns.call_module(torch.nn.ReLU): 1, + # two input and one output for the pattern in standalone module + ns.call_method("dequantize"): 3, + } + self.checkGraphModuleNodes( + m.standalone, expected_node_occurrence=standalone_node_occurrence + ) + + def test_quant_dequant_not_fold(self): + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10).float() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + model = LinearModule().eval() + inputs = [torch.rand(8, 5)] + example_inputs = tuple(inputs) + prepared = prepare_fx( + model, + {"": self.trt_qconfig}, + example_inputs, + backend_config=self.trt_backend_config_dict, + ) + quantized = convert_to_reference_fx( + prepared, + backend_config=self.trt_backend_config_dict, + ) + + model = acc_tracer.trace(quantized, inputs) + model = run_const_fold(model) + + no_const = { + ns.call_function(acc_ops.quantize_per_tensor): 3, + ns.call_function(acc_ops.dequantize): 3, + } + self.checkGraphModuleNodes(model, expected_node_occurrence=no_const) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/tools/test_model_packager.py b/py/torch_tensorrt/dynamo/test/tools/test_model_packager.py new file mode 100644 index 0000000000..00293ccadb --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/tools/test_model_packager.py @@ -0,0 +1,56 @@ +import io +import unittest + +import torch +import torch.fx +from torch import nn +from torch.package import PackageImporter +from torch_tensorrt.dynamo.tools.model_packager import ( + generate_standalone_repro, + ModelPackager, +) + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Module() + self.b = torch.nn.Module() + self.a.weights = torch.nn.Parameter(torch.randn(1, 2)) + self.b.weights = torch.nn.Parameter( + torch.randn( + 1, + ) + ) + + def forward(self, x): + return x + self.a.weights + self.b.weights + + +class ModelPackagerTest(unittest.TestCase): + def test_text_repro_gen(self): + model = torch.fx.symbolic_trace(TestModel().eval()) + inputs = [torch.randn(1)] + _ = model(*inputs) + + string_io = io.StringIO() + generate_standalone_repro(model, string_io, "\n# hello") + string_io.seek(0) + exec(string_io.read()) + exported_model = locals()["ExportedModule"]() + _ = exported_model(*inputs) + + def test_package_model(self): + model = torch.fx.symbolic_trace(TestModel().eval()) + inputs = [torch.randn(1)] + _ = model(*inputs) + bytesIO = io.BytesIO() + ModelPackager.package_model(model, inputs, bytesIO) + bytesIO.seek(0) + pi = PackageImporter(bytesIO) + reload_model = pi.load_pickle("repro", "model") + reload_inputs = pi.load_pickle("repro", "inputs") + + torch.testing.assert_close(model(*inputs), reload_model(*reload_inputs)) + keys = dict(reload_model.named_children()).keys() + self.assertEqual(keys, {"_holder"}) diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_acc_shape_prop.py b/py/torch_tensorrt/dynamo/test/tracer/test_acc_shape_prop.py new file mode 100644 index 0000000000..a2f842b722 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/tracer/test_acc_shape_prop.py @@ -0,0 +1,98 @@ +# Owner(s): ["oncall: fx"] + +import operator +import unittest + +import torch + +import torch_tensorrt.fx.tracer.acc_tracer.acc_shape_prop as acc_shape_prop +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from parameterized import param, parameterized + +torch.manual_seed(0) + + +class AccShapePropTest(unittest.TestCase): + @parameterized.expand( + [ + param("fp32", dtype=torch.float32), + param("fp16", dtype=torch.float16), + ] + ) + def test_basic(self, _, dtype): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr = torch.nn.Parameter(torch.randn(3, 4)) + self.submod = torch.nn.Linear(4, 4) + + def forward(self, x): + return torch.neg(self.submod(x.relu() + self.attr)) + + m = TestModule() + if dtype == torch.float16: + m.half() + gm = acc_tracer.rewriter_base_trace(m, None, None) + inp = torch.rand(3, 4, dtype=dtype) + acc_shape_prop.AccShapeProp(gm).propagate(inp) + + for node in gm.graph.nodes: + self.assertEqual(node.meta["tensor_meta"].dtype, dtype) + + def test_mutli_dtype(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + return torch.relu(x * 2), torch.sigmoid(y + y) + + m = TestModule() + gm = acc_tracer.rewriter_base_trace(m, None, None) + # Note: One input is fp32, the other fp16. + x, y = torch.rand(3, 4), torch.rand(3, 4, dtype=torch.float16) + acc_shape_prop.AccShapeProp(gm).propagate(x, y) + + for node in gm.graph.nodes: + if (node.op == "placeholder" and node.target == "x") or ( + node.op == "call_function" and node.target in {operator.mul, torch.relu} + ): + self.assertEqual(node.meta["tensor_meta"].dtype, torch.float32) + elif node.op != "output": + self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) + else: + self.assertEqual(node.meta["tensor_meta"][0].dtype, torch.float32) + self.assertEqual(node.meta["tensor_meta"][1].dtype, torch.float16) + + def test_to_dtype(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return x.to(dtype=torch.float32).to(dtype=torch.float16) + + m = TestModule() + gm = acc_tracer.rewriter_base_trace(m, None, None) + x = torch.rand(3, 4, dtype=torch.float16) + acc_shape_prop.AccShapeProp(gm).propagate(x) + ph = None + for node in gm.graph.nodes: + if node.op == "placeholder": + ph = node + self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) + elif node.all_input_nodes == [ph]: + self.assertEqual(node.meta["tensor_meta"].dtype, torch.float32) + else: + self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) + + def test_split(self): + class TestModule(torch.nn.Module): + def forward(self, x): + s = torch.tensor_split(x, 2) + return s[0].relu(), s[1].sigmoid() + + m = TestModule() + gm = acc_tracer.rewriter_base_trace(m, None, None) + x = torch.rand(2, 4, dtype=torch.float16) + acc_shape_prop.AccShapeProp(gm).propagate(x) + for node in gm.graph.nodes: + if node.target == torch.tensor_split or node.op == "output": + self.assertEqual(node.meta["tensor_meta"][0].dtype, torch.float16) + self.assertEqual(node.meta["tensor_meta"][1].dtype, torch.float16) + else: + self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/dynamo/test/tracer/test_acc_tracer.py new file mode 100644 index 0000000000..633359127f --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/tracer/test_acc_tracer.py @@ -0,0 +1,2801 @@ +# Owner(s): ["oncall: fx"] +import logging +import operator +import unittest +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +import torch_tensorrt.fx.tracer.acc_tracer.acc_normalizer as acc_normalizer +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +import torchvision +from parameterized import param, parameterized + +torch.manual_seed(0) + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +torch.fx.wrap("len") + + +class AccTracerTest(unittest.TestCase): + def _make_model_unit_test( + self, + model, + *args, + input_shape=None, + enable_allclose=False, + **kwargs, + ): + """ + Test that the model can be traced correctly and is producing correct + result. + """ + if input_shape is None: + input_shape = [1, 3, 224, 224] + input = torch.randn(input_shape) + traced = acc_tracer.trace(model, [input]) + if enable_allclose: + torch.testing.assert_close(model(input), traced(input)) + else: + self.assertTrue(torch.equal(model(input), traced(input))) + traced_again = acc_tracer.trace(traced, [input]) + if enable_allclose: + torch.testing.assert_close(model(input), traced_again(input)) + else: + self.assertTrue(torch.equal(model(input), traced_again(input))) + + def _make_acc_op_function_test( + self, + acc_op: Callable, + torch_op, + *args, + input_shape=(2, 3), + validate_same_kwargs=True, + enable_allclose=False, + **kwargs, + ): + """ + Test that acc_op is traced somewhat. + """ + + class TestModule(torch.nn.Module): + def __init__(self, torch_op, args, kwargs): + super().__init__() + self._torch_op = torch_op + self._args = args + self._kwargs = kwargs + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self._torch_op(a, *self._args, **self._kwargs) + + m = TestModule(torch_op, args, kwargs) + m.eval() + a = torch.randn(*input_shape) + traced = acc_tracer.trace(m, [a]) + ph_a = acc_op_node = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_op) + self.assertEqual(node.kwargs["input"], ph_a) + if validate_same_kwargs: + for key, value in kwargs.items(): + self.assertEqual(node.kwargs[key], value) + acc_op_node = node + elif node.op == "output": + if acc_op is None: + # If we expect no new acc_op after graph building + # and found we have only output in traced graph + continue + self.assertEqual(acc_op_node, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref_outputs = m(a) + outputs = traced(a) + traced_again = acc_tracer.trace(traced, [a]) + outputs_again = traced_again(a) + if isinstance(ref_outputs, torch.Tensor): + ref_outputs = [ref_outputs] + outputs = [outputs] + outputs_again = [outputs_again] + + for ref_output, output, output_again in zip( + ref_outputs, outputs, outputs_again + ): + if enable_allclose: + torch.testing.assert_close( + torch.nan_to_num(ref_output), torch.nan_to_num(output) + ) + torch.testing.assert_close( + torch.nan_to_num(ref_output), torch.nan_to_num(output_again) + ) + else: + self.assertTrue( + torch.equal(torch.nan_to_num(ref_output), torch.nan_to_num(output)) + ) + self.assertTrue( + torch.equal( + torch.nan_to_num(ref_output), torch.nan_to_num(output_again) + ) + ) + + def test_sum(self): + self._make_acc_op_function_test(acc_ops.sum, torch.sum) + self._make_acc_op_function_test(acc_ops.sum, torch.sum, dim=(1,), keepdim=True) + + def test_prod(self): + self._make_acc_op_function_test(acc_ops.prod, torch.prod) + self._make_acc_op_function_test(acc_ops.prod, torch.prod, dim=1, keepdim=True) + + def test_mean(self): + self._make_acc_op_function_test(acc_ops.mean, torch.mean) + self._make_acc_op_function_test( + acc_ops.mean, torch.mean, dim=(1,), keepdim=True + ) + + def test_pad(self): + self._make_acc_op_function_test( + acc_ops.pad, torch.nn.functional.pad, pad=(2, 0) + ) + + def test_max(self): + def torch_max(x, *args, **kwargs): + return x.max(*args, **kwargs) + + self._make_acc_op_function_test(acc_ops.max_full_reduce, torch_max) + self._make_acc_op_function_test( + acc_ops.max_dim_reduce, torch_max, dim=1, keepdim=True + ) + self._make_acc_op_function_test( + acc_ops.max_dim_reduce, torch_max, input_shape=(1, 4), dim=1, keepdim=True + ) + self._make_acc_op_function_test( + acc_ops.max_dim_reduce, torch_max, input_shape=(3, 4, 3), dim=2 + ) + + @parameterized.expand( + [ + param("max_maximum", orig_op=torch.max, expected_op=acc_ops.maximum), + param( + "maximum_maximum", orig_op=torch.maximum, expected_op=acc_ops.maximum + ), + param("min_minimum", orig_op=torch.min, expected_op=acc_ops.minimum), + param( + "minimum_minimum", orig_op=torch.minimum, expected_op=acc_ops.minimum + ), + ] + ) + def test_maximum_minimum(self, _: str, orig_op, expected_op): + class TestModule(torch.nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return self.orig_op(input, other) + + m = TestModule(orig_op) + input, other = torch.randn(2, 2), torch.randn(2, 2) + traced = acc_tracer.trace(m, [input, other]) + + ph_in = ph_oth = mxm = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "other": + ph_oth = node + else: + self.assertTrue(str(node.target) == "input") + ph_in = node + elif node.op == "call_function": + if node.target == expected_op: + self.assertEqual(node.kwargs["input"], ph_in) + self.assertEqual(node.kwargs["other"], ph_oth) + mxm = node + elif node.op == "output": + self.assertEqual(mxm, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input, other), traced(input, other))) + + def test_conv(self): + """ + Test that a conv is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(8, 7, 3, stride=2) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.randn(3, 8, 10, 10) + traced = acc_tracer.trace(m, [input]) + + ph = weight_attr = bias_attr = conv = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv.bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.conv2d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + self.assertEqual(node.kwargs["stride"], (2, 2)) + self.assertEqual(node.kwargs["padding"], (0, 0)) + self.assertEqual(node.kwargs["dilation"], (1, 1)) + self.assertEqual(node.kwargs["groups"], 1) + conv = node + elif node.op == "output": + self.assertEqual(conv, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_quantized_conv2d(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.quantized.Conv2d(3, 3, 1) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.quantize_per_tensor( + torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 + ) + traced = acc_tracer.trace(m, [input]) + _LOGGER.info(traced.graph) + ph = weight_attr = bias_attr = conv = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv_weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv_bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.quantized_conv2d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + conv = node + elif node.op == "output": + self.assertEqual(conv, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_quantized_convrelu2d(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.intrinsic.quantized.ConvReLU2d(3, 3, 1) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.quantize_per_tensor( + torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 + ) + traced = acc_tracer.trace(m, [input]) + ph = weight_attr = bias_attr = conv = relu = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv_weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv_bias": + bias_attr = node + elif node.op == "call_function" and node.target == acc_ops.quantized_conv2d: + self.assertEqual(node.target, acc_ops.quantized_conv2d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + conv = node + elif node.op == "call_function" and node.target == acc_ops.relu: + self.assertEqual(node.target, acc_ops.relu) + self.assertEqual(node.kwargs["input"], conv) + relu = node + elif node.op == "output": + self.assertEqual(relu, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_conv1d(self): + """ + Test that a conv is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(8, 7, 3, stride=2) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.randn(3, 8, 8) + traced = acc_tracer.trace(m, [input]) + + ph = weight_attr = bias_attr = conv = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv.bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.conv1d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + self.assertEqual(node.kwargs["stride"], (2,)) + self.assertEqual(node.kwargs["padding"], (0,)) + self.assertEqual(node.kwargs["dilation"], (1,)) + self.assertEqual(node.kwargs["groups"], 1) + conv = node + elif node.op == "output": + self.assertEqual(conv, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_conv3d(self): + """ + Test that a conv is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv3d(8, 7, 3, stride=2) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.randn(3, 8, 8, 10, 10) + traced = acc_tracer.trace(m, [input]) + + ph = weight_attr = bias_attr = conv = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv.bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.conv3d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + self.assertEqual(node.kwargs["stride"], (2, 2, 2)) + self.assertEqual(node.kwargs["padding"], (0, 0, 0)) + self.assertEqual(node.kwargs["dilation"], (1, 1, 1)) + self.assertEqual(node.kwargs["groups"], 1) + conv = node + elif node.op == "output": + self.assertEqual(conv, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_conv_transpose2d(self): + """ + Test that a conv_transpose2d is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.ConvTranspose2d(8, 7, 3, stride=2) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.randn(3, 8, 10, 10) + traced = acc_tracer.trace(m, [input]) + + ph = weight_attr = bias_attr = conv = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv.bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.conv_transpose2d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + self.assertEqual(node.kwargs["stride"], (2, 2)) + self.assertEqual(node.kwargs["padding"], (0, 0)) + self.assertEqual(node.kwargs["output_padding"], (0, 0)) + self.assertEqual(node.kwargs["groups"], 1) + self.assertEqual(node.kwargs["dilation"], (1, 1)) + conv = node + elif node.op == "output": + self.assertEqual(conv, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_conv_transpose3d(self): + """ + Test that a conv_transpose3d is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.ConvTranspose3d(8, 7, 3, stride=2) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.conv(a) + + m = TestModule() + input = torch.randn(3, 8, 8, 10, 10) + traced = acc_tracer.trace(m, [input]) + + ph = weight_attr = bias_attr = conv = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "conv.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "conv.bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.conv_transpose3d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + self.assertEqual(node.kwargs["stride"], (2, 2, 2)) + self.assertEqual(node.kwargs["padding"], (0, 0, 0)) + self.assertEqual(node.kwargs["output_padding"], (0, 0, 0)) + self.assertEqual(node.kwargs["dilation"], (1, 1, 1)) + self.assertEqual(node.kwargs["groups"], 1) + conv = node + elif node.op == "output": + self.assertEqual(conv, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_embedding_bag(self): + """ + Test that an embedding_bag is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.eb = nn.EmbeddingBag(10, 3, mode="sum", include_last_offset=True) + + def forward(self, inp: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + return self.eb(inp, offsets) + + m = TestModule() + inp = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) + offsets = torch.LongTensor([0, 4]) + traced = acc_tracer.trace(m, [inp, offsets]) + + inp_node = offsets_node = weight_attr = eb_node = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "inp": + inp_node = node + elif str(node.target) == "offsets": + offsets_node = node + else: + self.fail(f"Unexpected placeholder {node.target}.") + continue + elif node.op == "get_attr" and node.target == "eb.weight": + weight_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.embedding_bag) + # Note: Normalization called from acc_tracer means we use all kwargs. + self.assertEqual(node.kwargs["input"], inp_node) + self.assertEqual(node.kwargs["offsets"], offsets_node) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["mode"], "sum") + self.assertEqual(node.kwargs["include_last_offset"], True) + # The rest of these were unspecified, so verify they fell back + # to their respective default values thanks to normalization. + self.assertEqual(node.kwargs["max_norm"], None) + self.assertEqual(node.kwargs["norm_type"], 2.0) + self.assertEqual(node.kwargs["scale_grad_by_freq"], False) + self.assertEqual(node.kwargs["sparse"], False) + self.assertEqual(node.kwargs["per_sample_weights"], None) + eb_node = node + elif node.op == "output": + self.assertEqual(eb_node, node.args[0]) + + self.assertTrue(torch.equal(m(inp, offsets), traced(inp, offsets))) + + def test_embedding_bag_byte_and_4bit_rowwise_offsets(self): + """ + Test that 4 bit quantized embedding_bag is traced as expected. + """ + + class TestModule(nn.Module): + def __init__( + self, + op, + q_weights, + per_index_weights, + ): + super().__init__() + self.emb = op + self.q_weights = q_weights + self.per_index_weights = per_index_weights + + def forward( + self, + indices, + offsets, + ): + return self.emb( + self.q_weights, + indices, + offsets, + mode=0, + per_sample_weights=self.per_index_weights, + include_last_offset=True, + ) + + def run_embedding_bag_test(is_4bit, use_weights): + # generate random indices, offsets, and weights. + num_embeddings = 16 + embedding_dim = 32 + num_lengths = 10 + + weights = torch.from_numpy( + (np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype( + np.float32 + ) + ) + q_weights = ( + torch.ops.quantized.embedding_bag_4bit_prepack(weights) + if is_4bit + else torch.ops.quantized.embedding_bag_byte_prepack(weights) + ) + np_lengths = np.random.randint(0, num_lengths, size=10).astype(np.int32) + + num_lengths = np.sum(np_lengths) + indices = torch.from_numpy( + np.random.randint(low=0, high=num_embeddings, size=num_lengths) + ).int() + + lengths = torch.from_numpy(np_lengths) + offsets = torch.cat([torch.zeros([1]), torch.cumsum(lengths, 0)]).int() + + weights = torch.randint(low=0, high=4, size=indices.size()) + per_sample_weights = weights.to(torch.float32) + + indices = indices.to(torch.int32) + offsets = offsets.to(torch.int32) + inputs = [ + indices, + offsets, + ] + + op = ( + torch.ops.quantized.embedding_bag_4bit_rowwise_offsets + if is_4bit + else torch.ops.quantized.embedding_bag_byte_rowwise_offsets + ) + + m = TestModule( + op, + q_weights, + per_sample_weights, + ) + + traced = acc_tracer.trace(m, inputs) + _LOGGER.info(traced.graph) + + expected_target = ( + acc_ops.embedding_bag_4bit_rowwise_offsets + if is_4bit + else acc_ops.embedding_bag_byte_rowwise_offsets + ) + + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "indices": + inp_node = node + elif str(node.target) == "offsets": + offsets_node = node + else: + self.fail(f"Unexpected placeholder {node.target}.") + continue + elif node.op == "get_attr" and node.target == "q_weights": + weight_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, expected_target) + # Note: Normalization called from acc_tracer means we use all kwargs. + self.assertEqual(node.kwargs["indices"], inp_node) + self.assertEqual(node.kwargs["offsets"], offsets_node) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["mode"], 0) + self.assertEqual(node.kwargs["include_last_offset"], True) + # The rest of these were unspecified, so verify they fell back + # to their respective default values thanks to normalization. + eb_node = node + elif node.op == "output": + self.assertEqual(eb_node, node.args[0]) + self.assertTrue(torch.equal(m(indices, offsets), traced(indices, offsets))) + + # test 8-bit + run_embedding_bag_test(is_4bit=False, use_weights=True) + # test 4-bit + run_embedding_bag_test(is_4bit=True, use_weights=True) + + def test_quantized_batch_norm2d(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.bn = nn.quantized.BatchNorm2d(3) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.bn(a) + + m = TestModule() + m.eval() + input = torch.quantize_per_tensor( + torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 + ) + traced = acc_tracer.trace(m, [input]) + ph = weight_attr = bias_attr = bn_mean = bn_var = bn = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "bn.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "bn.bias": + bias_attr = node + elif node.op == "get_attr" and node.target == "bn.running_mean": + bn_mean = node + elif node.op == "get_attr" and node.target == "bn.running_var": + bn_var = node + elif node.op == "get_attr" and node.target == "bn.scale": + bn_scale = node + elif node.op == "get_attr" and node.target == "bn.zero_point": + bn_zero_point = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.quantized_batch_norm2d) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + self.assertEqual(node.kwargs["running_mean"], bn_mean) + self.assertEqual(node.kwargs["running_var"], bn_var) + self.assertEqual(node.kwargs["acc_out_ty"][6]["scale"], bn_scale) + self.assertEqual( + node.kwargs["acc_out_ty"][6]["zero_point"], bn_zero_point + ) + bn = node + elif node.op == "output": + self.assertEqual(bn, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_linear(self): + """ + Test that a linear is traced as expected, i.e. to the functional level and with + kwarg normalization. Also verify that symbolic shape inference worked as part of + the acc_tracer. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 5, bias=True) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.linear(a) + + m = TestModule() + test_input = torch.randn(1, 3) + traced = acc_tracer.trace(m, [test_input]) + ph = weight_attr = bias_attr = linear = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "linear.weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "linear.bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.linear) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + linear = node + elif node.op == "output": + self.assertEqual(linear, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + self.assertTrue(torch.equal(m(test_input), traced(test_input))) + + def test_quantized_linear(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.quantized.Linear(3, 5) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.linear(a) + + m = TestModule() + input = torch.quantize_per_tensor( + torch.randn(2, 3), scale=0.01, zero_point=3, dtype=torch.quint8 + ) + traced = acc_tracer.trace(m, [input]) + ph = weight_attr = bias_attr = linear = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "linear_weight": + weight_attr = node + elif node.op == "get_attr" and node.target == "linear_bias": + bias_attr = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.quantized_linear) + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight_attr) + self.assertEqual(node.kwargs["bias"], bias_attr) + linear = node + elif node.op == "output": + self.assertEqual(linear, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input), traced(input))) + + @parameterized.expand( + [ + param("remove_exceptions_false", remove_exceptions=False), + param("remove_exceptions_true", remove_exceptions=True), + ] + ) + def test_batch_norm(self, _, remove_exceptions): + """ + Test that a batch norm is traced as expected, i.e. to the functional level + and with kwarg normalization. Note that we also expect to see a + ConditionalExceptionWrapper in the graph that the AST rewriter converted + from `if x: raise y`. + + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(2) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.bn(a) + + m = TestModule() + input = torch.randn(2, 2, 1, 1) + # Note: Explicitly not removing exceptions so that we can check they + # were found and exist below. + traced = acc_tracer.trace( + m, + [input], + remove_exceptions=remove_exceptions, + ) + + ph = exception_wrapper = weight = bias = mean = var = bn = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "get_attr" and node.target == "bn.weight": + weight = node + elif node.op == "get_attr" and node.target == "bn.bias": + bias = node + elif node.op == "get_attr" and node.target == "bn.running_mean": + mean = node + elif node.op == "get_attr" and node.target == "bn.running_var": + var = node + elif node.op == "call_function" and node.target == acc_ops.batch_norm: + # Note: Normalization called from acc_tracer means we use + # all kwargs. + self.assertEqual(node.kwargs["input"], ph) + self.assertEqual(node.kwargs["weight"], weight) + self.assertEqual(node.kwargs["bias"], bias) + self.assertEqual(node.kwargs["running_mean"], mean) + self.assertEqual(node.kwargs["running_var"], var) + bn = node + elif ( + node.op == "call_module" + and node.target == "bn._conditional_exception_wrapper_ValueError" + ): + exception_wrapper = node + elif node.op == "output": + self.assertEqual(bn, node.args[0]) + + self.assertTrue(remove_exceptions or exception_wrapper is not None) + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_remove_asserts(self): + """ + Test that a Module with asserts has the asserts automatically removed, as + well as calls to a class method that should be dead. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def _test_method(self, a): + return a + + def forward(self, a: torch.Tensor) -> torch.Tensor: + assert torch.equal(self._test_method(a), a) + return a + + m = TestModule() + input = torch.randn(10) + traced = acc_tracer.trace(m, [input], ast_rewriter_allow_list={TestModule}) + # Check we have no call_functions. If remove asserts didn't work + # correctly we would see a call to torch._assert, _test_method, and + # torch.equal. + for node in traced.graph.nodes: + self.assertFalse(node.op == "call_function") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_no_rewrite_leaf_module(self): + """ + Test that when we supply a leaf module, we don't rewrite it + """ + + class TestChildModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return a.relu() + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.child = TestChildModule() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.child(a) + self.child(a) + + m = TestModule() + input = torch.randn(10) + traced = acc_tracer.trace(m, [input], leaf_module_list={TestChildModule}) + # trace it again just in case + traced = acc_tracer.trace(traced, [input], leaf_module_list={TestChildModule}) + + for _, m in traced.named_children(): + self.assertFalse("__AccRewrittenModule" in str(type(m)), str(type(m))) + + def test_sequential(self): + """ + Test that the tracer works for torch.nn.Sequential. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential(nn.Sigmoid(), nn.ReLU()) + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.model(a) + + m = TestModule() + input = torch.randn(10) + traced = acc_tracer.trace(m, [input]) + + for node in traced.graph.nodes: + if node.op == "call_function": + is_sigmoid = node.target == acc_ops.sigmoid + is_relu = node.target == acc_ops.relu + self.assertTrue(is_sigmoid or is_relu) + else: + self.assertTrue(node.op == "placeholder" or node.op == "output") + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_unsqueeze(self): + """ + Test that torch.unsqueeze is traced correctly. + """ + self._make_acc_op_function_test( + acc_ops.unsqueeze, + torch.unsqueeze, + validate_same_kwargs=False, + dim=1, + ) + + def test_stack(self): + """ + Test that torch.stack is traced correctly. + """ + + class TestModule(torch.nn.Module): + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.stack((a, b), dim=1) + + a, b = torch.randn(4, 5, 6), torch.randn(4, 5, 6) + mod = TestModule() + traced = acc_tracer.trace(mod, [a, b]) + self.assertTrue(torch.equal(mod(a, b), traced(a, b))) + + ph_a = ph_b = unsqueeze_a = unsqueeze_b = cat_node = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + ph_b = node + elif node.op == "call_function": + if node.target == acc_ops.unsqueeze: + if node.kwargs["input"] is ph_a: + unsqueeze_a = node + else: + self.assertEqual(node.kwargs["input"], ph_b) + unsqueeze_b = node + else: + self.assertEqual(node.target, acc_ops.cat) + self.assertEqual(node.kwargs["tensors"], [unsqueeze_a, unsqueeze_b]) + cat_node = node + elif node.op == "output": + self.assertEqual(cat_node, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + def test_no_raise(self): + """ + self that we can trace `if x: raise y(msg)` when the raise isn't executed. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + if torch.equal(a, b): + raise AssertionError("a equaled b!") + return a + + m = TestModule() + in_a, in_b = torch.randn(5), torch.randn(5) + traced = acc_tracer.trace( + m, + [in_a, in_b], + remove_exceptions=False, + use_acc_normalization=False, + ast_rewriter_allow_list={TestModule}, + ) + + # Verify the structure of the graph, including the existence of the + # exception_wrapper. + ph_a = exception_wrapper = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + elif node.op == "call_module": + self.assertEqual( + node.target, "_conditional_exception_wrapper_AssertionError" + ) + exception_wrapper = node + elif node.op == "output": + self.assertEqual(ph_a, node.args[0]) + + self.assertTrue(exception_wrapper is not None) + + self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_b))) + + def test_yes_raise(self): + """ + Test that we can trace `if x: raise y(msg)` when the raise is executed. + """ + err_str = "a equaled b!" + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.err_str = err_str + + def forward(self, a, b): + if torch.equal(a, b): + raise RuntimeError(self.err_str) + return a + + m = TestModule() + # Note: We must use different inputs here in order for shape_prop to work, as + # otherwise the exception is thrown (as expected/checked below). + in_a, in_b = torch.randn(5), torch.randn(5) + traced = acc_tracer.trace( + m, + [in_a, in_b], + remove_exceptions=False, + ast_rewriter_allow_list={TestModule}, + ) + + # Verify the structure of the graph, including the existence of the + # exception_wrapper. + ph_a = exception_wrapper = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + elif node.op == "call_module": + self.assertEqual( + node.target, "_conditional_exception_wrapper_RuntimeError" + ) + exception_wrapper = node + elif node.op == "output": + self.assertEqual(ph_a, node.args[0]) + + self.assertTrue(exception_wrapper is not None) + + def test(mod): + try: + # Note: Use the same input here to ensure the exception is thrown. + mod(in_a, in_a) + self.fail("Shouldn't get here because exception should be thrown.") + except RuntimeError as e: + self.assertEqual(err_str, str(e)) + + test(m) + test(traced) + + def test_remove_raise(self): + """ + Test that we can trace `if x: raise y(msg)` and then remove the exception_wrapper. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + if torch.equal(a, b): + raise AssertionError("a equaled b!") + return a + + m = TestModule() + in_a, in_b = torch.randn(5), torch.randn(5) + traced = acc_tracer.trace( + m, + [in_a, in_b], + remove_exceptions=True, + ast_rewriter_allow_list={TestModule}, + ) + + # Verify the structure of the graph, including the existence of the + # exception_wrapper. + ph_a = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + elif node.op == "output": + self.assertEqual(ph_a, node.args[0]) + else: + # Should not encounter any call_modules, e.g. to the + # exception_wrapper. + self.assertFalse(node.op == "call_module") + + # Note: Using input in_a twice for the tracer version, which would + # trigger the raise if it was still there. + self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_a))) + + def test_raise_no_message(self): + """ + Test that we can trace `if x: raise y` when `y` has no message. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + if torch.equal(a, b): + raise AssertionError + return a + + m = TestModule() + in_a, in_b = torch.randn(5), torch.randn(5) + traced = acc_tracer.trace( + m, + [in_a, in_b], + remove_exceptions=False, + use_acc_normalization=False, + ast_rewriter_allow_list={TestModule}, + ) + + # Verify the structure of the graph, including the existence of the + # exception_wrapper. + ph_a = exception_wrapper = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + elif node.op == "call_module": + self.assertEqual( + node.target, "_conditional_exception_wrapper_AssertionError" + ) + exception_wrapper = node + elif node.op == "output": + self.assertEqual(ph_a, node.args[0]) + + self.assertTrue(exception_wrapper is not None) + self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_b))) + + def test_quantized_add(self): + """ + Test that a quantized_add and acc_ops.quantize_per_tensor are traced as expected, + verifying the acc_out_tys are set as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.q_input = torch.nn.quantized.Quantize( + scale=1.0 / 128, zero_point=5, dtype=torch.quint8 + ) + self.q_other = torch.nn.quantized.Quantize( + scale=1.0 / 128, zero_point=10, dtype=torch.quint8 + ) + + def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.add( + self.q_input(input), + self.q_other(other), + scale=0.05, + zero_point=1, + ) + + m = TestModule() + input, other = torch.randn(2, 3, 4), torch.randn(2, 3, 4) + traced = acc_tracer.trace(m, [input, other]) + + input_ph = other_ph = q_input = q_other = q_add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "input": + input_ph = node + else: + self.assertTrue(str(node.target) == "other") + other_ph = node + elif ( + node.op == "call_function" + and node.target == acc_ops.quantize_per_tensor + ): + qparams = { + "scale": 1.0 / 128, + "zero_point": 5, + } + expected_md = acc_utils.build_raw_tensor_meta( + dtype=torch.quint8, + qparams=qparams, + ) + if node.kwargs["input"] == input_ph: + q_input = node + else: + self.assertTrue(node.kwargs["input"] == other_ph) + q_other = node + qparams_copy = qparams.copy() + qparams_copy["zero_point"] = 10 + expected_md = expected_md._replace(qparams=qparams_copy) + self.assertEqual(node.kwargs["acc_out_ty"], expected_md) + elif node.op == "call_function" and node.target == acc_ops.quantized_add: + self.assertEqual(node.kwargs["input"], q_input) + self.assertEqual(node.kwargs["other"], q_other) + qparams = { + "scale": 0.05, + "zero_point": 1, + } + expected_md = acc_utils.build_raw_tensor_meta(qparams=qparams) + self.assertEqual(node.kwargs["acc_out_ty"], expected_md) + q_add = node + elif node.op == "output": + self.assertEqual(q_add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input, other), traced(input, other))) + + def test_quantized_mul(self): + """ + Test that a quantized_mul and acc_ops.quantize_per_tensor are traced as expected, + verifying the acc_out_tys are set as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.q_input = torch.nn.quantized.Quantize( + scale=1.0 / 128, zero_point=5, dtype=torch.quint8 + ) + self.q_other = torch.nn.quantized.Quantize( + scale=1.0 / 128, zero_point=10, dtype=torch.quint8 + ) + + def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.mul( + self.q_input(input), + self.q_other(other), + scale=0.05, + zero_point=1, + ) + + m = TestModule() + input, other = torch.randn(2, 3, 4), torch.randn(2, 3, 4) + traced = acc_tracer.trace(m, [input, other]) + + input_ph = other_ph = q_input = q_other = q_add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "input": + input_ph = node + else: + self.assertTrue(str(node.target) == "other") + other_ph = node + elif ( + node.op == "call_function" + and node.target == acc_ops.quantize_per_tensor + ): + qparams = { + "scale": 1.0 / 128, + "zero_point": 5, + } + expected_md = acc_utils.build_raw_tensor_meta( + dtype=torch.quint8, + qparams=qparams, + ) + if node.kwargs["input"] == input_ph: + q_input = node + else: + self.assertTrue(node.kwargs["input"] == other_ph) + q_other = node + qparams_copy = qparams.copy() + qparams_copy["zero_point"] = 10 + expected_md = expected_md._replace(qparams=qparams_copy) + self.assertEqual(node.kwargs["acc_out_ty"], expected_md) + elif node.op == "call_function" and node.target == acc_ops.quantized_mul: + self.assertEqual(node.kwargs["input"], q_input) + self.assertEqual(node.kwargs["other"], q_other) + qparams = { + "scale": 0.05, + "zero_point": 1, + } + expected_md = acc_utils.build_raw_tensor_meta(qparams=qparams) + self.assertEqual(node.kwargs["acc_out_ty"], expected_md) + q_add = node + elif node.op == "output": + self.assertEqual(q_add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input, other), traced(input, other))) + + def test_cat(self): + """ + Test that torch.cat is traced correctly. + """ + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, a, b], 0) + + m = TestModule() + a, b = torch.randn(2, 2), torch.randn(2, 2) + traced = acc_tracer.trace(m, (a, b)) + + ph_a = ph_b = cat = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + ph_b = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.cat) + self.assertEqual(node.kwargs["tensors"][0], ph_a) + self.assertEqual(node.kwargs["tensors"][1], ph_a) + self.assertEqual(node.kwargs["tensors"][2], ph_b) + self.assertEqual(node.kwargs["dim"], 0) + cat = node + elif node.op == "output": + self.assertEqual(cat, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(a, b), traced(a, b))) + + def test_square(self): + """ + Test that torch.square is traced correctly. + """ + self._make_acc_op_function_test(acc_ops.mul, torch.square) + + def test_reshape(self): + """ + Test that torch.reshape is traced correctly. + """ + self._make_acc_op_function_test(acc_ops.reshape, torch.reshape, (1, -1)) + # arg = (1, -1) + self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.reshape(1, -1)) + # arg = ((1, -1)) + self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.reshape((1, -1))) + + def test_transpose(self): + """ + Test that torch.transpose is traced correctly. + """ + self._make_acc_op_function_test( + acc_ops.permute, lambda x: torch.transpose(x, 1, 0) + ) + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + x = len(a.shape) - 2 + y = len(a.shape) - 1 + return a.transpose(x, y) + + m = TestModule() + m.eval() + + a = torch.randn(2, 3, 4, 5) + traced = acc_tracer.trace(m, [a]) + + ph_a = permute = None + for node in traced.graph.nodes: + if node.op == "placeholder": + ph_a = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.permute) + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["permutation"], [0, 1, 3, 2]) + permute = node + elif node.op == "output": + self.assertEqual(permute, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(a), traced(a))) + + def test_permute(self): + """ + Test that torch.permute is traced correctly. + """ + + def torch_permute(a, *dim): + return a.permute(*dim) + + self._make_acc_op_function_test(acc_ops.permute, torch_permute, 1, 0) + + def test_min_full_reduce(self): + """ + Test that test_min_full_reduce is traced correctly. + """ + self._make_acc_op_function_test(acc_ops.min_full_reduce, torch.min) + + def test_matmul(self): + """ + Test that torch.matmul is traced correctly. + """ + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.matmul(a, b) + + m = TestModule() + a, b = torch.randn(2, 2), torch.randn(2, 2) + traced = acc_tracer.trace(m, [a, b]) + + ph_a = ph_b = matmul = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + else: + self.assertTrue(str(node.target) == "b") + ph_b = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.matmul) + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["other"], ph_b) + matmul = node + elif node.op == "output": + self.assertEqual(matmul, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(a, b), traced(a, b))) + + def test_bmm(self): + self._make_acc_op_function_test( + acc_ops.matmul, lambda x: torch.bmm(x, x), input_shape=(2, 4, 4) + ) + + def test_tile(self): + return self._make_acc_op_function_test( + acc_ops.tile, lambda x: torch.tile(x, (2, 1, 2)), input_shape=(1, 2) + ) + + def test_dropout(self): + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout(x, training=False), + input_shape=(1, 2, 3), + ) + + def test_stochastic_depth(self): + self._make_acc_op_function_test( + None, + lambda x, p, mode, training: torchvision.ops.stochastic_depth( + x, p=p, mode=mode, training=training + ), + input_shape=(1, 2, 3), + p=0.5, + mode="row", + training=False, + ) + + def test_hardsigmoid(self): + self._make_acc_op_function_test( + acc_ops.hardsigmoid, + lambda x: nn.functional.hardsigmoid(x), + input_shape=(3, 4, 5), + ) + + def test_hardtanh(self): + self._make_acc_op_function_test( + acc_ops.hardtanh, + lambda x: nn.functional.hardtanh(x), + input_shape=(3, 4, 5), + ) + + def test_hardswish(self): + class TestModule(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = nn.functional.hardswish(x) + return y + + m = TestModule() + x = torch.randn(3, 4, 5) + traced = acc_tracer.trace(m, [x]) + ph_x = hardsigmoid_y = res_y = None + for node in traced.graph.nodes: + if node.op == "placeholder": + ph_x = node + elif node.op == "call_function" and node.target == acc_ops.hardsigmoid: + hardsigmoid_y = node + self.assertEqual(node.kwargs["input"], ph_x) + elif node.op == "call_function" and node.target == acc_ops.mul: + res_y = node + self.assertEqual(node.kwargs["input"], hardsigmoid_y) + self.assertEqual(node.kwargs["other"], ph_x) + elif node.op == "output": + self.assertEqual(node.args[0], res_y) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(x) + res = traced(x) + torch.testing.assert_close(ref, res) + + def test_add_with_alpha(self): + """ + Test that normalization works for torch add with alpha, which requires special + normalization handling. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + a1 = torch.add(a, b) + a2 = torch.add(a, b, alpha=1.0) + a3 = torch.add(a, b, alpha=0.5) + return a1, a2, a3 + + m = TestModule() + input_a = torch.randn(2, 3) + input_b = torch.randn(2, 3) + traced = acc_tracer.trace(m, [input_a, input_b]) + + ph_a = ph_b = add_1 = add_2 = add_3 = mul = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + elif str(node.target) == "b": + ph_b = node + else: + self.fail(f"Unexpected placeholder {node.target}.") + elif node.op == "call_function" and node.target == acc_ops.mul: + mul = node + self.assertEqual(node.kwargs["input"], ph_b) + self.assertEqual(node.kwargs["other"], 0.5) + elif node.op == "call_function" and node.target == acc_ops.add: + if add_1 is None: + add_1 = node + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["other"], ph_b) + elif add_2 is None: + add_2 = node + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["other"], ph_b) + elif add_3 is None: + add_3 = node + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["other"], mul) + else: + self.fail(f"Unexpected add: {node.format_node()}") + elif node.op == "output": + self.assertEqual(node.args[0][0], add_1) + self.assertEqual(node.args[0][1], add_2) + self.assertEqual(node.args[0][2], add_3) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(input_a, input_b) + res = traced(input_a, input_b) + self.assertTrue(torch.equal(ref[0], res[0])) + self.assertTrue(torch.equal(ref[1], res[1])) + self.assertTrue(torch.equal(ref[2], res[2])) + + def test_leaf_module_list(self): + """ + Test leaf_module_list is working properly. + """ + + class LeafModule(nn.Module): + def forward(self, x): + return x + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.mod = LeafModule() + + def forward(self, x): + return self.mod(x) + + x = torch.randn(1, 1) + mod = TestModule() + acc_mod = acc_tracer.trace( + mod, + [x], + leaf_module_list={LeafModule}, + ) + ph = leaf_module = None + for node in acc_mod.graph.nodes: + if node.op == "placeholder": + ph = node + elif node.op == "call_module": + leaf_module = node + self.assertEqual(leaf_module.target, "mod") + self.assertEqual(leaf_module.args[0], ph) + elif node.op == "output": + self.assertEqual(node.args[0], leaf_module) + else: + self.fail(f"Unexpected node: {node.format_node()}") + self.assertTrue(torch.equal(mod(x), acc_mod(x))) + + def test_sign(self): + self._make_acc_op_function_test(acc_ops.sign, torch.sign) + + def test_relu(self): + self._make_acc_op_function_test(acc_ops.relu, torch.relu) + + def test_leaky_relu(self): + self._make_acc_op_function_test( + acc_ops.leaky_relu, torch.nn.functional.leaky_relu + ) + + def test_elu(self): + self._make_acc_op_function_test(acc_ops.elu, torch.nn.functional.elu) + + def test_selu(self): + self._make_acc_op_function_test(acc_ops.selu, torch.nn.functional.selu) + + def test_softsign(self): + self._make_acc_op_function_test(acc_ops.softsign, torch.nn.functional.softsign) + + def test_sigmoid(self): + self._make_acc_op_function_test(acc_ops.sigmoid, torch.sigmoid) + + def test_sin(self): + self._make_acc_op_function_test(acc_ops.sin, torch.sin) + + def test_cos(self): + self._make_acc_op_function_test(acc_ops.cos, torch.cos) + + def test_tan(self): + self._make_acc_op_function_test(acc_ops.tan, torch.tan) + + def test_sinh(self): + self._make_acc_op_function_test(acc_ops.sinh, torch.sinh) + + def test_cosh(self): + self._make_acc_op_function_test(acc_ops.cosh, torch.cosh) + + def test_tanh(self): + self._make_acc_op_function_test(acc_ops.tanh, torch.tanh) + + def test_asin(self): + self._make_acc_op_function_test(acc_ops.asin, torch.asin) + + def test_acos(self): + self._make_acc_op_function_test(acc_ops.acos, torch.acos) + + def test_atan(self): + self._make_acc_op_function_test(acc_ops.atan, torch.atan) + + def test_exp(self): + self._make_acc_op_function_test(acc_ops.exp, torch.exp) + + def test_log(self): + self._make_acc_op_function_test(acc_ops.log, torch.log) + + def test_sqrt(self): + self._make_acc_op_function_test(acc_ops.sqrt, torch.sqrt) + + def test_reciprocal(self): + self._make_acc_op_function_test(acc_ops.reciprocal, torch.reciprocal) + + def test_abs(self): + self._make_acc_op_function_test(acc_ops.abs, torch.abs) + + def test_neg(self): + self._make_acc_op_function_test(acc_ops.neg, torch.neg) + + def test_floor(self): + self._make_acc_op_function_test(acc_ops.floor, torch.floor) + + def test_ceil(self): + self._make_acc_op_function_test(acc_ops.ceil, torch.ceil) + + def test_softmax(self): + self._make_acc_op_function_test(acc_ops.softmax, torch.nn.functional.softmax) + + def test_tensor_squeeze(self): + self._make_acc_op_function_test(acc_ops.squeeze, lambda x: x.squeeze()) + + def test_torch_squeeze(self): + self._make_acc_op_function_test(acc_ops.squeeze, lambda x: torch.squeeze(x)) + + def test_operator_mul(self): + self._make_acc_op_function_test(acc_ops.mul, lambda x: x * 7) + + def test_torch_mul(self): + self._make_acc_op_function_test(acc_ops.mul, lambda x: torch.mul(x, 7)) + + def test_torch_isinf(self): + self._make_acc_op_function_test(acc_ops.isinf, torch.isinf) + + def test_torch_any(self): + self._make_acc_op_function_test(acc_ops.any, torch.any) + + def test_div(self): + self._make_acc_op_function_test(acc_ops.div, lambda x: torch.div(x, 2)) + self._make_acc_op_function_test(acc_ops.div, lambda x: x / 2) + + def test_fmod(self): + self._make_acc_op_function_test(acc_ops.fmod, lambda x: torch.fmod(x, 1.3)) + self._make_acc_op_function_test(acc_ops.fmod, lambda x: torch.fmod(x, -0.4)) + + def test_floor_div(self): + self._make_acc_op_function_test( + acc_ops.floor_div, lambda x: torch.div(x, 2, rounding_mode="floor") + ) + + def test_trunc_div(self): + self._make_acc_op_function_test( + acc_ops.trunc_div, lambda x: torch.div(x, 2, rounding_mode="trunc") + ) + # does not behave the same as floor_divide + # self._make_acc_op_function_test( + # acc_ops.trunc_div, lambda x: torch.floor_divide(x, 2) + # ) + + def test_view(self): + """ + Test that Tensor.view is traced correctly. + """ + + self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view(1, -1)) + self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view([1, -1])) + + def test_narrow(self): + """ + Test that torch.narrow is traced correctly. + """ + return self._make_acc_op_function_test( + acc_ops.slice_tensor, + torch.narrow, + validate_same_kwargs=False, + dim=1, + start=1, + length=2, + ) + + def test_pow(self): + self._make_acc_op_function_test(acc_ops.pow, torch.pow, exponent=2) + + def test_numel(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return torch.numel(a) + + m = TestModule() + a = torch.randn(2, 1, 4) + traced = acc_tracer.trace(m, [a]) + + ph_a = numel = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertTrue(node.target == "a") + ph_a = node + elif node.op == "call_function" and node.target == acc_ops.numel: + numel = node + self.assertTrue(numel.kwargs["input"] is ph_a) + elif node.op == "output": + self.assertEqual(node.args[0], numel) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(a) + res = traced(a) + self.assertEqual(ref, res) + + def test_size(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + idx = a.size(1) + return a.shape[idx] + + m = TestModule() + a = torch.randn(2, 1, 4) + traced = acc_tracer.trace(m, [a]) + + ph_a = size_1 = size_2 = getitem_1 = getitem_2 = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertTrue(node.target == "a") + ph_a = node + elif node.op == "call_function" and node.target == acc_ops.size: + if size_1: + size_2 = node + self.assertTrue(size_2.kwargs["input"] is ph_a) + else: + size_1 = node + self.assertTrue(size_1.kwargs["input"] is ph_a) + elif node.op == "call_function" and node.target == acc_ops.getitem: + if getitem_1: + getitem_2 = node + self.assertTrue(getitem_2.kwargs["idx"] == getitem_1) + self.assertTrue(getitem_2.kwargs["input"] == size_2) + else: + getitem_1 = node + self.assertTrue(getitem_1.kwargs["idx"] == 1) + self.assertTrue(getitem_1.kwargs["input"] == size_1) + elif node.op == "output": + self.assertEqual(node.args[0], getitem_2) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(a) + res = traced(a) + self.assertEqual(ref, res) + + def test_getattr_named_tuple(self): + """ + Test that call_function getattr on namedtuples is + traced correctly. + """ + + class TestNamedTuple(NamedTuple): + foo: torch.Tensor + bar: torch.Tensor + + class TestModule(nn.Module): + def forward(self, a: TestNamedTuple): + return a.foo + a.bar + + m = TestModule() + a = TestNamedTuple(torch.randn(2, 2), torch.randn(2, 2)) + traced = acc_tracer.trace(m, [a]) + + ph_a = getitem_1 = getitem_2 = add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(node.target, "a") + ph_a = node + + elif node.op == "call_function" and node.target == acc_ops.getitem: + if getitem_1: + getitem_2 = node + self.assertEqual(getitem_2.kwargs["idx"], 1) + else: + getitem_1 = node + self.assertEqual(getitem_1.kwargs["idx"], 0) + + self.assertEqual(node.kwargs["input"], ph_a) + + elif node.op == "call_function" and node.target == acc_ops.add: + self.assertEqual(node.kwargs["input"], getitem_1) + self.assertEqual(node.kwargs["other"], getitem_2) + add = node + + elif node.op == "output": + self.assertEqual(node.args[0], add) + + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(a) + res = traced(a) + self.assertTrue(torch.equal(ref, res)) + + def test_flatten(self): + """ + Test that torch.flatten is traced correctly. + """ + self._make_acc_op_function_test( + acc_ops.flatten, torch.flatten, start_dim=1, end_dim=1 + ) + self._make_acc_op_function_test(acc_ops.flatten, lambda x: x.flatten()) + + def test_topk_multi_output(self): + """ + Test that torch.topk multi outputs work. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return torch.topk(a, 3)[1] + + m = TestModule() + input_a = torch.randn(10) + traced = acc_tracer.trace(m, [input_a]) + + ph_a = topk = getitem = None + for node in traced.graph.nodes: + if node.op == "placeholder" and str(node.target) == "a": + ph_a = node + elif node.op == "call_function" and node.target == acc_ops.topk: + topk = node + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["k"], 3) + elif node.op == "call_function" and node.target == acc_ops.getitem: + getitem = node + self.assertEqual(node.kwargs["input"], topk) + self.assertEqual(node.kwargs["idx"], 1) + elif node.op == "output": + self.assertEqual(node.args[0], getitem) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(input_a), traced(input_a))) + + def test_addmm_with_alpha_beta(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor + ) -> torch.Tensor: + return torch.addmm(input, a, b, alpha=1.2, beta=1.1) + + m = TestModule() + input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2) + traced = acc_tracer.trace(m, [input, a, b]) + + ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + elif str(node.target) == "b": + ph_b = node + else: + self.assertTrue(str(node.target) == "input") + ph_in = node + elif node.op == "call_function": + if node.target == acc_ops.matmul: + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["other"], ph_b) + mm = node + elif node.target == acc_ops.add: + self.assertEqual(node.kwargs["input"], mm_mul) + self.assertEqual(node.kwargs["other"], add_mul) + add = node + elif mm_mul: + self.assertEqual(node.kwargs["input"], ph_in) + self.assertEqual(node.kwargs["other"], 1.1) + add_mul = node + else: + self.assertEqual(node.kwargs["input"], mm) + self.assertEqual(node.kwargs["other"], 1.2) + mm_mul = node + elif node.op == "output": + self.assertEqual(add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + torch.testing.assert_close(m(input, a, b), traced(input, a, b)) + + def test_log1p(self): + class TestModule(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.log1p(input) + + m = TestModule().eval() + input = torch.tensor([[1.2, 0.3, -0.4]]) + traced = acc_tracer.trace(m, [input]) + + ph_in = add = log = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertTrue(str(node.target) == "input") + ph_in = node + elif node.op == "call_function": + if node.target == acc_ops.add: + self.assertEqual(node.kwargs["input"], ph_in) + self.assertEqual(node.kwargs["other"], 1) + add = node + else: + self.assertEqual(node.target, acc_ops.log) + self.assertEqual(node.kwargs["input"], add) + log = node + elif node.op == "output": + self.assertEqual(log, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + torch.testing.assert_close(m(input), traced(input)) + + @parameterized.expand([(torch.float,), (torch.float16,)]) + def test_addmm(self, dtype): + class TestModule(torch.nn.Module): + def forward( + self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor + ) -> torch.Tensor: + return torch.addmm(input, a, b) + + m = TestModule() + input, a, b = ( + torch.randn(2, 2, dtype=dtype), + torch.randn(2, 2, dtype=dtype), + torch.randn(2, 2, dtype=dtype), + ) + traced = acc_tracer.trace(m, [input, a, b]) + + ph_in = ph_a = ph_b = mm = add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if str(node.target) == "a": + ph_a = node + elif str(node.target) == "b": + ph_b = node + else: + self.assertTrue(str(node.target) == "input") + ph_in = node + elif node.op == "call_function": + if node.target == acc_ops.matmul: + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["other"], ph_b) + mm = node + else: + self.assertEqual(node.target, acc_ops.add) + self.assertEqual(node.kwargs["input"], mm) + self.assertEqual(node.kwargs["other"], ph_in) + add = node + elif node.op == "output": + self.assertEqual(add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + for node in [ph_in, ph_a, ph_b, mm, add]: + self.assertEqual(acc_utils.get_tensor_meta(node).dtype, dtype) + + if dtype == torch.float: + self.assertTrue(torch.equal(m(input, a, b), traced(input, a, b))) + + def test_gelu(self): + return self._make_acc_op_function_test(acc_ops.gelu, torch.nn.functional.gelu) + + @parameterized.expand( + [ + (1, True), + (1, False), + (None, False), + ] + ) + def test_argmin(self, dim, keepdim): + class TestModule(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.argmin(input, dim=self.dim, keepdim=self.keepdim) + + m = TestModule(dim, keepdim) + input = torch.randn(2, 2) + traced = acc_tracer.trace(m, [input]) + + ph_in = flatten = topk = getitem = squeeze = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertTrue(str(node.target) == "input") + ph_in = node + elif node.op == "call_function": + if node.target == acc_ops.flatten: + self.assertEqual(node.kwargs["input"], ph_in) + flatten = node + elif node.target == acc_ops.topk: + self.assertEqual( + node.kwargs["input"], flatten if flatten else ph_in + ) + topk = node + elif node.target == acc_ops.getitem: + self.assertEqual(node.kwargs["input"], topk) + getitem = node + elif node.target == acc_ops.squeeze: + self.assertEqual(node.kwargs["input"], getitem) + squeeze = node + elif node.op == "output": + self.assertEqual(squeeze if squeeze else getitem, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + if dim is None: + self.assertTrue(flatten is not None) + if not keepdim: + self.assertTrue(squeeze is not None) + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_t(self): + """ + Test Tensor.t() is traced correctly. + """ + self._make_acc_op_function_test(acc_ops.permute, lambda x: x.t()) + self._make_acc_op_function_test( + acc_ops.permute, lambda x: x.t(), input_shape=(3,) + ) + + def test_split_size(self): + self._make_acc_op_function_test( + acc_ops.split, + torch.split, + validate_same_kwargs=False, + split_size_or_sections=2, + dim=1, + ) + + def test_split_sections(self): + class TestModule(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.split(input, [2, 5, 3], 1) + + m = TestModule() + input = torch.randn(1, 10) + traced = acc_tracer.trace(m, [input]) + + ph_in = slice_node_0 = slice_node_1 = slice_node_2 = None + tuple_construct_node = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertTrue(str(node.target) == "input") + ph_in = node + elif node.op == "call_function": + if node.target == acc_ops.slice_tensor: + self.assertEqual(node.kwargs["input"], ph_in) + if slice_node_0: + if slice_node_1: + slice_node_2 = node + else: + slice_node_1 = node + else: + slice_node_0 = node + else: + self.assertEqual(node.target, acc_ops.tuple_construct) + self.assertEqual( + node.kwargs["tensors"], + (slice_node_0, slice_node_1, slice_node_2), + ) + tuple_construct_node = node + elif node.op == "output": + self.assertEqual(tuple_construct_node, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref_output = m(input) + output = traced(input) + for i, j in zip(ref_output, output): + self.assertTrue(torch.equal(i, j)) + + @parameterized.expand( + [ + ("neg_1", -1, 1, 3), + ("neg_2", -2, 1, 3), + ("neg_4", -4, 1, 1), + ] + ) + def test_negative_slicing(self, _, dim, start, length): + """ + Test that slicing with negative dims works. + """ + self._make_acc_op_function_test( + acc_ops.slice_tensor, + torch.narrow, + input_shape=(2, 3, 4, 5), + validate_same_kwargs=False, + dim=dim, + start=start, + length=length, + ) + + def test_list_input(self): + """ + Test that list inputs are traced correctly. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: List[torch.Tensor]) -> torch.Tensor: + return a[0] + a[1] + + m = TestModule() + input = [torch.randn(2, 3), torch.randn(2, 3)] + traced = acc_tracer.trace(m, [input]) + + ph = getitem_0 = getitem_1 = add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "call_function" and node.target == acc_ops.getitem: + self.assertTrue(node.kwargs["idx"] == 0 or node.kwargs["idx"] == 1) + if node.kwargs["idx"] == 0: + getitem_0 = node + else: + getitem_1 = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.add) + self.assertEqual(node.kwargs["input"], getitem_0) + self.assertEqual(node.kwargs["other"], getitem_1) + add = node + elif node.op == "output": + self.assertEqual(add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + # Check the tensor ranks are correct given the input is a list. + self.assertIsInstance(ph.meta["tensor_rank"], list) + self.assertEqual(len(ph.meta["tensor_rank"]), 2) + self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"][0]) + self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"][1]) + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_dict_input(self): + """ + Test that dict inputs are traced correctly. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: Dict[str, torch.Tensor]) -> torch.Tensor: + return a["foo"] + a["bar"] + + m = TestModule() + input = {"foo": torch.randn(2, 3), "bar": torch.randn(2, 3)} + traced = acc_tracer.trace(m, [input]) + + ph = getitem_0 = getitem_1 = add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "call_function" and node.target == acc_ops.getitem: + self.assertTrue( + node.kwargs["idx"] == "foo" or node.kwargs["idx"] == "bar" + ) + if node.kwargs["idx"] == "foo": + getitem_0 = node + else: + getitem_1 = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.add) + self.assertEqual(node.kwargs["input"], getitem_0) + self.assertEqual(node.kwargs["other"], getitem_1) + add = node + elif node.op == "output": + self.assertEqual(add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + # Check the tensor ranks are correct given the input is a dict. + self.assertIsInstance(ph.meta["tensor_rank"], dict) + self.assertEqual(len(ph.meta["tensor_rank"]), 2) + self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"]["foo"]) + self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"]["bar"]) + + self.assertTrue(torch.equal(m(input), traced(input))) + + def test_none_type_ret(self): + """ + Test that a NoneType is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, a: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + return a + a, None + + m = TestModule() + input = torch.randn(1, 2, 3) + try: + traced = acc_tracer.trace( + m, + [input], + ) + except RuntimeError as e: + self.assertEqual( + "This error should not be triggered, as NoneType should be lowered without an issue", + str(e), + ) + ans1, _ = m(input) + ans2, _ = traced(input) + self.assertTrue(torch.equal(ans1, ans2)) + + def test_mobilenet_v3(self): + """ + Test that we can trace mobilenet v3 small and run/compare against the untraced version. + """ + m = torchvision.models.mobilenet_v3_small(pretrained=True) + self._make_model_unit_test(m, enable_allclose=True) + + def test_mobilenet_v2(self): + """ + Test that we can trace mobilenet v2 small and run/compare against the untraced version. + """ + m = torchvision.models.mobilenet_v2(pretrained=True) + self._make_model_unit_test(m) + + def test_vgg16(self): + """ + Test that we can trace vgg16 and run/compare against the untraced version. + """ + m = torchvision.models.vgg16(pretrained=True) + self._make_model_unit_test(m) + + def test_resnet18(self): + """ + Test that we can trace resnet18 and run/compare against the untraced version. + """ + m = torchvision.models.resnet18(pretrained=True) + self._make_model_unit_test(m) + + def test_resnext50_32x4d(self): + """ + Test that we can trace resnext and run/compare against the untraced version. + """ + m = torchvision.models.resnext50_32x4d(pretrained=True) + self._make_model_unit_test(m) + + def test_cumsum(self): + # Tests call_function version + self._make_acc_op_function_test(acc_ops.cumsum, torch.cumsum, dim=1) + self._make_acc_op_function_test( + acc_ops.cumsum, torch.cumsum, dim=1, dtype=torch.float + ) + + # Tests call_method version + class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return a.cumsum(dim=0) + + m = TestModule() + a = torch.rand(2, 2) + gm = acc_tracer.trace(m, [a]) + self.assertTrue(torch.equal(m(a), gm(a))) + + def test_chunk(self): + self._make_acc_op_function_test(acc_ops.chunk, torch.chunk, chunks=2, dim=0) + + def test_retrace_reshape(self): + """ + Retrace reshape to verify it's retraceable. + """ + + class TestModule(torch.nn.Module): + def forward(self, a: torch.Tensor) -> torch.Tensor: + return a.reshape(a.size()[0], 1, 2) + + m = TestModule() + a = torch.randn(2, 2) + gm = acc_tracer.trace(m, [a]) + self.assertTrue(torch.equal(m(a), gm(a))) + gm_retrace = acc_tracer.trace(gm, [a]) + self.assertTrue(torch.equal(m(a), gm_retrace(a))) + + def test_index_select(self): + class TestModule(nn.Module): + def __init__(self, dim, index): + super().__init__() + self._dim = dim + self._index = index + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return torch.index_select(a, self._dim, self._index) + + dim = 0 + index = torch.tensor([1, 0]) + m = TestModule(dim, index) + _input = [torch.randn(2, 3), torch.randn(2, 3)] + traced = acc_tracer.trace(m, _input) + + ph = index = index_select = None + + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "call_function" and node.target == acc_ops.index_select: + self.assertTrue(node.kwargs["input"] == ph) + self.assertTrue(node.kwargs["index"] == index) + self.assertTrue(node.kwargs["dim"] == dim) + index_select = node + elif node.op == "output": + self.assertEqual(index_select, node.args[0]) + elif node.op == "get_attr": + # There only be oneâ„¢ const node + self.assertTrue(index is None) + index = node + else: + self.fail(f"Unexpected node: {node.format_node()}") + + def test_gather(self): + class TestModule(nn.Module): + def __init__(self, dim, index): + super().__init__() + self._dim = dim + self._index = index + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return torch.gather(a, self._dim, self._index) + + dim = 0 + index = torch.tensor([[1, 0], [0, 1]]) + m = TestModule(dim, index) + _input = [torch.randn(2, 3), torch.randn(2, 3)] + traced = acc_tracer.trace(m, _input) + + ph = index = gather = None + + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "call_function" and node.target == acc_ops.gather: + self.assertTrue(node.kwargs["input"] == ph) + self.assertTrue(node.kwargs["index"] == index) + self.assertTrue(node.kwargs["dim"] == dim) + gather = node + elif node.op == "output": + self.assertEqual(gather, node.args[0]) + elif node.op == "get_attr": + # There only be oneâ„¢ const node + self.assertTrue(index is None) + index = node + else: + self.fail(f"Unexpected node: {node.format_node()}") + + def test_where(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + return torch.where(a, b, c) + + m = TestModule() + x = torch.randn(3, 2) + y = torch.ones(3, 2) + cond = x > 0 + traced = acc_tracer.trace(m, [cond, x, y]) + + ph_a = where = None + ph_b = None + ph_c = None + for node in traced.graph.nodes: + if node.op == "placeholder": + if node.target == "a": + ph_a = node + elif node.target == "b": + ph_b = node + elif node.target == "c": + ph_c = node + elif node.op == "call_function" and node.target == acc_ops.where: + where = node + self.assertTrue(where.kwargs["condition"] is ph_a) + self.assertTrue(where.kwargs["x"] is ph_b) + self.assertTrue(where.kwargs["y"] is ph_c) + elif node.op == "output": + self.assertEqual(node.args[0], where) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(cond, x, y) + res = traced(cond, x, y) + self.assertTrue(torch.equal(ref, res)) + + @parameterized.expand( + [ + ("sections divisible", 2, 0), + ("sections indivisible", 3, 0), + ("indices list", [1, 3], 0), + ("indices tuple", (1, 3), 0), + ("indices tensor", torch.tensor([1, 3]), 0), + ("indices tensor dim1", torch.tensor([1, 3]), 1), + ("indices tensor dim2", torch.tensor([1, 3]), 2), + ("indices tensor long dim2", torch.tensor([1, 3, 5, 7]), 2), + ] + ) + def test_tensor_split(self, _, indices_or_sections, dim): + """ + Test that the tracer works for torch.tensor_split with indices and sections + """ + + class TestModule(nn.Module): + def __init__(self, indices_or_sections, dim): + super().__init__() + self._indices_or_sections = indices_or_sections + self._dim = dim + + def forward(self, a): + return torch.tensor_split(a, self._indices_or_sections, self._dim) + + m = TestModule(indices_or_sections, dim) + a = torch.randn(4, 8, 16) + traced = acc_tracer.trace(m, [a]) + + results = traced(a) + references = m(a) + for res, ref in zip(results, references): + self.assertTrue( + torch.equal(ref, res), f"Tensors at don't match {ref=} {res=}" + ) + + def test_inplace_raise(self): + """ + Test that encountering inplace is raised for exception + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + a = a + 2 + a.sub_(3) + return a + + m = TestModule() + in_a = torch.randn(5) + try: + acc_tracer.trace( + m, + [in_a], + ) + self.fail("Shouldn't get here because exception should be thrown.") + except RuntimeError as e: + self.assertEqual( + "Tried to trace mutable operation sub_. FX only supports functional code", + str(e), + ) + + def test_repeat_interleave(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.repeat_interleave(x, 2, 1) + + # TODO: finish test later + m = TestModule() + x = torch.randn(3, 4) + traced = acc_tracer.trace(m, [x]) + ph_in = tile = size = getitem = unsqueeze = reshape = None + for node in traced.graph.nodes: + if node.op == "placeholder": + ph_in = node + elif node.op == "call_function": + if node.target == acc_ops.size: + self.assertEqual(node.kwargs["input"], ph_in) + size = node + elif node.target == acc_ops.getitem: + self.assertEqual(node.kwargs["input"], size) + getitem = node + elif node.target == acc_ops.reshape: + self.assertEqual(node.kwargs["input"], tile) + reshape = node + elif node.target == acc_ops.unsqueeze: + self.assertEqual(node.kwargs["input"], ph_in) + unsqueeze = node + elif node.target == acc_ops.tile: + self.assertEqual(node.kwargs["input"], unsqueeze) + tile = node + elif node.op == "output": + self.assertEqual(reshape, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + if size is not None: + self.assertIsNotNone(getitem) + self.assertTrue(torch.equal(m(x), traced(x))) + + def test_acc_normalization_block_list(self): + class TestModule(nn.Module): + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + return x[0] + x[1] + + m = TestModule() + x = [torch.randn(1), torch.randn(1)] + traced = acc_tracer.trace( + m, [x], acc_normalization_block_list={("call_function", operator.getitem)} + ) + for node in traced.graph.nodes: + if "getitem" in node.name: + # Make sure we didn't convert to the acc version + self.assertEqual(node.target, operator.getitem) + + def test_detach(self): + class TestModule(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.detach(x) + + m = TestModule() + sample_inputs = [torch.randn(8)] + traced = acc_tracer.trace(m, sample_inputs) + + placeholder = output = None + for node in traced.graph.nodes: + if node.op == "placeholder": + assert placeholder is None + placeholder = node + elif node.op == "output": + assert output is None + output = node + else: + raise RuntimeError(f"Unexpected Node {node.format_node()}") + + self.assertIsNotNone(placeholder) + self.assertIsNotNone(output) + + self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs))) + + def test_all_acc_ops_registered(self): + self.assertEqual( + acc_normalizer._acc_ops, + { + acc_ops.linear, + acc_ops.embedding, + acc_ops.max_pool1d, + acc_ops.max_pool2d, + acc_ops.max_pool3d, + acc_ops.flatten, + acc_ops.adaptive_avg_pool2d, + acc_ops.adaptive_avg_pool3d, + acc_ops.avg_pool1d, + acc_ops.avg_pool2d, + acc_ops.avg_pool3d, + acc_ops.add, + acc_ops.min_full_reduce, + acc_ops.min_dim_reduce, + acc_ops.minimum, + acc_ops.cat, + acc_ops.softmax, + acc_ops.sign, + acc_ops.permute, + acc_ops.matmul, + acc_ops.quantize_per_tensor, + acc_ops.quantize_per_channel, + acc_ops.quantized_add, + acc_ops.quantized_mul, + acc_ops.dequantize, + acc_ops.sub, + acc_ops.mul, + acc_ops.div, + acc_ops.fmod, + acc_ops.floor_div, + acc_ops.trunc_div, + acc_ops.pow, + acc_ops.relu, + acc_ops.prelu, + acc_ops.leaky_relu, + acc_ops.elu, + acc_ops.selu, + acc_ops.softsign, + acc_ops.tuple_construct, + acc_ops.unsqueeze, + acc_ops.sigmoid, + acc_ops.sum, + acc_ops.prod, + acc_ops.max_full_reduce, + acc_ops.max_dim_reduce, + acc_ops.maximum, + acc_ops.sinh, + acc_ops.cosh, + acc_ops.tanh, + acc_ops.asin, + acc_ops.acos, + acc_ops.atan, + acc_ops.exp, + acc_ops.log, + acc_ops.sqrt, + acc_ops.reciprocal, + acc_ops.abs, + acc_ops.neg, + acc_ops.floor, + acc_ops.ceil, + acc_ops.size, + acc_ops.split, + acc_ops.conv1d, + acc_ops.conv2d, + acc_ops.conv3d, + acc_ops.conv_transpose2d, + acc_ops.conv_transpose3d, + acc_ops.batch_norm, + acc_ops.embedding_bag, + acc_ops.embedding_bag_byte_rowwise_offsets, + acc_ops.embedding_bag_4bit_rowwise_offsets, + acc_ops.contiguous, + acc_ops.pad, + acc_ops.sin, + acc_ops.cos, + acc_ops.tan, + acc_ops.topk, + acc_ops.getitem, + acc_ops.squeeze, + acc_ops.tile, + acc_ops.reshape, + acc_ops.quantized_linear, + acc_ops.quantized_conv2d, + acc_ops.quantized_batch_norm2d, + acc_ops.to_dtype, + acc_ops.clamp, + acc_ops.layer_norm, + acc_ops.linalg_norm, + acc_ops.slice_tensor, + acc_ops.hardsigmoid, + acc_ops.mean, + acc_ops.hardtanh, + acc_ops.gelu, + acc_ops.cumsum, + acc_ops.chunk, + acc_ops.rescale_quantize_per_tensor, + acc_ops.rescale_quantize_per_channel, + acc_ops.nan_to_num, + acc_ops.expand, + acc_ops.masked_fill, + acc_ops.eq, + acc_ops.gt, + acc_ops.lt, + acc_ops.logical_or, + acc_ops.logical_xor, + acc_ops.gather, + acc_ops.index_select, + acc_ops.interpolate, + acc_ops.logical_and, + acc_ops.logical_not, + acc_ops.ne, + acc_ops.device, + acc_ops.numel, + acc_ops.where, + acc_ops.dtype, + acc_ops.isinf, + acc_ops.any, + acc_ops.tensor_split, + acc_ops.new_empty, + acc_ops.new_ones, + acc_ops.einsum, + acc_ops.as_strided, + acc_ops.var, + acc_ops.grid_sample, + acc_ops.xl_weight, + }, + ) diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py new file mode 100644 index 0000000000..4af730d67c --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py @@ -0,0 +1,245 @@ +import copy +import unittest + +import torch +import torch._dynamo as torchdynamo + +import torch._dynamo.config +import torchvision +from functorch.experimental import functionalize +from torch._dynamo.optimizations import backends +from torch._dynamo.optimizations.normalize import normalize_ir + +from torch.library import Library +from torch_tensorrt.dynamo.lower import compile +from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx +from torch_tensorrt.dynamo.utils import LowerPrecision, proxytensor_trace + +# TODO(ezyang): remove this after we properly support fake example inputs +torch._dynamo.config.DO_NOT_USE_legacy_non_fake_example_inputs = True + +torch.manual_seed(0) + +wrap_lib = Library("wrap", "DEF") +""" +There are two methods for setting leaf_module. leaf(op registeration) and leaf(override call_module) +Only leaf(op registeration) can work together with functionalize. +If you do not need funcitonalize, you can choose any of the leaf module methods. + +Test coverage: +ProxytensorTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration) +DispatchTracerTest.test_leaf_operator_reg: dispatch tracer + functionalize + leaf(op registeration) +DispatchTracerTest.test_leaf: dispatch tracer + leaf(override call_module) +DispatchTracerTest.test_non_tensor_input: dispatch tracer +DispatchTracerTest.test_reference_copy: dispatch tracer + functionalize +DispatchTracerTest.test_reference_copy_torchdynamo: dispatcher tracer + torchdynamo + functionalize +""" + + +class ProxytensorTracerTest(unittest.TestCase): + def test_leaf_operator_reg(self): + class Leaf(torch.nn.Module): + def forward(self, x, y): + return x + y + torch.nn.Parameter(torch.ones(5)) + + leaf = Leaf() + wrap_lib.define("wrapped_foo(Tensor x, Tensor y) -> Tensor") + wrap_lib.impl("wrapped_foo", leaf, "CPU") + + class Bar(torch.nn.Module): + def __init__(self): + super(Bar, self).__init__() + self.foo = torch.ops.wrap.wrapped_foo + self.other = torch.nn.Parameter(torch.ones(5)) + + def forward(self, x, y): + x = self.foo(x, y) + x = x + self.other + return x + + mod = Bar().eval() + inputs = [torch.ones(5), torch.ones(5)] + gm = proxytensor_trace(mod, inputs) + inputs_new = [torch.ones(5) + 5, torch.ones(5) + 8] + output = gm(*inputs_new) + ref_output = mod(*inputs_new) + torch.testing.assert_close(output, ref_output) + + def test_resnet18_dynamo(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + ref_output = mod(*inputs) + + torchdynamo.reset() + dynamo_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) + dynamo_output = dynamo_mod(*inputs) + cos_val = torch.nn.functional.cosine_similarity( + dynamo_output.flatten(), ref_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) + + +class DispatchTracerTest(unittest.TestCase): + def test_leaf_operator_reg(self): + class Leaf(torch.nn.Module): + def forward(self, x, y): + return x + y + torch.nn.Parameter(torch.ones(5)) + + leaf = Leaf() + wrap_lib.define("wrapped_leaf(Tensor x, Tensor y) -> Tensor") + wrap_lib.impl("wrapped_leaf", leaf, "CPU") + + class Bar(torch.nn.Module): + def __init__(self): + super(Bar, self).__init__() + self.leaf = torch.ops.wrap.wrapped_leaf + self.other = torch.nn.Parameter(torch.ones(5)) + + def forward(self, x, y): + x = self.leaf(x, y) + x = x + self.other + return x + + mod = Bar() + + def f(x, y): + return mod(x, y) + + gm = make_fx(functionalize(f))(torch.ones(5), torch.ones(5)) + inputs = [torch.ones(5) + 5, torch.ones(5) + 8] + output = gm(*inputs) + ref_output = f(*inputs) + torch.testing.assert_close(output, ref_output) + # through the op registration method, the module is defined in a call_function + call_function_node = None + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.wrap.wrapped_leaf + ): + call_function_node = node + self.assertIsNotNone(call_function_node) + + ## The test is broken on Aug 27 as the leaf node does not work. P525693772 + # def test_leaf(self): + # class TestModuleLeaf(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv2d(3, 10, 1) + # self.relu = torch.nn.ReLU(inplace=True) + + # def forward(self, x): + # x = self.conv(x) + # return self.relu(x) + + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # self.relu = torch.nn.ReLU(inplace=True) + # self.leaf = TestModuleLeaf() + + # def forward(self, x): + # x = self.leaf(x) + # return self.relu(x) + + # mod = TestModule() + + # def f(x): + # return mod(x) + + # a = torch.randn(1, 3, 1, 1) + # ref_output = f(a) + # func = make_fx(f, leaf_module_list={"test_dispatch_tracer.TestModuleLeaf"}) + # gm = func(a) + # output = gm(a) + # torch.testing.assert_close(output, ref_output) + # import pdb;pdb.set_trace() + # # There should be a call module node in the graph. + # call_module_node = None + # for node in gm.graph.nodes: + # if node.op == "call_module": + # call_module_node = node + # self.assertIsNotNone(call_module_node) + # self.assertEqual(call_module_node.target, "TestModuleLeaf_0") + + def test_non_tensor_input(self): + def foo(x): + a = x["a"] + b = x["b"] + return a + b + + x = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} + ref_output = foo(x) + func = make_fx(foo) + gm = func(x) + output = gm(x) + torch.testing.assert_close(output, ref_output) + + def test_reference_copy(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y[:, 0] = x[:, 0] + return y + + mod = TestModule() + + def f(x, y): + return mod(x, y) + + a = torch.ones(2, 2) + 2 + b = torch.ones(2, 2) + b_copy = torch.ones(2, 2) + ref_output = f(a, b) + gm = make_fx(functionalize(f))(a, b) + output = gm(a, b_copy) + torch.testing.assert_close(output, ref_output) + + def test_reference_copy_torchdynamo(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x, y): + y = y + 3 + y = self.relu(y) + y[:, 0] = x[:, 0] + return y + + mod = TestModule() + + def f(x, y): + return mod(x, y) + + a = torch.ones(2, 2) + 2 + b = torch.ones(2, 2) + inputs = [a, b] + ref_output = f(*inputs) + + def compile_dispatch(gm, example_inputs): + # after normalization, relu in-place is removed + gm = normalize_ir(gm, example_inputs) + # dispatch tracer + nargs = len(example_inputs) + + def fake_signature(fn, nargs): + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + + gm = make_fx(functionalize(fake_signature(gm, nargs)))(*example_inputs) + return gm + + optimized_mod = torchdynamo.optimize( + compile_dispatch, + nopython=True, + )(mod) + output = optimized_mod(*inputs) + torch.testing.assert_close(output, ref_output) diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_resnet.py b/py/torch_tensorrt/dynamo/test/tracer/test_resnet.py new file mode 100644 index 0000000000..cf04edc5d9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/tracer/test_resnet.py @@ -0,0 +1,86 @@ +import unittest + +import torch + +import torch._dynamo.config +import torchvision +from torch_tensorrt.dynamo.lower import compile +from torch_tensorrt.dynamo.utils import LowerPrecision + + +class ResnetTest(unittest.TestCase): + def test_resnet18_aten(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + + aten_mod = compile( + mod, + inputs, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + aten_output = aten_output[0] + fx_mod = compile( + mod, + inputs, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + # Kernel selection is tricky in TRT with big variance as shown below: + # Mismatched elements: 30816 / 32000 (96.3%) + # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) + # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) + # so we choose to use cosine similarity + cos_val = torch.nn.functional.cosine_similarity( + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) + + def test_resnet18_aten_dynamic(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + + aten_mod = compile( + mod, + inputs, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + aten_output = aten_output[0] + fx_mod = compile( + mod, + inputs, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + + cos_val = torch.nn.functional.cosine_similarity( + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py new file mode 100644 index 0000000000..709973ae22 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py @@ -0,0 +1,200 @@ +# Owner(s): ["oncall: gpu_enablement"] +import functools +import glob +import logging +import os +import shutil +import tempfile +from typing import Union +from unittest import TestCase + +import torch_tensorrt.dynamo.diagnostics as diag + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def reset_diag(fn): + @functools.wraps(fn) + def reset(*a, **kw): + try: + tok1 = diag._CURRENT_COLLECTOR.set(None) + tok2 = diag._CURRENT_WRITER.set(None) + tok3 = diag._SUBSEQUENT_COLLECT_SUPPRESSED_BY.set(None) + return fn(*a, **kw) + finally: + diag._CURRENT_COLLECTOR.reset(tok1) + diag._CURRENT_WRITER.reset(tok2) + diag._SUBSEQUENT_COLLECT_SUPPRESSED_BY.reset(tok3) + + return reset + + +class Fx2trtDiagnosticsTest(TestCase): + @reset_diag + def test_diagnostics(self): + collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) + + diag.set_current_collector(collector) + + try: + with diag.collect_when_fail(): + diag.write("aaa", "hello") + diag.write("bbb", lambda: "world") + diag.write("ccc", b"123") + diag.write("ddd", lambda: b"456") + + def boom() -> str: + raise AssertionError("Error generating diagnostics.") + + diag.write("eee", boom) + + diag.write("zzz", "done") + raise _UserDefinedError("Error while lowering") + except _UserDefinedError: + pass + + zip_fn = collector._last_zip_path_for_test + assert os.path.exists(zip_fn) + with tempfile.TemporaryDirectory() as tempdir: + _LOGGER.info(f"Unpacking into {tempdir}") + shutil.unpack_archive(zip_fn, tempdir) + _check_file(tempdir, "aaa", "hello") + _check_file(tempdir, "bbb", "world") + _check_file(tempdir, "ccc", b"123") + _check_file(tempdir, "ddd", b"456") + _check_file(tempdir, "zzz", "done") + # file eee should still exist to contain err msg + _check_file(tempdir, "eee", "") + + @reset_diag + def test_condition_func_name(self): + collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) + diag.set_current_collector(collector) + + with diag.collect_when( + diag.CollectionConditions.when_called_by_function( + self.test_condition_func_name.__name__ + ) + ): + diag.write("aaa", "hello") + + zip_fn = collector._last_zip_path_for_test + assert os.path.exists(zip_fn) + with tempfile.TemporaryDirectory() as tempdir: + _LOGGER.info(f"Unpacking into {tempdir}") + shutil.unpack_archive(zip_fn, tempdir) + _check_file(tempdir, "aaa", "hello") + + @reset_diag + def test_write_without_collect(self): + collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) + diag.set_current_collector(collector) + diag.write("aaa", "hello") + root_dir = diag.get_current_writer().root_dir() + res = glob.glob(f"{root_dir}/*") + assert not res # root dir should be empty + + def test_conditions(self): + + _test_cond( + diag.CollectionConditions.when_called_by_function( + self.test_conditions.__name__ + ), + should_collect=True, + ) + + _test_cond( + diag.CollectionConditions.when_called_by_function("moo_baa_la_la_la"), + should_collect=False, + ) + + _test_cond( + diag.CollectionConditions.any( + diag.CollectionConditions.never(), + diag.CollectionConditions.always(), + ), + True, + ) + + _test_cond( + diag.CollectionConditions.all( + diag.CollectionConditions.never(), + diag.CollectionConditions.always(), + ), + False, + ) + + _test_cond( + diag.CollectionConditions.not_( # returns False + diag.CollectionConditions.always(), # returns True + ), + False, + ) + + _test_cond( + diag.CollectionConditions.when_not_in_tests(), + False, # Yes we are in test right now + ) + + # nested + _test_cond( + diag.CollectionConditions.any( + diag.CollectionConditions.never(), + diag.CollectionConditions.any( + diag.CollectionConditions.always(), + ), + ), + True, + ) + + +@reset_diag +def _test_cond( + cond: diag.CollectionCondition, + should_collect: bool, +) -> None: + collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) + diag.set_current_collector(collector) + + with diag.collect_when(cond): + diag.write("aaa", "hello") + + zip_fn = collector._last_zip_path_for_test + if should_collect: + assert os.path.exists(zip_fn) + with tempfile.TemporaryDirectory() as tempdir: + _LOGGER.info(f"Unpacking into {tempdir}") + shutil.unpack_archive(zip_fn, tempdir) + _check_file(tempdir, "aaa", "hello") + else: + assert not zip_fn, "the collection should not have triggered" + + +def _check_file(dir: str, fn: str, content: Union[str, bytes]): + fp = os.path.join(dir, fn) + res = glob.glob(f"{fp}*") + assert len(res) == 1 + fp = res[0] + if not os.path.exists(fp): + raise _CheckFileDoesNotExist(f"{fp} must exist") + if not content: + # don't check content then + return + if isinstance(content, bytes): + with open(fp, "rb") as f: + content_actual = f.read() + assert content == content_actual + else: + content: str + with open(fp, "r", encoding="utf-8") as f: + content_actual = f.read() + assert content == content_actual + + +class _UserDefinedError(Exception): + pass + + +class _CheckFileDoesNotExist(AssertionError): + pass diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py new file mode 100644 index 0000000000..a626c739b0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py @@ -0,0 +1,104 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import logging +import unittest + +import torch +import torch.fx as fx +import torch.nn as nn +from torch_tensorrt.dynamo.lower import Lowerer, LowerSetting +from torch_tensorrt.dynamo.passes.lower_basic_pass import replace_mutable_op + +logger = logging.getLogger(__name__) + + +class Fx2trtLowerTests(unittest.TestCase): + def test_fx2trt_lower(self): + class _Mod(nn.Module): + def forward(self, x): + return (x, 2 * x) + + mod = _Mod() + mod_traced = fx.symbolic_trace(mod) + input = [torch.rand(4)] + lower = Lowerer.create(LowerSetting()) + lower(mod_traced, input) + + def test_lower_with_batchnorm_act_rewrite(self): + class MyBatchNorm(nn.BatchNorm2d): + def forward(self, x): + self._check_input_dim(x) + return x + 1 + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.bn = MyBatchNorm(3) + + def forward(self, x): + return self.bn(x) + + module = TestModule() + inputs = [torch.randn(1, 3, 224, 224)] + lower = Lowerer.create(LowerSetting(ast_rewriter_allow_list={MyBatchNorm})) + lower(module, inputs) + + def test_lower_const_fold(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Parameter(torch.randn(1)) + + def forward(self, x): + return (torch.sqrt(x), self.a) + + lower = Lowerer.create(LowerSetting()) + lower(TestModule(), [torch.randn([2, 2])]) + + def test_replace_mutable_op(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + xf = x.fill_(100) + yf = y.fill_(200) + c = torch.cat([xf, yf], dim=1) + return c + + lower = Lowerer.create(LowerSetting()) + mod_traced = fx.symbolic_trace(TestModule()) + lower(mod_traced, [torch.randn(3, 4), torch.randn(3, 4)]) + + def test_replace_mutable_op_dont_apply(self): + class TestModule(torch.nn.Module): + def forward(self, x): + s = x + 1 + t = s.fill_(5) + p = s + t + return p + + mod_traced = fx.symbolic_trace(TestModule()) + old_code = mod_traced.code + + transformed = replace_mutable_op(mod_traced) + new_code = transformed.code + + # s.fill_ shouldn't have been replaced + # because s is used later + self.assertEqual(old_code, new_code) + + def test_replace_mutable_op_do_apply(self): + class TestModule(torch.nn.Module): + def forward(self, x): + s = x + 1 + t = s.fill_(5) # s not used afterwards + p = x + t + return p + + mod_traced = fx.symbolic_trace(TestModule()) + old_code = mod_traced.code + + transformed = replace_mutable_op(mod_traced) + new_code = transformed.code + + # s.fill_ should have been replaced + # because s is not used afterwards + self.assertNotEqual(old_code, new_code) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py new file mode 100644 index 0000000000..185f3acc04 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py @@ -0,0 +1,128 @@ +# Owner(s): ["oncall: gpu_enablement"] +import functools +import logging +import typing as t +from contextlib import contextmanager +from unittest import TestCase + +import torch_tensorrt.dynamo.observer as ob +from torch_tensorrt.dynamo.observer import observable + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def set_observer_callback_rethrow(fn): + """ + Specify that observer callback exceptions should be re-thrown (default + behavior is to swallow) Re-throw is only for test purpose. + """ + + @functools.wraps(fn) + def fn_(*args, **kwargs): + try: + ob.RETHROW_CALLBACK_EXCEPTION = True + return fn(*args, **kwargs) + finally: + ob.RETHROW_CALLBACK_EXCEPTION = False + + return fn_ + + +class ObserverTests(TestCase): + @set_observer_callback_rethrow + def test_basics(self): + @observable() + def foo(x, y, z): + return x + y + z + + with execution_verifier() as verify_execution: + + @verify_execution + def log_pre(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling log: {ctx}") + assert ctx.callable is foo.orig_func + assert ctx.args == (1, 2) + assert ctx.kwargs == {"z": 3} + assert not ctx.return_value + + @verify_execution + def log_post(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling log: {ctx}") + assert ctx.callable is foo.orig_func + assert ctx.args == (1, 2) + assert ctx.kwargs == {"z": 3} + assert ctx.return_value == 6 + + with foo.observers.pre.add(log_pre), foo.observers.post.add(log_post): + foo(1, 2, z=3) + + with execution_verifier() as verify_execution: + + @verify_execution + def log_pre(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling log: {ctx}") + + @verify_execution + def log_post(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling log: {ctx}") + + foo.observers.pre.add(log_pre) + foo.observers.post.add(log_post) + foo(1, 2, 3) + + with execution_verifier() as verify_execution: + + @verify_execution + def f1(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling f1: {ctx}") + + @verify_execution + def f2(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling f2: {ctx}") + + # Test that we can register the same observation point twice + with foo.observers.pre.add(f1): + with foo.observers.pre.add(f2): + foo(1, 2, z=3) + + def test_observer_callbacks_should_not_throw(self): + @observable() + def foo(x, y, z): + return x + y + z + + with execution_verifier() as verify_execution: + + @verify_execution + def log_pre(ctx: ob.ObserveContext) -> None: + _LOGGER.info(f"calling log: {ctx}") + raise CallbackError("TEST CALLBACK EXCEPTION") + + with foo.observers.pre.add(log_pre): + foo(1, 2, 3) + + +@contextmanager +def execution_verifier(): + _is_called: t.Dict[callable, bool] = {} + + def verify_executed(fn): + _is_called[fn] = False + + @functools.wraps(fn) + def fn_(*args, **kwargs): + _is_called[fn] = True + return fn(*args, **kwargs) + + return fn_ + + try: + yield verify_executed + except: # noqa: B001 + raise + else: + for fn, was_executed in _is_called.items(): + assert was_executed, f"{fn} was not executed" + + +class CallbackError(Exception): + pass diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py new file mode 100644 index 0000000000..1898f10980 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py @@ -0,0 +1,53 @@ +# Owner(s): ["oncall: gpu_enablement"] +import functools +from unittest import TestCase + +import torch_tensorrt.dynamo.observer as ob +from test_observer import execution_verifier, set_observer_callback_rethrow +from torch_tensorrt.dynamo.passes.lower_basic_pass import fuse_permute_linear + + +class ObserverGPUTests(TestCase): + @set_observer_callback_rethrow + def test_observe_lowerer(self): + """ + Test that we can observe the execution of `fuse_permute_linear` during + lowering. + """ + from dataclasses import replace + + import torch + import torch.nn as nn + + import torch_tensorrt.fx.lower as lower + from torch_tensorrt.fx.lower_setting import LowerSetting + + class Model(nn.Module): + def forward(self, x, y): + return x + y + + mod = Model().cuda() + inp = [torch.rand(1, 10), torch.rand(1, 10)] + inp = [i.cuda() for i in inp] + mod(*inp) + + with execution_verifier() as verify_execution: + + lowerer = lower.Lowerer.create( + lower_setting=LowerSetting(min_block_size=0) + ) + + @verify_execution + def observe_fuse_permute_linear_post(ctx: ob.ObserveContext): + """ + Called when fuse_permute_linear is executed. Decorated with + `verify_execution` so if this function is not executed, the + test fails. + """ + assert ctx.callable is fuse_permute_linear.orig_func + + # Register the observer callback and do the lowering + with fuse_permute_linear.observers.post.add( + observe_fuse_permute_linear_post + ): + lowerer(mod, inp) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py b/py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py new file mode 100644 index 0000000000..699b787f0e --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py @@ -0,0 +1,80 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import torch +import torch.fx +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops # noqa: F401 +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo.tools.trt_splitter import create_trt_operator_support +from torch_tensorrt.fx.tracer.acc_tracer import acc_ops, acc_tracer + + +class TestTRTOperatorSupport(TestCase): + def test_supported_node_target(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x): + x = self.linear(x) + x = x + 1 + return torch.add(input=x, other=x) + + mod = TestModule() + traced_mod = acc_tracer.trace(mod, [torch.randn(1, 2, 1, 1)]) + op_support = create_trt_operator_support() + for node in traced_mod.graph.nodes: + self.assertTrue(op_support.is_node_supported(mod, node)) + + def test_unsupport_node_explicit_batch_dim(self): + class TestModule(nn.Module): + def forward(self, x): + y = torch.add(input=x, other=x) + return torch.max_pool1d(y, 1) + + mod = TestModule() + traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)]) + op_support = create_trt_operator_support(use_implicit_batch_dim=False) + + for node in traced_mod.graph.nodes: + if node.target == acc_ops.add: + self.assertTrue(op_support.is_node_supported(mod, node)) + elif node.target == acc_ops.split: + self.assertFalse(op_support.is_node_supported(mod, node)) + + def test_unsupport_node_implicit_batch_dim(self): + class TestModule(nn.Module): + def forward(self, x): + y = torch.add(input=x, other=x) + return nn.functional.gelu(y) + + mod = TestModule() + traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)]) + op_support = create_trt_operator_support(use_implicit_batch_dim=True) + + for node in traced_mod.graph.nodes: + if node.target == acc_ops.add: + self.assertTrue(op_support.is_node_supported(mod, node)) + elif node.target == acc_ops.gelu: + self.assertFalse(op_support.is_node_supported(mod, node)) + + def test_support_node_with_int_attr(self): + class TestModule(nn.Module): + def forward(self, x): + zeros = torch.randint(3, 5, (1,)) + zeros = zeros.to(torch.int64) + scale = torch.randn(1) + return torch.quantize_per_tensor(x, scale, zeros, torch.quint8) + + mod = TestModule() + traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)]) + op_support = create_trt_operator_support(use_implicit_batch_dim=True) + + for node in traced_mod.graph.nodes: + if node.target == acc_ops.quantize_per_tensor: + self.assertTrue(op_support.is_node_supported(mod, node)) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py new file mode 100644 index 0000000000..9d96bf78b0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py @@ -0,0 +1,1176 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import operator + +import torch # isort:skip +import torch.fx # isort:skip + +import torch.fx.passes.operator_support as op_support +import torch.fx.passes.shape_prop as shape_prop +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from torch.fx.passes import splitter_base +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo.tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer + +ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!" +ERROR_MSG_MULTI_ACC_MODULES = "FX split failed: Found more than one ACC submodules!" +ACC_SUBMODULE_PREFIX = "_run_on_acc_" + + +# Check if the split result has expected number of ACC submodule. If not, raise runtime error; +def verify_split_model( + mod: torch.fx.GraphModule, + acc_submodule_keyword: str = ACC_SUBMODULE_PREFIX, + expected_number: int = 1, +) -> None: + acc_submodule_num = 0 + for name, _ in mod.named_children(): + if name.startswith(acc_submodule_keyword): + acc_submodule_num = acc_submodule_num + 1 + + if acc_submodule_num < expected_number: + raise RuntimeError(ERROR_MSG_NO_ACC_MODULE) + elif acc_submodule_num > expected_number: + raise RuntimeError(ERROR_MSG_MULTI_ACC_MODULES) + + +def find_inputs(module): + return [n for n in module.graph.nodes if n.op == "placeholder"] + + +def find_fun_calls(module, target): + return [ + n for n in module.graph.nodes if n.op == "call_function" and n.target == target + ] + + +def find_output(module): + return next(n for n in module.graph.nodes if n.op == "output") + + +TENSOR_SIZE_DUMMY = "tensor_size_dummy" + + +def find_call_targets(module: torch.fx.GraphModule): + result = set() + for n in module.graph.nodes: + n: torch.fx.Node + if n.op in {"call_module", "call_function", "call_method"}: + result.add(n.target) + return result + + +# We test both FxNetSplitOnly and FxNetSplitter here, since they share most +# functionalities. The only difference is that FxNetSplitOnly does not implement +# split_preview() related functions, while FxNetSplitter does. +class TestSplit(TestCase): + def test_demo(self): + """ + ==> b ==> + // \\ + a d + \\ // + ==> c ==> + """ + + class SimpleModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.cos(a) + d = b + c + return d + + mod = acc_tracer.trace(SimpleModule(), [torch.randn(2, 3)]) + + # Making b and c run on ACC + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + "acc_ops.cos": None, + } + ), + ) + + st_split = splitter() + + [arg] = find_inputs(st_split) + + # First subgraph calculates b = sin(a) and c = cos(a) on ACC + [sin] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sin) + self.assertEqual(arg.name, sin.kwargs["input"].name) + + [cos] = find_fun_calls(st_split._run_on_acc_0, acc_ops.cos) + self.assertEqual(arg.name, cos.kwargs["input"].name) + + # Second subgraph calculates d = b + c on CPU + [add] = find_fun_calls(st_split._run_on_gpu_1, acc_ops.add) + self.assertEqual(sin.name, add.kwargs["input"].name) + self.assertEqual(cos.name, add.kwargs["other"].name) + + def test_mod_with_getattr(self): + """ + CPU subgraph should have get_attr for self.a while ACC subgraph + should have get_attr for self.b. + """ + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(1, 1, 1, 1) + self.b = torch.randn(1, 1, 1, 1) + self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(1, 1) + + def forward(self, x): + x = x + self.a + x = self.conv(x) + return self.linear(x - self.b) + + mod = acc_tracer.trace(SimpleModule(), [torch.randn(1, 1, 1, 1)]) + mod.eval() + + splitter = TRTSplitter( + mod, + (torch.randn(1, 1, 1, 1),), + op_support_with_support_dict( + { + "acc_ops.linear": None, + "acc_ops.sub": None, + } + ), + ) + + def test_splitter(splitter): + st_split = splitter() + verify_split_model(st_split) + # Should be "a", "conv.weight", "conv.bias". + get_attr_nodes = [ + node.target + for node in st_split._run_on_gpu_0.graph.nodes + if node.op == "get_attr" + ] + assert len(get_attr_nodes) == 3 and "a" in get_attr_nodes + + # Should be "b", "conv.weight", "conv.bias". + get_attr_nodes = [ + node.target + for node in st_split._run_on_acc_1.graph.nodes + if node.op == "get_attr" + ] + assert len(get_attr_nodes) == 3 and "b" in get_attr_nodes + + test_splitter(splitter) + + def test_nothing_to_split(self): + class SimpleModule(torch.nn.Module): + def forward(self, a): + return a + + mod = acc_tracer.trace(SimpleModule(), [torch.randn(2, 3)]) + + # Mark any operation as runnable on ACC + class CustomOpSupport(op_support.OperatorSupportBase): + def is_node_supported(self, submodules, node): + return True + + splitter = TRTSplitter(mod, (torch.randn(2, 3),), CustomOpSupport()) + + def test_splitter(splitter): + st_split = splitter() + try: + verify_split_model(st_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) + self.assertEqual(splitter.module.__dict__.keys(), st_split.__dict__.keys()) + + test_splitter(splitter) + + def test_multi_output(self): + class MultiOutputModule(torch.nn.Module): + def forward(self, x): + res, ind = torch.topk(x, 3) + return torch.sigmoid(res), ind + + mod = acc_tracer.trace(MultiOutputModule(), [torch.randn(2, 3)]) + + # Mark any operation as runnable on ACC + class CustomOpSupport(op_support.OperatorSupportBase): + def is_node_supported(self, submodules, node): + return True + + splitter = TRTSplitter(mod, (torch.randn(2, 3),), CustomOpSupport()) + + def test_splitter(splitter): + st_split = splitter() + verify_split_model(st_split) + [arg] = find_inputs(st_split) + + # There is only one subgraph that executes topk and sigmoid on ACC + [topk] = find_fun_calls(st_split._run_on_acc_0, acc_ops.topk) + self.assertEqual(arg.name, topk.kwargs["input"].name) + self.assertEqual(3, topk.kwargs["k"]) + + [topk_res1, topk_res2] = find_fun_calls( + st_split._run_on_acc_0, acc_ops.getitem + ) + + [sigmoid] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sigmoid) + self.assertIn( + sigmoid.kwargs["input"].name, {topk_res1.name, topk_res2.name} + ) + + # Main graph returns a tuple + output = find_output(st_split._run_on_acc_0) + self.assertLess( + {output.args[0][0].name, output.args[0][1].name}, + {topk_res1.name, topk_res2.name, sigmoid.name}, + ) + + test_splitter(splitter) + + def test_nested_modules(self): + """ + x + // \\ + // \\ + relu(x) sin(x) + \\ // + \\ // + relu(x) + sin(x) + """ + + class ReluModule(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + class SinModule(torch.nn.Module): + def forward(self, x): + return torch.sin(x) + + class TestModule3(torch.nn.Module): + def __init__(self, relu_module, sin_module): + super().__init__() + self.relu_module = relu_module + self.sin_module = sin_module + + def forward(self, x): + return self.relu_module(x) + self.sin_module(x) + + mod = acc_tracer.trace( + TestModule3(ReluModule(), SinModule()), [torch.randn(2, 3)] + ) + + # Making sin(x) run on ACC + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + } + ), + ) + + def test_splitter(splitter): + st_split = splitter() + verify_split_model(st_split) + [arg] = find_inputs(st_split) + + # First subgraph calculates relu(x) on CPU + [relu] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.relu) + self.assertEqual(arg.name, relu.kwargs["input"].name) + + # Second subgraph calculates sin(x) on ACC + [sin] = find_fun_calls(st_split._run_on_acc_1, acc_ops.sin) + self.assertEqual(arg.name, sin.kwargs["input"].name) + + # Third subgraph calculates sum on CPU + [add] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.add) + self.assertEqual(relu.name, add.kwargs["input"].name) + self.assertEqual(sin.name, add.kwargs["other"].name) + + # Checking that results of applying split module will be the same + tensor = torch.randn(5) + self.assertTrue(torch.equal(mod(tensor), st_split(tensor))) + + test_splitter(splitter) + + def test_longer_chain(self): + """ + sin relu cos sigmoid tanh + a ====> b =====> c ====> d ========> e =====> f + """ + + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + # Making relu and sigmoid execute on ACC + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + } + ), + ) + + def test_splitter(splitter): + st_split = splitter() + try: + verify_split_model(st_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) + [arg] = find_inputs(st_split) + + # First subgraph calculates b = sin(a) on CPU + [sin] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.sin) + self.assertEqual(arg.name, sin.kwargs["input"].name) + + # Second subgraph calculates c = relu(b) on ACC + [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu) + self.assertEqual(sin.name, relu.kwargs["input"].name) + + # Third subgraph calculates d = cos(c) on CPU + [cos] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.cos) + self.assertEqual(relu.name, cos.kwargs["input"].name) + + # Fourth subgraph calculates e = sigmoid(d) on ACC + [sigmoid] = find_fun_calls(st_split._run_on_acc_3, acc_ops.sigmoid) + self.assertEqual(cos.name, sigmoid.kwargs["input"].name) + + # Fifth subgraph calculates f = tanh(e) on CPU + [tanh] = find_fun_calls(st_split._run_on_gpu_4, acc_ops.tanh) + self.assertEqual(sigmoid.name, tanh.kwargs["input"].name) + + test_splitter(splitter) + + def test_min_block_size(self): + """ + sin relu cos sigmoid tanh + a ====> b =====> c ====> d ========> e =====> f + + We set sin, cos and tanh as acc node but also set min_block_size to 2 + and expect the whole module stay on CPU. + """ + + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + # Set sin, cos and tanh as acc node and split with settings + class CustomOpSupport(op_support.OperatorSupport): + _support_dict = { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.tanh": None, + } + + # Create splitter setting and set min_block_size to 2 + settings = splitter_base._SplitterSettingBase() + settings.min_block_size = 2 + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.tanh": None, + } + ), + settings, + ) + + def test_splitter(splitter): + st_split = splitter() + try: + verify_split_model(st_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) + modules = list(st_split.named_modules()) + # Main module and a submodule + assert len(modules) == 2 + + assert modules[1][0] == "_run_on_gpu_0" + + test_splitter(splitter) + + def test_extend_acc_subgraph_after_split(self): + class TestModule(torch.nn.Module): + r""" a (input) + | + b + / \ + c d + \ / + e + / \ + | (g1, g2, g3, g4) + \ / | + f | + \ | + h + + c and f are not runnable on acc while all other nodes are supported by acc. + g1, g2, g3 and g4 should be in a fusion group, let's call it g. + + After split we have 2 cpu subgraphs (c) and (f), 3 acc subgraphs (b, d), (e, g) and (h). + We expect 3 acc subgraphs (b), (d, e, g) and (h) after extend the second acc subgraph. + And expect acc subgraphs stay the same after extend the third acc subgraph because of + the unbreakable fusion group. + """ + + def forward(self, a: torch.Tensor): + b = a + a + c = b - b + d = b + b + e = c + d + + # These four nodes should be in a fusion group + g1 = e.size() + g2 = g1[0] + g3 = e + g2 + g4 = g3 + g2 + + f = e - g3 + h = f + g4 + return h + + a = torch.randn(2) + mod = acc_tracer.trace(TestModule(), (a,)) + + # Allow all nodes expect subtract run on accelerator + class CustomOpSupport(op_support.OperatorSupportBase): + def is_node_supported(self, submodules, node): + return op_support.get_node_target(submodules, node) != "acc_ops.sub" + + splitter = TRTSplitter(mod, (a,), CustomOpSupport()) + + def test_splitter(splitter): + # Manually tag nodes first in case split algorithm changes in the future + nodes = list(splitter.module.graph.nodes) + # b and d + nodes[1].tag = "acc_0" + nodes[3].tag = "acc_0" + # c + nodes[2].tag = "cpu_1" + # e and g + nodes[4].tag = "acc_2" + nodes[5].tag = "acc_2" + nodes[6].tag = "acc_2" + nodes[7].tag = "acc_2" + nodes[8].tag = "acc_2" + # f + nodes[9].tag = "cpu_3" + # h + nodes[10].tag = "acc_4" + + splitter.tags = ["acc_0", "cpu_1", "acc_2", "cpu_3", "acc_4"] + split_module = splitter.split() + try: + verify_split_model(split_module, "acc_") + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) + try: + verify_split_model(split_module) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) + + module_names = [name for name, _ in split_module.named_modules()] + # Main module, 2 cpu submodules and 3 acc submodule + assert len(module_names) == 6 + + # 1 Placeholder, 2 Adds and 1 Output + assert len(split_module.acc_0.graph.nodes) == 4 + # 2 Placeholder, 3 Adds, 1 Size, 1 GetItem and 1 Output + assert len(split_module.acc_2.graph.nodes) == 8 + + # Extend the second acc subgraph + splitter.extend_acc_subgraph("acc_2") + extend_module = splitter.split() + try: + verify_split_model(extend_module, "acc_") + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) + + # 1 Placeholder, 1 Adds and 1 Output + assert len(extend_module.acc_0.graph.nodes) == 3 + # 2 Placeholder, 4 Adds 1 Size, 1 GetItem and 1 Output + assert len(extend_module.acc_2.graph.nodes) == 9 + + # Extend the third acc subgraph + splitter.extend_acc_subgraph("acc_4") + extend_module = splitter.split() + try: + verify_split_model(extend_module, "acc_") + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) + + assert len(extend_module.acc_2.graph.nodes) == 9 + # 2 Placeholder, 1 Adds and 1 Output + assert len(extend_module.acc_4.graph.nodes) == 4 + + test_splitter(splitter) + + def test_get_attr_into_output(self): + """ + Here we verify the case when get_attr node is consumed directly by the + output. We don't expect any split to happen in this test, just want to + make sure that the splitter code doesn't break. + """ + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 3) + + def forward(self, x): + return (x, self.a) + + # No need to put anything on ACC. + class TestOperatorSupport: + def is_node_supported(self, submodules, node): + return False + + module_original = acc_tracer.trace(TestModule(), [torch.randn(4, 5)]) + + splitter = TRTSplitter( + module=module_original, + sample_input=torch.randn(4, 5), + operator_support=TestOperatorSupport(), + ) + + def test_splitter(splitter): + module_split = splitter() + try: + verify_split_model(module_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) + + output = find_output(module_split) + # Second argument of the output should be get_attr. + self.assertEqual("get_attr", output.args[0][1].op) + + # Check if modules are equivalent. + tensor = torch.randn(10, 20) + result_original = module_original(tensor) + result_split = module_split(tensor) + self.assertTrue(torch.equal(result_original[0], result_split[0])) + self.assertTrue(torch.equal(result_original[1], result_split[1])) + + test_splitter(splitter) + + def test_get_attr_into_starter_node(self): + """ + Here we verify the case when starter nodes depend on get_attr node only. + We don't expect any split to happen in this test, just want to make sure + that the splitter code doesn't break. + """ + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(2, 3) + + def forward(self): + m = self.a + self.a + o = m + m + return o + + # No need to put anything on ACC. + class TestOperatorSupport: + def is_node_supported(self, submodules, node): + return False + + module_original = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + splitter = TRTSplitter( + module=module_original, + sample_input=torch.randn(2, 3), + operator_support=TestOperatorSupport(), + ) + + def test_splitter(splitter): + module_split = splitter() + try: + verify_split_model(module_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) + + # Check if modules are equivalent. + result_original = module_original() + result_split = module_split() + self.assertTrue(torch.equal(result_original, result_split)) + + test_splitter(splitter) + + +class TestSplitComplexGraph(TestCase): + """ + a ====== + // \\ \\ + b c d + \\ // // + e // + \\ // + \\ // + f + """ + + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(a) + d = torch.cos(a) + e = b + c + f = e - d + return f + + def test_split_complex_graph_1(self): + mod = acc_tracer.trace(self.TestModule(), [torch.randn(2, 3)]) + + # Making 'c' and 'd' run on ACC + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.cos": None, + "acc_ops.relu": None, + } + ), + ) + + def test_splitter(splitter): + st_split = splitter() + verify_split_model(st_split) + + [arg] = find_inputs(st_split) + + # First subgraph calculates b = sin(a) on CPU + [sin] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.sin) + self.assertEqual(arg.name, sin.kwargs["input"].name) + + # Second subgraph calculates c = relu(a) and d = cos(a) on ACC + [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu) + self.assertEqual(arg.name, relu.kwargs["input"].name) + + [cos] = find_fun_calls(st_split._run_on_acc_1, acc_ops.cos) + self.assertEqual(arg.name, cos.kwargs["input"].name) + + # Third subgraph calculates the e = b + c and f = e - d on CPU + [add] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.add) + self.assertEqual(sin.name, add.kwargs["input"].name) + self.assertEqual(relu.name, add.kwargs["other"].name) + + [sub] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.sub) + self.assertEqual(add.name, sub.kwargs["input"].name) + self.assertEqual(cos.name, sub.kwargs["other"].name) + + test_splitter(splitter) + + def test_split_complex_graph_2(self): + module_nn = self.TestModule() + module = acc_tracer.trace(module_nn, (torch.randn(2, 3),)) + + # Making 'c', 'd' and 'e' run on ACC + splitter = TRTSplitter( + module, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.cos": None, + "acc_ops.relu": None, + "acc_ops.add": None, + } + ), + ) + + def test_splitter(splitter): + module_fx_split = splitter() + verify_split_model(module_fx_split) + + [arg] = find_inputs(module) + + # First subgraph calculates b = sin(a) on CPU + [sin] = find_fun_calls(module_fx_split._run_on_gpu_0, acc_ops.sin) + self.assertEqual(arg.name, sin.kwargs["input"].name) + + # Second subgraph calculates c = relu(a), d = cos(a) and e = b + c on ACC + [relu] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.relu) + self.assertEqual(arg.name, relu.kwargs["input"].name) + + [cos] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.cos) + self.assertEqual(arg.name, cos.kwargs["input"].name) + + [add] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.add) + self.assertEqual(sin.name, add.kwargs["input"].name) + self.assertEqual(relu.name, add.kwargs["other"].name) + + # Third subgraph calculates f = e + d on CPU + [sub] = find_fun_calls(module_fx_split._run_on_gpu_2, acc_ops.sub) + self.assertEqual(add.name, sub.kwargs["input"].name) + self.assertEqual(cos.name, sub.kwargs["other"].name) + + test_splitter(splitter) + + +class TestSplitNonTensorEdges(TestCase): + """ + a (relu) + // \\ + (b1,b2) c (cos) + \\ // + d (add) + || + e (sigmoid) + """ + + # Note non-tensor edge between b2 and d + class TestModule(torch.nn.Module): + def forward(self, x): + a = torch.relu(x) + + b1 = a.size() + b2 = b1[0] + + c = torch.cos(a) + + d = b2 + c + e = torch.sigmoid(d) + return e + + def test_split_non_tensor_edges_1(self): + test_data = torch.randn(2, 3) + + module_nn = acc_tracer.trace(self.TestModule(), (test_data,)) + + # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC + splitter = TRTSplitter( + module_nn, + (test_data,), + op_support_with_support_dict( + { + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.add": None, + "acc_ops.getitem": None, + "acc_ops.size": None, + } + ), + ) + + def test_splitter(splitter): + module_fx_split = splitter() + try: + verify_split_model(module_fx_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) + + self.assertEqual( + {acc_ops.relu}, find_call_targets(module_fx_split._run_on_acc_0) + ) + + self.assertEqual( + {acc_ops.cos}, find_call_targets(module_fx_split._run_on_gpu_1) + ) + + self.assertEqual( + {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, + find_call_targets(module_fx_split._run_on_acc_2), + ) + + # Make sure we can compile to TorchScript + module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) + self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) + + test_splitter(splitter) + + def test_split_non_tensor_edges_2(self): + test_data = torch.randn(2, 3) + + module_nn = acc_tracer.trace(self.TestModule(), (test_data,)) + + # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC with limit on ACC + # subgraph size + settings = splitter_base._SplitterSettingBase() + settings.min_block_size = 2 + splitter = TRTSplitter( + module_nn, + (test_data,), + op_support_with_support_dict( + { + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.add": None, + "acc_ops.getitem": None, + "acc_ops.size": None, + } + ), + settings, + ) + + def test_splitter(splitter): + module_fx_split = splitter() + verify_split_model(module_fx_split) + + self.assertEqual( + {acc_ops.relu, acc_ops.cos}, + find_call_targets(module_fx_split._run_on_gpu_0), + ) + + self.assertEqual( + {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, + find_call_targets(module_fx_split._run_on_acc_1), + ) + + # Make sure we can compile to TorchScript + module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) + self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) + + test_splitter(splitter) + + def test_split_non_tensor_edges_3(self): + test_data = torch.randn(2, 3) + + module_nn = acc_tracer.trace( + self.TestModule(), + (test_data,), + ) + + # Making 'a', 'c', 'd' and 'e' run on ACC + splitter = TRTSplitter( + module_nn, + (test_data,), + op_support_with_support_dict( + { + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.cos": None, + "acc_ops.add": None, + } + ), + ) + + def test_splitter(splitter): + module_fx_split = splitter() + try: + verify_split_model(module_fx_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) + + self.assertEqual( + {acc_ops.relu, acc_ops.cos}, + find_call_targets(module_fx_split._run_on_acc_0), + ) + + self.assertEqual( + {acc_ops.size, acc_ops.getitem, acc_ops.add}, + find_call_targets(module_fx_split._run_on_gpu_1), + ) + + self.assertEqual( + {acc_ops.sigmoid}, + find_call_targets(module_fx_split._run_on_acc_2), + ) + + # Make sure we can compile to TorchScript + module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) + self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) + + test_splitter(splitter) + + def test_split_non_tensor_edges_4(self): + test_data = torch.randn(2, 3) + + module_nn = acc_tracer.trace( + self.TestModule(), + (test_data,), + ) + + # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC + # subgraph size + settings = splitter_base._SplitterSettingBase() + settings.min_block_size = 2 + splitter = TRTSplitter( + module_nn, + (test_data,), + op_support_with_support_dict( + { + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.cos": None, + "acc_ops.add": None, + } + ), + settings, + ) + + def test_splitter(splitter): + module_fx_split = splitter() + verify_split_model(module_fx_split) + + self.assertEqual( + {acc_ops.relu, acc_ops.cos}, + find_call_targets(module_fx_split._run_on_acc_0), + ) + + self.assertEqual( + {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, + find_call_targets(module_fx_split._run_on_gpu_1), + ) + + # Make sure we can compile to TorchScript + module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) + self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) + + test_splitter(splitter) + + +class TestAccNodesFinder(TestCase): + def test_acc_nodes_finder_1(self): + """ + y -------------> + | + ----> b ----> + x ----> a d + ----> c ----> + | + z -------------> + """ + + # Make a return non-tensor data + class TestModule(torch.nn.Module): + def forward(self, x, y, z): + a1 = x.size() + a1 = a1[0] + + b = y + a1 + c = z - a1 + + d = b + c + + return d + + module_nn = TestModule() + module_fx = torch.fx.symbolic_trace(module_nn) + + # Make a and c lowerable to ACC + finder = torch.fx.passes.splitter_base.FxNetAccNodesFinder( + module_fx, + op_support_with_support_dict( + { + "acc_ops.sub": None, + "acc_ops.getitem": None, + "acc_ops.size": None, + } + ), + False, + ) + acc_nodes = finder() + self.assertEqual(set(), acc_nodes, "Shouldn't have ACC nodes") + + +class TestAccFusionsFinder(TestCase): + """ + x + / \\ + a b + / | \\ + / | a2 + a0 a1 | + | / | + c | + | | + d | + \\ / + e + """ + + class TestModule(torch.nn.Module): + def forward(self, x): + a = x.size() + b = x + x + + a0 = a[0] + a1 = a[1] + a2 = a[2] + c = x.view(a1, a0, -1) + + d = c + c + e = d + a2 + return b, e + + def test_acc_fusions_finder_1(self): + """ + Assume every node is acc node. We should have one fusion group + (a, a0, a1, a2, c, d, e). + """ + module_nn = self.TestModule() + module_fx = torch.fx.symbolic_trace(module_nn) + shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1)) + + acc_node = { + node + for node in module_fx.graph.nodes + if node.op in torch.fx.passes.tools_common.CALLABLE_NODE_OPS + } + + fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder( + module_fx, + acc_node, + ) + fusion_map = fusions_finder() + + self.assertEqual(len(fusion_map), 7) + for _, v in fusion_map.items(): + self.assertEqual(len(v), 7) + + def test_acc_fusions_finder_2(self): + """ + Let b and d be cpu nodes. After fusion all nodes should be cpu nodes + because d is included in the fusion group which force all other nodes + in the same fusion group to be on CPU too. + """ + module_nn = self.TestModule() + module_fx = torch.fx.symbolic_trace(module_nn) + shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1)) + + acc_node = { + node for node in module_fx.graph.nodes if node.target == operator.add + } + fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder( + module_fx, + acc_node, + ) + fusion_map = fusions_finder() + self.assertEqual(len(fusion_map), 0) + + def test_start_with_acc_module_(self): + """ + sin relu cos sigmoid tanh + a ====> b =====> c ====> d ========> e =====> f + + We set sin, relu and cos as acc node but also set min_block_size to 2 + and expect the whole module stay on CPU. + """ + + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + # Set sin, cos and tanh as acc node and split with settings + class CustomOpSupport(op_support.OperatorSupport): + _support_dict = { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + } + + # Create splitter setting and set min_block_size to 2 + settings = splitter_base._SplitterSettingBase() + settings.min_block_size = 2 + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + } + ), + settings, + ) + + def test_splitter(splitter): + st_split = splitter() + try: + verify_split_model(st_split) + except RuntimeError as err: + self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) + modules = list(st_split.named_modules()) + # Main module and a submodule + assert len(modules) == 3 + + assert modules[1][0] == "_run_on_acc_0" + assert modules[2][0] == "_run_on_gpu_1" + + test_splitter(splitter) + + def test_exclude_support_node_by_name(self): + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + # Set sin, cos and tanh as acc node and split with settings + class CustomOpSupport(op_support.OperatorSupport): + _support_dict = { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.tanh": None, + } + + # For unsupport relu node, this would cut graph into acc_0, gpu_1 and acc_2 + # as three sub graphs. + settings = TRTSplitterSetting() + settings.exclude_support_node_name = {"relu"} + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + } + ), + settings, + ) + res = splitter.generate_split_results() + self.assertTrue(len(res), 3) + + +def op_support_with_support_dict(support_dict: dict) -> op_support.OperatorSupportBase: + return op_support.OperatorSupport(support_dict) + + +if __name__ == "__main__": + run_tests() From 80828fb5d57fd9f82a0329087121fd985fa2492a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 4 Apr 2023 17:46:22 -0700 Subject: [PATCH 07/45] chore: Initial refactoring of TS changes to unify with FX backend in dynamo namespace Signed-off-by: Dheeraj Peri --- py/setup.py | 6 ++ py/torch_tensorrt/_Input.py | 65 ++------------- py/torch_tensorrt/_compile.py | 24 ++++++ py/torch_tensorrt/ts/_compile_spec.py | 116 ++++++++++++++++---------- tests/py/api/test_classes.py | 60 +++++++++---- 5 files changed, 158 insertions(+), 113 deletions(-) diff --git a/py/setup.py b/py/setup.py index f7247a9f90..3c73aee3b5 100644 --- a/py/setup.py +++ b/py/setup.py @@ -356,6 +356,9 @@ def run(self): "torch_tensorrt.fx.tools", "torch_tensorrt.fx.tracer.acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer", + "torch_tensorrt.dynamo", + "torch_tensorrt.dynamo.passes", + "torch_tensorrt.dynamo.tools", ] package_dir = { "torch_tensorrt.fx": "torch_tensorrt/fx", @@ -364,6 +367,9 @@ def run(self): "torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools", "torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer", + "torch_tensorrt.dynamo": "torch_tensorrt/dynamo", + "torch_tensorrt.dynamo.passes": "torch_tensorrt/dynamo/passes", + "torch_tensorrt.dynamo.tools": "torch_tensorrt/dynamo/tools", } with open("README.md", "r", encoding="utf-8") as fh: diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 324c385fab..dfc8db3bb8 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -4,7 +4,6 @@ import torch from torch_tensorrt import _enums -from torch_tensorrt import _C class Input(object): @@ -41,6 +40,7 @@ class _ShapeMode(Enum): DOMAIN_OFFSET = 2.0 low_tensor_domain_incl = 0.0 high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET + torch_dtype = None def __init__(self, *args, **kwargs): """__init__ Method for torch_tensorrt.Input @@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs): ) if "dtype" in kwargs: + if isinstance(kwargs["dtype"], torch.dtype): + self.torch_dtype = kwargs["dtype"] + self.dtype = Input._parse_dtype(kwargs["dtype"]) self._explicit_set_dtype = True @@ -173,59 +176,6 @@ def __str__(self) -> str: else: raise RuntimeError("Unknown input shape mode") - def _to_internal(self) -> _C.Input: - internal_in = _C.Input() - if self.shape_mode == Input._ShapeMode.DYNAMIC: - if not Input._supported_input_size_type(self.shape["min_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["min_shape"])) - + " for min_shape" - ) - else: - internal_in.min = self.shape["min_shape"] - - if not Input._supported_input_size_type(self.shape["opt_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["opt_shape"])) - + " for opt_shape" - ) - else: - internal_in.opt = self.shape["opt_shape"] - - if not Input._supported_input_size_type(self.shape["max_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["max_shape"])) - + " for max_shape" - ) - else: - internal_in.max = self.shape["max_shape"] - internal_in.input_is_dynamic = True - else: - if not Input._supported_input_size_type(self.shape): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape)) - + " for shape" - ) - else: - internal_in.opt = self.shape - internal_in.input_is_dynamic = False - - if self.dtype != _enums.dtype.unknown: - self._explicit_set_dtype = True - else: - self._explicit_set_dtype = False - - internal_in.dtype = Input._parse_dtype(self.dtype) - internal_in._explicit_set_dtype = self._explicit_set_dtype - internal_in.format = Input._parse_format(self.format) - - internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) - return internal_in - @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): @@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: Input.low_tensor_domain_incl, Input.high_tensor_domain_excl, ) + elif len(domain) == 2: domain_lo, domain_hi = domain @@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor ) if self.shape_mode == Input._ShapeMode.STATIC: - return torch.randn(self.shape).to(dtype=self.dtype) + return torch.randn(self.shape).to( + dtype=self.dtype if not self.torch_dtype else self.torch_dtype + ) else: return torch.randn(self.shape[optimization_profile_field]).to( - dtype=self.dtype + dtype=self.dtype if not self.torch_dtype else self.torch_dtype ) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index cbd1b87c5c..e75f28ff22 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -15,6 +15,7 @@ class _IRType(Enum): ts = 0 fx = 1 + dynamo = 2 class _ModuleType(Enum): @@ -45,11 +46,14 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" + ir_targets_dynamo = ir == "dynamo" if module_is_tsable and ir_targets_torchscript: return _IRType.ts elif module_is_fxable and ir_targets_fx: return _IRType.fx + elif module_is_fxable and ir_targets_dynamo: + return _IRType.dynamo else: if ir == "default": # Options are listed in order of preference @@ -148,6 +152,26 @@ def compile( dynamic_batch=False, **kwargs, ) + elif target_ir == _IRType.dynamo: + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + lower_precision = LowerPrecision.FP16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + lower_precision = LowerPrecision.FP32 + else: + raise ValueError(f"Precision {enabled_precisions} not supported on FX") + + return torch_tensorrt.dynamo.compile( + module, + inputs, + lower_precision=lower_precision, + **kwargs, + ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 8f06e2ef71..76d3605a1d 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -9,6 +9,7 @@ from typing import Tuple, List, Dict import warnings from copy import deepcopy +from torch_tensorrt.ts.ts_input import TSInput def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: @@ -38,46 +39,6 @@ def _supported_input_size_type(input_size: Any) -> bool: ) -def _parse_input_ranges(input_sizes: List) -> List: - - if any( - not isinstance(i, dict) and not _supported_input_size_type(i) - for i in input_sizes - ): - raise KeyError( - "An input size must either be a static size or a range of three sizes (min, opt, max) as Dict" - ) - - parsed_input_sizes = [] - for i in input_sizes: - if isinstance(i, dict): - if all(k in i for k in ["min", "opt", "min"]): - parsed_input_sizes.append( - Input( - min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"] - )._to_internal() - ) - - elif "opt" in i: - parsed_input_sizes.append(Input(shape=i["opt"])._to_internal()) - - else: - raise KeyError( - "An input size must either be a static size or a range of three sizes (min, opt, max) as Dict" - ) - - elif isinstance(i, list): - parsed_input_sizes.append(Input(shape=i)._to_internal()) - - elif isinstance(i, tuple): - parsed_input_sizes.append(Input(shape=i)._to_internal()) - - elif isinstance(i, torch.Size): - parsed_input_sizes.append(Input(shape=i)._to_internal()) - - return parsed_input_sizes - - def _parse_op_precision(precision: Any) -> _enums.dtype: if isinstance(precision, torch.dtype): if precision == torch.int8: @@ -228,7 +189,23 @@ def _parse_input_signature(input_signature: Any, depth: int = 0): + "non-TRT types." ) - clone = _internal_input_to_torch_class_input(i._to_internal()) + ts_i = i + if i.shape_mode == Input._ShapeMode.STATIC: + ts_i = TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + elif i.shape_mode == Input._ShapeMode.DYNAMIC: + ts_i = TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + else: + raise ValueError( + "Invalid shape mode detected for input while parsing the input_signature" + ) + + clone = _internal_input_to_torch_class_input(ts_i._to_internal()) return clone else: raise KeyError( @@ -260,7 +237,25 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: Input.from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"] ] - info.inputs = [i._to_internal() for i in inputs] + ts_inputs = [] + for i in inputs: + if i.shape_mode == Input._ShapeMode.STATIC: + ts_inputs.append( + TSInput( + shape=i.shape, dtype=i.dtype, format=i.format + )._to_internal() + ) + elif i.shape_mode == Input._ShapeMode.DYNAMIC: + ts_inputs.append( + TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + )._to_internal() + ) + info.inputs = ts_inputs elif compile_spec["input_signature"] is not None: log( @@ -268,7 +263,42 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: "Input signature parsing is an experimental feature, behavior and APIs may change", ) signature = _parse_input_signature(compile_spec["input_signature"]) - info.input_signature = _C.InputSignature(signature) + info.input_signature = _C.InputSignature(signature) # py_object + + if not compile_spec["torch_fallback"]["enabled"]: + raise ValueError( + "Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" + ) + + log( + Level.Debug, + "Grouped inputs currently requires additional settings to enable the feature", + ) + log( + Level.Debug, + """Adding the following ops to torch_executed_ops: + - aten::__getitem__ + - prim::ListConstruct + - prim::ListUnpack + - prim::TupleIndex + - prim::TupleConstruct + - prim::TupleUnpack +""", + ) + compile_spec["torch_fallback"]["forced_fallback_ops"].append( + "aten::__getitem__" + ) + compile_spec["torch_fallback"]["forced_fallback_ops"].append( + "prim::ListConstruct" + ) + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") + compile_spec["torch_fallback"]["forced_fallback_ops"].append( + "prim::TupleConstruct" + ) + compile_spec["torch_fallback"]["forced_fallback_ops"].append( + "prim::TupleUnpack" + ) else: raise KeyError( diff --git a/tests/py/api/test_classes.py b/tests/py/api/test_classes.py index 861efd84a7..b9729b9d4d 100644 --- a/tests/py/api/test_classes.py +++ b/tests/py/api/test_classes.py @@ -103,7 +103,8 @@ def test_infer_from_example_tensor(self): example_tensor = torch.randn(shape).half() i = torchtrt.Input.from_tensor(example_tensor) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_static_shape(self): shape = [1, 3, 255, 255] @@ -118,22 +119,28 @@ def test_static_shape(self): } i = torchtrt.Input(shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(tuple(shape)) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(torch.randn(shape).shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=tuple(shape)) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=torch.randn(shape).shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_data_type(self): shape = [1, 3, 255, 255] @@ -148,10 +155,12 @@ def test_data_type(self): } i = torchtrt.Input(shape, dtype=torchtrt.dtype.half) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape, dtype=torch.half) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_tensor_format(self): shape = [1, 3, 255, 255] @@ -166,10 +175,12 @@ def test_tensor_format(self): } i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape, format=torch.channels_last) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_dynamic_shape(self): min_shape = [1, 3, 128, 128] @@ -188,14 +199,28 @@ def test_dynamic_shape(self): i = torchtrt.Input( min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape ) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input( min_shape=tuple(min_shape), opt_shape=tuple(opt_shape), max_shape=tuple(max_shape), ) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + self.assertTrue(self._verify_correctness(ts_i, target)) tensor_shape = lambda shape: torch.randn(shape).shape i = torchtrt.Input( @@ -203,7 +228,14 @@ def test_dynamic_shape(self): opt_shape=tensor_shape(opt_shape), max_shape=tensor_shape(max_shape), ) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + self.assertTrue(self._verify_correctness(ts_i, target)) class TestTRTModuleNext(unittest.TestCase): From b6338a0d3aeb982aa23a8efa14663f3ab3cbee61 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 4 Apr 2023 18:14:40 -0700 Subject: [PATCH 08/45] chore: Add dynamo backend tests to CircleCI Signed-off-by: Dheeraj Peri --- .circleci/config.yml | 294 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 294 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 91e6a71f7e..a6731200cf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -508,6 +508,7 @@ commands: - store_artifacts: path: /tmp/testlogs +# =================== FX tests start ======================== # test-fx_core: description: "Test the fx core" steps: @@ -707,6 +708,210 @@ commands: - store_artifacts: path: /tmp/testlogs +# =================== FX tests end ======================== # + +# =================== Dynamo tests start ======================== # + test-dynamo_core: + description: "Test the Dynamo core" + steps: + - run: + name: Run Dynamo core tests + command: | + cd py/torch_tensorrt/dynamo/test + pushd core/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/core/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_converters_acc: + description: "Test the Dynamo acc converters" + steps: + - run: + name: Run FX converter tests + command: | + cd py/torch_tensorrt/dynamo/test + pushd converters/acc_op/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/acc_op/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_converters_aten: + description: "Test the dynamo aten converters" + steps: + - run: + name: Run dynamo converter tests + command: | + cd py/torch_tensorrt/dynamo/test + pushd converters/aten_op/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/aten_op/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_converters_vanilla: + description: "Test the dynamo vanilla converters" + steps: + - run: + name: Run dynamo converter tests + command: | + cd py/torch_tensorrt/dynamo/test + pushd converters/vanilla/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/vanilla/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_passes: + description: "Test the dynamo passes" + steps: + - run: + name: Run dynamo passes + command: | + cd py/torch_tensorrt/dynamo/test + pushd passes + list_passes=$(ls | grep -v test_setitem*) + pytest $list_passes --junitxml=/tmp/artifacts/test_results/dynamo/passes/test_results.xml + popd + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_tools: + description: "Test the dynamo tools" + steps: + - run: + name: Run dynamo tools + command: | + cd py/torch_tensorrt/dynamo/test + pushd tools + pytest --junitxml=/tmp/artifacts/test_results/dynamo/tools/test_results.xml + popd + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_trt_lower: + description: "Test the dynamo TRT lowering" + steps: + - run: + name: Run dynamo TRT lowering + command: | + cd py/torch_tensorrt/dynamo/test + pushd trt_lower + pytest --junitxml=/tmp/artifacts/test_results/dynamo/trt_lower/test_results.xml + popd + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_tracer: + description: "Test all dynamo tracers" + steps: + - run: + name: Run dynamo tracer + command: | + cd py/torch_tensorrt/dynamo/test + pushd tracer + list_tracer=$(ls | grep -v test_dispatch_*) + pytest $list_tracer --junitxml=/tmp/artifacts/test_results/fx/tracer/test_results.xml + popd + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_tracer_acc: + description: "Test the dynamo acc tracer only" + steps: + - run: + name: Run dynamo tracer + command: | + cd py/torch_tensorrt/dynamo/test + pushd tracer + list_tracer=$(ls | grep test_acc) + pytest $list_tracer --junitxml=/tmp/artifacts/test_results/dynamo/tracer/test_results.xml + popd + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo_quant: + description: "Test the dynamo quant" + steps: + - run: + name: Run dynamo quant tests + command: | + cd py/torch_tensorrt/dynamo/test + pushd quant/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/quant/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo: + description: "Test the dynamo backend" + steps: + - run: + name: Run dynamo tests + command: | + mkdir -p /tmp/artifacts/test_results + - test-dynamo_converters_acc + - test-dynamo_converters_aten + - test-dynamo_converters_vanilla + - test-dynamo_passes + - test-dynamo_tools + - test-dynamo_trt_lower + - test-dynamo_tracer + - test-dynamo_core + - test-dynamo_quant + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo-no-aten: + description: "Test the dynamo backend without aten operators" + steps: + - run: + name: Run dynamo tests without aten ops + command: | + mkdir -p /tmp/artifacts/test_results + - test-dynamo_converters_acc + - test-dynamo_converters_vanilla + - test-dynamo_passes + - test-dynamo_tools + - test-dynamo_trt_lower + - test-dynamo_tracer_acc + - test-dynamo_core + - test-dynamo_quant + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + +# =================== Dynamo tests end ======================== # + # Define a job to be invoked later in a workflow. # See: https://circleci.com/docs/2.0/configuration-reference/#jobs jobs: @@ -883,6 +1088,68 @@ jobs: - dump-test-env - test-fx-no-aten + test-py-dynamo-x86_64-linux: + parameters: + torch-build: + type: string + torch-build-index: + type: string + trt-version-long: + type: string + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.large + steps: + - checkout + - attach_workspace: + at: /tmp/dist/ + - install-torch-from-index: + torch-build: << parameters.torch-build >> + torch-build-index: << parameters.torch-build-index >> + - create-py-env: + trt-version-long: << parameters.trt-version-long >> + - install-cudnn + # - run: + # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" + # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH + - run: + name: "Install torch-tensorrt" + command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl + # We install torch after torch-trt because pip automatically enforces the version constraint otherwise + - dump-test-env + - test-dynamo + + test-py-dynamo-x86_64-linux-no-aten: + parameters: + torch-build: + type: string + torch-build-index: + type: string + trt-version-long: + type: string + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.large + steps: + - checkout + - attach_workspace: + at: /tmp/dist/ + - install-torch-from-index: + torch-build: << parameters.torch-build >> + torch-build-index: << parameters.torch-build-index >> + - create-py-env: + trt-version-long: << parameters.trt-version-long >> + - install-cudnn + # - run: + # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" + # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH + - run: + name: "Install torch-tensorrt" + command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl + # We install torch after torch-trt because pip automatically enforces the version constraint otherwise + - dump-test-env + - test-dynamo-no-aten + package-x86_64-linux: parameters: enabled: @@ -1261,6 +1528,13 @@ workflows: requires: - build-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - build-x86_64-linux + - build-x86_64-linux: name: build-x86_64-linux-legacy torch-build: << pipeline.parameters.torch-build-legacy >> @@ -1291,6 +1565,12 @@ workflows: requires: - build-x86_64-linux-legacy + - test-py-dynamo-x86_64-linux-no-aten: + torch-build: << pipeline.parameters.torch-build-legacy >> + torch-build-index: << pipeline.parameters.torch-build-index-legacy >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - build-x86_64-linux-legacy release: when: << pipeline.parameters.enable-packaging >> jobs: @@ -1328,6 +1608,13 @@ workflows: requires: - package-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - package-x86_64-linux + on-push: jobs: - build-x86_64-linux: @@ -1357,6 +1644,13 @@ workflows: requires: - build-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - build-x86_64-linux + - build-x86_64-linux-cmake: torch-build: << pipeline.parameters.torch-build >> torch-build-index: << pipeline.parameters.torch-build-index >> From c4a03ff34717be25395cf9b69c4ed1f7d41a1b3e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 4 Apr 2023 18:18:36 -0700 Subject: [PATCH 09/45] chore: Linter fixes Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/input_tensor_spec.py | 1 + .../dynamo/passes/lower_pass_manager_builder.py | 8 ++++++-- py/torch_tensorrt/dynamo/test/core/test_input.py | 1 + .../passes/test_fix_clamp_numerical_limits_to_fp16.py | 4 +++- .../dynamo/test/trt_lower/test_observer_gpu.py | 4 +--- py/torch_tensorrt/dynamo/tools/trt_minimizer.py | 4 +--- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/input_tensor_spec.py b/py/torch_tensorrt/dynamo/input_tensor_spec.py index 1e64c31c59..3eb9a115af 100644 --- a/py/torch_tensorrt/dynamo/input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/input_tensor_spec.py @@ -6,6 +6,7 @@ from .utils import get_dynamic_dims from torch_tensorrt._Input import Input + class InputTensorSpec(NamedTuple): """ This class contains the information of a input tensor. diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py index d79c0e77b4..c3c69e0117 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py @@ -268,7 +268,9 @@ def build_trt_lower_pipeline( elif isinstance(input_obj, torch.Tensor): self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) else: - raise ValueError("Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor") + raise ValueError( + "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" + ) self._additional_input = additional_input passes = [] @@ -294,7 +296,9 @@ def build_aten2trt_lower_pipeline( elif isinstance(input_obj, torch.Tensor): self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) else: - raise ValueError("Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor") + raise ValueError( + "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" + ) self._additional_input = additional_input passes = [] diff --git a/py/torch_tensorrt/dynamo/test/core/test_input.py b/py/torch_tensorrt/dynamo/test/core/test_input.py index efe323c691..869482dbef 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_input.py +++ b/py/torch_tensorrt/dynamo/test/core/test_input.py @@ -7,6 +7,7 @@ import torch_tensorrt from torch.testing._internal.common_utils import run_tests, TestCase + class TestInput(TestCase): def test_add_model(self): class TestModule(torch.nn.Module): diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py index 91d21b7fd0..0a4805d45f 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py @@ -3,7 +3,9 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch_tensorrt.dynamo.passes.lower_basic_pass import fix_clamp_numerical_limits_to_fp16 +from torch_tensorrt.dynamo.passes.lower_basic_pass import ( + fix_clamp_numerical_limits_to_fp16, +) _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py index 1898f10980..352ecd062a 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py @@ -33,9 +33,7 @@ def forward(self, x, y): with execution_verifier() as verify_execution: - lowerer = lower.Lowerer.create( - lower_setting=LowerSetting(min_block_size=0) - ) + lowerer = lower.Lowerer.create(lower_setting=LowerSetting(min_block_size=0)) @verify_execution def observe_fuse_permute_linear_post(ctx: ob.ObserveContext): diff --git a/py/torch_tensorrt/dynamo/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/tools/trt_minimizer.py index 78b2f252bb..f4886fab22 100644 --- a/py/torch_tensorrt/dynamo/tools/trt_minimizer.py +++ b/py/torch_tensorrt/dynamo/tools/trt_minimizer.py @@ -83,9 +83,7 @@ def run_a(self, mod, inputs): def run_b(self, mod, inputs): mod.eval() try: - mod = self.lower_fn( - mod, inputs, self.use_experiemental_rt - ) + mod = self.lower_fn(mod, inputs, self.use_experiemental_rt) output = mod(*inputs) except RuntimeError as e: raise net_min_base.FxNetMinimizerRunFuncError( From f8ad31ac0bf3269d40e59adb688947b8a3b0f553 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 4 Apr 2023 18:55:04 -0700 Subject: [PATCH 10/45] chore: add missing ts_input.py file Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/ts/ts_input.py | 108 +++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 py/torch_tensorrt/ts/ts_input.py diff --git a/py/torch_tensorrt/ts/ts_input.py b/py/torch_tensorrt/ts/ts_input.py new file mode 100644 index 0000000000..00055d4f13 --- /dev/null +++ b/py/torch_tensorrt/ts/ts_input.py @@ -0,0 +1,108 @@ +from enum import Enum +from typing import List, Dict, Any, Tuple, Optional + +import torch + +from torch_tensorrt import _C +from torch_tensorrt import _enums +from torch_tensorrt import _Input +from torch_tensorrt._Input import Input + + +class TSInput(Input): + """ + Defines an input to a module in terms of expected shape, data type and tensor format. + + Attributes: + shape_mode (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped + shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. + Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form + ``{ + "min_shape": Tuple, + "opt_shape": Tuple, + "max_shape": Tuple + }`` + dtype (torch_tensorrt.dtype): The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) + format (torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + """ + + def __init__(self, *args, **kwargs): + """__init__ Method for torch_tensorrt.Input + + Input accepts one of a few construction patterns + + Args: + shape (Tuple or List, optional): Static shape of input tensor + + Keyword Arguments: + shape (Tuple or List, optional): Static shape of input tensor + min_shape (Tuple or List, optional): Min size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + opt_shape (Tuple or List, optional): Opt size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + max_shape (Tuple or List, optional): Max size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + dtype (torch.dtype or torch_tensorrt.dtype): Expected data type for input tensor (default: torch_tensorrt.dtype.float32) + format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). + Note: Entering "None" (or not specifying) will set the bound to [0, 2) + + Examples: + - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) + - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) + - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW + """ + super(TSInput, self).__init__(*args, **kwargs) + + def _to_internal(self) -> _C.Input: + internal_in = _C.Input() + if self.shape_mode == Input._ShapeMode.DYNAMIC: + if not Input._supported_input_size_type(self.shape["min_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["min_shape"])) + + " for min_shape" + ) + else: + internal_in.min = self.shape["min_shape"] + + if not Input._supported_input_size_type(self.shape["opt_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["opt_shape"])) + + " for opt_shape" + ) + else: + internal_in.opt = self.shape["opt_shape"] + + if not Input._supported_input_size_type(self.shape["max_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["max_shape"])) + + " for max_shape" + ) + else: + internal_in.max = self.shape["max_shape"] + internal_in.input_is_dynamic = True + else: + if not Input._supported_input_size_type(self.shape): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape)) + + " for shape" + ) + else: + internal_in.opt = self.shape + internal_in.input_is_dynamic = False + + if self.dtype != _enums.dtype.unknown: + self._explicit_set_dtype = True + else: + self._explicit_set_dtype = False + + internal_in.dtype = Input._parse_dtype(self.dtype) + internal_in._explicit_set_dtype = self._explicit_set_dtype + internal_in.format = Input._parse_format(self.format) + + internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) + return internal_in From 3cbf72c1ab57af0cbf2f238aa0073475166c4b4d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 5 Apr 2023 19:26:44 -0700 Subject: [PATCH 11/45] chore: refactoring Signed-off-by: Dheeraj Peri --- py/setup.py | 57 ++++++++++--------- py/torch_tensorrt/__init__.py | 2 +- py/torch_tensorrt/_compile.py | 24 ++------ py/torch_tensorrt/dynamo/lower.py | 11 +++- .../test/converters/acc_op/test_silu.py | 2 +- py/torch_tensorrt/ts/__init__.py | 1 + 6 files changed, 47 insertions(+), 50 deletions(-) diff --git a/py/setup.py b/py/setup.py index 3c73aee3b5..de4ae92529 100644 --- a/py/setup.py +++ b/py/setup.py @@ -375,6 +375,36 @@ def run(self): with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() +if FX_ONLY: + package_data_list = ["_Input.py",] +else: + package_data_list = [ + "lib/*", + "include/torch_tensorrt/*.h", + "include/torch_tensorrt/core/*.h", + "include/torch_tensorrt/core/conversion/*.h", + "include/torch_tensorrt/core/conversion/conversionctx/*.h", + "include/torch_tensorrt/core/conversion/converters/*.h", + "include/torch_tensorrt/core/conversion/evaluators/*.h", + "include/torch_tensorrt/core/conversion/tensorcontainer/*.h", + "include/torch_tensorrt/core/conversion/var/*.h", + "include/torch_tensorrt/core/ir/*.h", + "include/torch_tensorrt/core/lowering/*.h", + "include/torch_tensorrt/core/lowering/passes/*.h", + "include/torch_tensorrt/core/partitioning/*.h", + "include/torch_tensorrt/core/partitioning/segmentedblock/*.h", + "include/torch_tensorrt/core/partitioning/partitioninginfo/*.h", + "include/torch_tensorrt/core/partitioning/partitioningctx/*.h", + "include/torch_tensorrt/core/plugins/*.h", + "include/torch_tensorrt/core/plugins/impl/*.h", + "include/torch_tensorrt/core/runtime/*.h", + "include/torch_tensorrt/core/util/*.h", + "include/torch_tensorrt/core/util/logging/*.h", + "bin/*", + "BUILD", + "WORKSPACE", + ] + setup( name="torch_tensorrt", version=__version__, @@ -418,32 +448,7 @@ def run(self): python_requires=">=3.7", include_package_data=True, package_data={ - "torch_tensorrt": [ - "lib/*", - "include/torch_tensorrt/*.h", - "include/torch_tensorrt/core/*.h", - "include/torch_tensorrt/core/conversion/*.h", - "include/torch_tensorrt/core/conversion/conversionctx/*.h", - "include/torch_tensorrt/core/conversion/converters/*.h", - "include/torch_tensorrt/core/conversion/evaluators/*.h", - "include/torch_tensorrt/core/conversion/tensorcontainer/*.h", - "include/torch_tensorrt/core/conversion/var/*.h", - "include/torch_tensorrt/core/ir/*.h", - "include/torch_tensorrt/core/lowering/*.h", - "include/torch_tensorrt/core/lowering/passes/*.h", - "include/torch_tensorrt/core/partitioning/*.h", - "include/torch_tensorrt/core/partitioning/segmentedblock/*.h", - "include/torch_tensorrt/core/partitioning/partitioninginfo/*.h", - "include/torch_tensorrt/core/partitioning/partitioningctx/*.h", - "include/torch_tensorrt/core/plugins/*.h", - "include/torch_tensorrt/core/plugins/impl/*.h", - "include/torch_tensorrt/core/runtime/*.h", - "include/torch_tensorrt/core/util/*.h", - "include/torch_tensorrt/core/util/logging/*.h", - "bin/*", - "BUILD", - "WORKSPACE", - ], + "torch_tensorrt": package_data_list, }, exclude_package_data={ "": ["*.cpp"], diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 3261265215..360d6f2dbe 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -93,7 +93,7 @@ def _find_lib(name, paths): from torch_tensorrt._TRTModuleNext import TRTModuleNext from torch_tensorrt import fx - +from torch_tensorrt import dynamo def _register_with_torch(): trtorch_dir = os.path.dirname(__file__) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index e75f28ff22..4aea98852c 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,6 +1,6 @@ from typing import List, Dict, Any -from torch_tensorrt import _enums import torch_tensorrt.ts + from torch_tensorrt import logging import torch import torch.fx @@ -78,7 +78,7 @@ def compile( module: Any, ir="default", inputs=[], - enabled_precisions=set([_enums.dtype.float]), + enabled_precisions=set([torch.float]), **kwargs, ): """Compile a PyTorch module for NVIDIA GPUs using TensorRT @@ -153,24 +153,8 @@ def compile( **kwargs, ) elif target_ir == _IRType.dynamo: - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - lower_precision = LowerPrecision.FP16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - lower_precision = LowerPrecision.FP32 - else: - raise ValueError(f"Precision {enabled_precisions} not supported on FX") - return torch_tensorrt.dynamo.compile( - module, - inputs, - lower_precision=lower_precision, - **kwargs, + module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") @@ -181,7 +165,7 @@ def convert_method_to_trt_engine( method_name: str, ir="default", inputs=[], - enabled_precisions=set([_enums.dtype.float]), + enabled_precisions=set([torch.float]), **kwargs, ): """Convert a TorchScript module method to a serialized TensorRT engine diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py index 75e0de9fe8..a5aa18926b 100644 --- a/py/torch_tensorrt/dynamo/lower.py +++ b/py/torch_tensorrt/dynamo/lower.py @@ -29,9 +29,9 @@ def compile( module: nn.Module, inputs, + enabled_precisions=set(), min_block_size: int = 10, max_workspace_size=1 << 25, - lower_precision=LowerPrecision.FP16, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, @@ -48,7 +48,6 @@ def compile( input: Input for module. min_block_size: Minimal number of nodes for an accelerated submodule max_workspace_size: Maximum size of workspace given to TensorRT. - lower_precision: lower_precision config given to TRTModule. verbose_log: Enable verbose log for TensorRT if set True. timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. @@ -62,6 +61,14 @@ def compile( "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True" ) + lower_precision = LowerPrecision.FP32 + if torch.float16 in enabled_precisions: + lower_precision = LowerPrecision.FP16 + elif torch.float32 in enabled_precisions: + lower_precision = LowerPrecision.FP32 + else: + raise ValueError(f"Precision {enabled_precisions} not supported on FX") + lower_setting = LowerSetting( min_block_size=min_block_size, max_workspace_size=max_workspace_size, diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py index 684b2247e8..38d8f5b645 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py @@ -2,7 +2,7 @@ from torch import nn from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec -from torch_tensorrt.dynamo.tracer.acc_tracer import acc_ops +from torch_tensorrt.fx.tracer.acc_tracer import acc_ops class TestSilu(AccTestCase): diff --git a/py/torch_tensorrt/ts/__init__.py b/py/torch_tensorrt/ts/__init__.py index ddee197aef..47ef249e55 100644 --- a/py/torch_tensorrt/ts/__init__.py +++ b/py/torch_tensorrt/ts/__init__.py @@ -1,2 +1,3 @@ from torch_tensorrt.ts._compiler import * from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec +from torch_tensorrt.ts.ts_input import TSInput From 516248b51174b3e510c11071c4062966cf30c800 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 5 Apr 2023 22:44:42 -0700 Subject: [PATCH 12/45] fix: Revamp implementation and replace inplace ops - Update implementation to use Dynamo partition functionality - Update implementation to use Dynamo decompositions to replace inplace operators - Name backends using standard names - Add documentation, print statements, and helper functions to the code --- .../tensorrt_dynamo_backend.py | 202 +++++++++++++----- 1 file changed, 145 insertions(+), 57 deletions(-) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index 55c5e2df33..dad3c81a1b 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -2,31 +2,137 @@ import traceback import torch._dynamo as td +from typing import Dict + from torch_tensorrt.fx.fx2trt import ( InputTensorSpec, TRTInterpreter, ) +from torch._dynamo.backends.common import fake_tensor_unsupported import tensorrt as trt -from torch_tensorrt.fx.tools.trt_splitter import ( - TRTSplitter, - TRTSplitterSetting, -) -from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch_tensorrt.fx.converter_registry import CONVERTERS + from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision -from torch._dynamo.backends.common import fake_tensor_unsupported - from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._decomp import register_decomposition, core_aten_decompositions + + +DECOMPOSITIONS = {**core_aten_decompositions()} +MAX_NUM_TRT_ENGINES = 10 + +aten = torch.ops.aten + + +def replace_inplace_op(aten_op, outplace_op): + """Replace inplace operation with functional equivalent + Adapted from: + https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 + """ + + @register_decomposition(aten_op, registry=DECOMPOSITIONS) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +replace_inplace_op(aten.add_, aten.add) +replace_inplace_op(aten.addbmm_, aten.addbmm) +replace_inplace_op(aten.addmm_, aten.addmm) +replace_inplace_op(aten.addmv_, aten.addmv) +replace_inplace_op(aten.baddbmm_, aten.baddbmm) +replace_inplace_op(aten.cumprod_, aten.cumprod) +replace_inplace_op(aten.fill_, aten.fill) +replace_inplace_op(aten.gelu_, aten.gelu) +replace_inplace_op(aten.hardsigmoid_, aten.hardsigmoid) +replace_inplace_op(aten.index_put_, aten.index_put) +replace_inplace_op(aten.index_reduce_, aten.index_reduce) +replace_inplace_op(aten.logit_, aten.logit) +replace_inplace_op(aten.relu_, aten.relu) +replace_inplace_op(aten.renorm_, aten.renorm) +replace_inplace_op(aten.round_, aten.round) +replace_inplace_op(aten.scatter_, aten.scatter) +replace_inplace_op(aten.scatter_add_, aten.scatter_add) +replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) + + +class TorchTensorRTOperatorSupport(OperatorSupport): + """Class to determine whether the aten operators have converters""" + + def __init__(self, support_dict=None): + super().__init__(support_dict) + + # Initialize sets of supported/unsupported operators + self.supported_operators = set() + self.unsupported_operators = set() + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + if node.target in CONVERTERS.keys(): + # If node is a proper computational node, store the operator + if not node.is_impure(): + node_name = node._pretty_print_target(node.target) + self.supported_operators.add(node_name) + + return True + else: + if not node.is_impure(): + node_name = node._pretty_print_target(node.target) + self.unsupported_operators.add(node_name) + + return False + + def print_support_overview(self, num_trt_blocks=None): + if num_trt_blocks is not None: + print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n") + + print("Supported Nodes:") + for node_name in self.supported_operators: + print(node_name) + + print("\nUnsupported Nodes:") + for node_name in self.unsupported_operators: + print(node_name) + + +def partition(gm: torch.fx.GraphModule, verbose=True): + """Partition an FX GraphModule with aten ops into TRT engines + + Partitioning is based on operator support + """ + supported_ops = TorchTensorRTOperatorSupport() + partitioner = CapabilityBasedPartitioner(gm, supported_ops) + + # Determine partitions, and raise error if the degree of partitioning + # exceeds a specified threshold + partitions = partitioner.propose_partitions() + num_blocks = len(partitions) + if num_blocks > MAX_NUM_TRT_ENGINES: + raise AssertionError( + f"The graph module has {num_blocks} TRT Engines which is larger than the " + + f"threshold={MAX_NUM_TRT_ENGINES}. Falling back to non-TRT module." + ) -from torch._inductor.decomposition import decompositions + # Fuse partitions and display overview of supported/unsupported operators + fused_graph = partitioner.fuse_partitions(partitions) + num_blocks = len(partitions) -DECOMPOSITIONS = decompositions.copy() -MAX_SPLITS_THRESHOLD = 10 + if verbose: + supported_ops.print_support_overview(num_blocks) + return fused_graph + +@td.register_backend(name="tensorrt") +@fake_tensor_unsupported def tensorrt_backend(gm, sample_inputs): - # Invoke AOTAutograd to compile model + # Invoke AOTAutograd to translate operators to aten return aot_module_simplified( gm, sample_inputs, @@ -36,35 +142,12 @@ def tensorrt_backend(gm, sample_inputs): def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): - model = gm - inputs = example_inputs - - # Perform lowering pass on model - model = aten_tracer.opt_trace(model, inputs, perform_trace=False) - - # Split out unsupported ops --> Needs rewrite/revision for ATEN - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter = TRTSplitter(model, inputs, settings=splitter_setting) - - splitter.node_support_preview() - split_mod = splitter() - num_pieces = 0 - - for name, _ in split_mod.named_children(): - print(f"Graph is split into {name}") - num_pieces += 1 - - # Select threshold above which segmentation is not beneficial and run graph in Torch - if num_pieces > MAX_SPLITS_THRESHOLD: - raise AssertionError( - f"The graph module is split into {num_pieces} which is large than the \ - threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." - ) + partitioned = partition(gm) precision = LowerPrecision.FP32 def get_submod_inputs(mod, submod, inputs): + """Helper function to get inputs to submodule""" acc_inputs = None def get_input(self, inputs): @@ -76,39 +159,44 @@ def get_input(self, inputs): handle.remove() return acc_inputs - for name, _ in split_mod.named_children(): - if "_run_on_acc" in name: - submod = getattr(split_mod, name) - acc_inputs = get_submod_inputs(split_mod, submod, inputs) + for name, _ in partitioned.named_children(): + submod = getattr(partitioned, name) - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - logger_level=trt.Logger.VERBOSE, - ) - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, - ) + # Get submodule inputs + acc_inputs = get_submod_inputs(partitioned, submod, example_inputs) - trt_mod = TRTModule(*r) + # Create TRT Module from submodule + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + trt_mod = TRTModule(*r) - setattr(split_mod, name, trt_mod) + # Replace FX Module with TRT Module + setattr(partitioned, name, trt_mod) - return split_mod + return partitioned -@td.register_backend +@td.register_backend(name="fx_tensorrt") @fake_tensor_unsupported def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): + """Helper function to manage translation of FX module to TRT engines""" try: trt_compiled = fx2trt(gm, example_inputs) return trt_compiled - except Exception: + except: traceback.print_exc() print( - "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" + "FX2TRT conversion failed on the subgraph. See trace above. " + + "Returning GraphModule forward instead." ) return gm.forward From 152bf43e0900a62d41b46c5d6f27c5628b0ac3c5 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Wed, 5 Apr 2023 22:58:27 -0700 Subject: [PATCH 13/45] Undo changes to aten tracer --- .../fx/tracer/dispatch_tracer/aten_tracer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 356ddc978e..e60c8f8d13 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -130,7 +130,7 @@ def trace(f, args, *rest): @req_torch_version("2.dev") -def opt_trace(f, args, perform_trace=True, *rest): +def opt_trace(f, args, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose @@ -148,11 +148,7 @@ def opt_trace(f, args, perform_trace=True, *rest): replace_inplace_ops, # remove it once functionalization is enabled ] - if perform_trace: - fx_module, _ = trace(f, args) - else: - fx_module = f - + fx_module, _ = trace(f, args) print(fx_module.graph) for passes in passes_list: pr: PassResult = passes(fx_module) From cb6e946861ef34abcedcf34f3e522bb00f4596be Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 6 Apr 2023 14:04:51 -0700 Subject: [PATCH 14/45] chore: Fix tests and remove implicit batch dim support Signed-off-by: Dheeraj Peri --- py/setup.py | 4 +++- py/torch_tensorrt/__init__.py | 1 + py/torch_tensorrt/dynamo/fx2trt.py | 2 +- py/torch_tensorrt/dynamo/tools/common_fx2trt.py | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/py/setup.py b/py/setup.py index de4ae92529..200a43c533 100644 --- a/py/setup.py +++ b/py/setup.py @@ -376,7 +376,9 @@ def run(self): long_description = fh.read() if FX_ONLY: - package_data_list = ["_Input.py",] + package_data_list = [ + "_Input.py", + ] else: package_data_list = [ "lib/*", diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 360d6f2dbe..3dd8748ee6 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -95,6 +95,7 @@ def _find_lib(name, paths): from torch_tensorrt import fx from torch_tensorrt import dynamo + def _register_with_torch(): trtorch_dir = os.path.dirname(__file__) torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx2trt.py index 4140a344f0..b062a62a37 100644 --- a/py/torch_tensorrt/dynamo/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx2trt.py @@ -37,7 +37,7 @@ def __init__( self, module: torch.fx.GraphModule, input_specs: List[InputTensorSpec], - explicit_batch_dimension: bool = False, + explicit_batch_dimension: bool = True, explicit_precision: bool = False, logger_level=None, ): diff --git a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py index be99562455..5157ef67fa 100644 --- a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py @@ -244,7 +244,7 @@ def run_test( unexpected_ops=None, apply_passes=None, test_explicit_batch_dim=True, - test_implicit_batch_dim=True, + test_implicit_batch_dim=False, test_explicit_precision=False, rtol=1e-03, atol=1e-03, From 6153d21368978b6012939a860504ceacc57b93d9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 6 Apr 2023 14:17:37 -0700 Subject: [PATCH 15/45] chore: fix tests Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/ts/_compile_spec.py | 35 --------------------------- 1 file changed, 35 deletions(-) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 76d3605a1d..b29d386118 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -265,41 +265,6 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: signature = _parse_input_signature(compile_spec["input_signature"]) info.input_signature = _C.InputSignature(signature) # py_object - if not compile_spec["torch_fallback"]["enabled"]: - raise ValueError( - "Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" - ) - - log( - Level.Debug, - "Grouped inputs currently requires additional settings to enable the feature", - ) - log( - Level.Debug, - """Adding the following ops to torch_executed_ops: - - aten::__getitem__ - - prim::ListConstruct - - prim::ListUnpack - - prim::TupleIndex - - prim::TupleConstruct - - prim::TupleUnpack -""", - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "aten::__getitem__" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::ListConstruct" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") - compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::TupleConstruct" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::TupleUnpack" - ) - else: raise KeyError( 'Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to "inputs" in the compile spec' From 5ad35e3deb4ab92e42188fc882163986da83fd29 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 6 Apr 2023 15:26:23 -0700 Subject: [PATCH 16/45] chore: remove softmax test using implicit dim Signed-off-by: Dheeraj Peri --- .../test/converters/acc_op/test_softmax.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py index 5c8b9ed58b..e64996eef6 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py @@ -60,22 +60,5 @@ def forward(self, x): Softmax(), input_specs, expected_ops={acc_ops.softmax} ) - def test_softmax_with_implicit_batch_dim0_fail(self): - class Softmax(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return nn.functional.softmax(x, dim=0) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test_with_assert_error( - Softmax(), - inputs, - expect_error=AssertionError, - test_explicit_batch_dim=False, - ) - - if __name__ == "__main__": run_tests() From f76d2b61082f51c9baee79157ceb1d6b809df022 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 6 Apr 2023 16:35:32 -0700 Subject: [PATCH 17/45] fix: Remove unmodified files from FX, add device support Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/README.md | 28 +- py/torch_tensorrt/dynamo/__init__.py | 1 - py/torch_tensorrt/dynamo/diagnostics.py | 287 ------------------ py/torch_tensorrt/dynamo/lower.py | 8 +- py/torch_tensorrt/dynamo/lower_setting.py | 3 +- py/torch_tensorrt/dynamo/observer.py | 194 ------------ .../dynamo/passes/lower_basic_pass.py | 2 +- .../passes/lower_pass_manager_builder.py | 4 +- py/torch_tensorrt/dynamo/passes/pass_utils.py | 4 +- .../dynamo/test/trt_lower/test_diagnostics.py | 2 +- .../dynamo/test/trt_lower/test_observer.py | 4 +- .../test/trt_lower/test_observer_gpu.py | 2 +- py/torch_tensorrt/dynamo/trt_module.py | 239 --------------- 13 files changed, 27 insertions(+), 751 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/diagnostics.py delete mode 100644 py/torch_tensorrt/dynamo/observer.py delete mode 100644 py/torch_tensorrt/dynamo/trt_module.py diff --git a/py/torch_tensorrt/dynamo/README.md b/py/torch_tensorrt/dynamo/README.md index d53f43a1d4..d2a9e295a3 100644 --- a/py/torch_tensorrt/dynamo/README.md +++ b/py/torch_tensorrt/dynamo/README.md @@ -1,21 +1,13 @@ -FX2TRT is merged as FX module in Torch-TensorRT +The code in this directory is similar to `torch_tensorrrt.fx`. We intend to make changes under `dynamo` namespace to ensure we +have the same top level API as `torch_tensorrt.ts.compile`. Right now, the usage is as follows -- The user guide is in [link](../../../docsrc/tutorials/getting_started_with_fx_path.rst#installation) -- The examples are moved to [link](../../../examples/fx) - -* Method 1. Follow the instrucions for Torch-TensorRT -* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path ``` - $ conda create --name python_env python=3.8 - $ conda activate python_env - # Recommend to install PyTorch 1.12 and later - $ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly - # Install TensorRT python package - $ pip3 install nvidia-pyindex - $ pip3 install tensorrt==8.5.1.7 - $ git clone https://github.com/pytorch/TensorRT.git - $ cd TensorRT/py && python setup.py install --fx-only && cd .. - $ pyton -c "import torch_tensorrt.fx" - # Test an example by - $ python py/torch_tensorrt/fx/example/lower_example.py +import torch_tensorrt +trt_module = torch_tensorrt.compile( + module, + ir="dynamo" + torchtrt_inputs, + enabled_precisions={torch.float32}, + ) ``` +This will internally call `torch_tensorrt.dynamo.compile` which has the same signature as `torch_tensorrt.ts.compile`. We intend to add features (existing in Torchscript backend for eg: torch_executed_ops, torch_executed_modules and many more) to this dynamo backend in the coming months. diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 8c40ecac76..85ce01ef20 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -9,7 +9,6 @@ from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa from .input_tensor_spec import InputTensorSpec # noqa from .lower_setting import LowerSetting # noqa -from .trt_module import TRTModule # noqa from .lower import compile # usort: skip #noqa logging.basicConfig(level=logging.INFO) diff --git a/py/torch_tensorrt/dynamo/diagnostics.py b/py/torch_tensorrt/dynamo/diagnostics.py deleted file mode 100644 index 0ba2a30652..0000000000 --- a/py/torch_tensorrt/dynamo/diagnostics.py +++ /dev/null @@ -1,287 +0,0 @@ -import contextlib -import inspect -import logging -import os -import os.path -import shutil -import tempfile -import time -import traceback -import typing as t -from contextvars import ContextVar -from dataclasses import dataclass - -TWrite = t.Union[str, bytes] -WriteObj = t.Union[TWrite, t.Callable[[], TWrite]] - -_CURRENT_WRITER: ContextVar["DiagnosticsWriter"] = ContextVar("_CURRENT_WRITER") -_CURRENT_COLLECTOR: ContextVar["DiagnosticsCollector"] = ContextVar( - "_CURRENT_COLLECTOR" -) -# Allows a collector to indicate subsequent collections should be suppressed to -# avoid duplicate collections. -_SUBSEQUENT_COLLECT_SUPPRESSED_BY: ContextVar[object] = ContextVar( - "_SUBSEQUENT_COLLECT_SUPPRESSED_BY" -) -# Indicates current execution context is within a context manager by -# `collect_when`. Only when it's set do we actually write diagnostics. -_IS_IN_COLLECT_CONTEXT: ContextVar[bool] = ContextVar("_IS_IN_COLLECT_CONTEXT") -_LOGGER = logging.getLogger(__name__) - - -@dataclass -class CollectionConditionContext: - exception: t.Optional[Exception] - - -CollectionCondition = t.Callable[[CollectionConditionContext], bool] - - -def collect_when( - condition: "CollectionCondition", supress_subsequent_collect: bool = True -): - """See `DiagnosticsCollector.collect_when`""" - return get_current_collector().collect_when(condition, supress_subsequent_collect) - - -def collect(): - return collect_when(CollectionConditions.always()) - - -def collect_when_fail(): - return collect_when(CollectionConditions.when_fail()) - - -def write(file_name: str, text: WriteObj): - return get_current_writer().write(file_name, text) - - -def get_current_writer() -> "DiagnosticsWriter": - """Get the writer for current execution context. - - Lazily instantiates and registers one if not already done. - """ - current_writer = _CURRENT_WRITER.get(None) - if not current_writer: - current_writer = DiagnosticsWriter() - _CURRENT_WRITER.set(current_writer) - return current_writer - - -def get_current_collector() -> "DiagnosticsCollector": - current_collector = _CURRENT_COLLECTOR.get(None) - if not current_collector: - current_collector = DiagnosticsCollector() - _CURRENT_COLLECTOR.set(current_collector) - return current_collector - - -def set_current_collector(collector: "DiagnosticsCollector"): - _CURRENT_COLLECTOR.set(collector) - - -class DiagnosticsWriter: - - # the root dir in which the diagnostics will be written - _root_dir: str - - def __init__(self): - self._root_dir = tempfile.mkdtemp(prefix="fx2trt.") - _LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}") - - def write(self, file_name: str, data: WriteObj): - """ - TODO: Can be disabled by regex on file_name - """ - # Only write if we are inside a collect_when() context. - if not _IS_IN_COLLECT_CONTEXT.get(False): - return - - try: - res, err = _res_or_err(data) - if err: - to_write = err.encode("utf-8") - else: - if isinstance(res, str): - to_write = res.encode("utf-8") - elif isinstance(res, bytes): - to_write = res - else: - raise TypeError(f"Unknown data type: {type(res)}") - self._write(file_name, to_write) - except Exception as e: - # Log the error and swallow the exception, as this should not - # propagated into business logic - _LOGGER.warning(f"Error writing diagnostics: {e}") - - def root_dir(self) -> str: - return self._root_dir - - def _write(self, file_name: str, to_write: bytes): - # ms granularity - no naming collash, otherwise file will be - # overwritten. - ts = int(time.time() * 1000) - file_name = f"{file_name}.{ts}" - fn = os.path.join(self.root_dir(), file_name) - with open(fn, "wb") as f: - f.write(to_write) - - -class CollectionConditions: - @classmethod - def any(cls, *conditions: "CollectionCondition") -> "CollectionCondition": - return lambda ctx: any(cond(ctx) for cond in conditions) - - @classmethod - def all(cls, *conditions: "CollectionCondition") -> "CollectionCondition": - return lambda ctx: all(cond(ctx) for cond in conditions) - - @classmethod - def not_(cls, condition: "CollectionCondition") -> "CollectionCondition": - return lambda ctx: not condition(ctx) - - @classmethod - def always(cls) -> "CollectionCondition": - """Always collect""" - return lambda ctx: True - - @classmethod - def never(cls) -> "CollectionCondition": - """Never collect""" - return lambda ctx: False - - @classmethod - def when_fail(cls) -> "CollectionCondition": - """Collect when failed""" - ctx: CollectionConditionContext - return lambda ctx: ctx.exception is not None - - @classmethod - def when_called_by_function( - cls, func_name: str, match_prefix: bool = False - ) -> "CollectionCondition": - def _when_called_by_function(ctx: CollectionConditionContext) -> bool: - frames = inspect.stack() - for frame in frames: - if match_prefix: - if frame[3].startswith(func_name): - return True - else: - if frame[3] == func_name: - return True - return False - - return _when_called_by_function - - @classmethod - def when_not_in_tests(cls) -> CollectionCondition: - return CollectionConditions.not_( - CollectionConditions.when_called_by_function("test_", match_prefix=True) - ) - - -class DiagnosticsCollector: - @contextlib.contextmanager - def collect_when( - self, condition: "CollectionCondition", supress_subsequent_collect: bool = True - ): - """ - Context manager to collect diagnostics when the enclosed code completes - and *any* of the given condition is met. - - Args: - condition: - the condition only when met should the collection be done - supress_subsequent_collect: - When true, suppress any collections registered by this function - call. This is to ensure duplicate collections registered across - the callstack by different components. In this case, only the - outermost component will collect. - - When false, always collect (subject to given condition) regardless - of earlier collection registration's suppression. - - Returns: - a context manager that handles the collection when its enclosed - code finished run. - """ - this_collection_handle = object() - suppressed_by = _SUBSEQUENT_COLLECT_SUPPRESSED_BY.get(None) - reset_suppressed_by = False - if supress_subsequent_collect: - if suppressed_by and suppressed_by != this_collection_handle: - # Disable this collection since it's suppressed by a previously - # installed collection - condition = CollectionConditions.never() - else: - suppressed_by = this_collection_handle - _SUBSEQUENT_COLLECT_SUPPRESSED_BY.set(suppressed_by) - # don't forget to reset it in `finanlly` - reset_suppressed_by = True - - is_in_collect_context_tok = _IS_IN_COLLECT_CONTEXT.set(True) - exception: t.Optional[Exception] = None - try: - yield - except Exception as e: - exception = e - raise - finally: - if reset_suppressed_by: - _SUBSEQUENT_COLLECT_SUPPRESSED_BY.set(None) - if self._test_condition(condition, CollectionConditionContext(exception)): - try: - self.collect() - except Exception as e: - _LOGGER.warning( - f"Error while collecting diagnostics (THIS EXCEPTION IS HANDLED):\n" - f"{e}\n" - f"{traceback.format_exc()}" - ) - _IS_IN_COLLECT_CONTEXT.reset(is_in_collect_context_tok) - - def collect(self) -> str: - """Collect the diagnostics. Overridable in sub-classes.""" - return "" - - @classmethod - def _test_condition( - cls, cond: CollectionCondition, ctx: CollectionConditionContext - ) -> bool: - try: - return cond(ctx) - except Exception as e: - _LOGGER.warning(f"Error while testing condition: {e}") - return False - - -class ZipDiagnosticsCollector(DiagnosticsCollector): - _write: DiagnosticsWriter - _last_zip_path_for_test: str = "" # for test purpose only - - def __init__(self, writer: DiagnosticsWriter): - self._write = writer - - def collect(self) -> str: - _, fp = tempfile.mkstemp() - try: - zip_path = shutil.make_archive(fp, "zip", self._write.root_dir()) - self._last_zip_path_for_test = zip_path - return zip_path - finally: - os.remove(fp) - - -def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]: - if isinstance(data, (str, bytes)): - return data, "" - if not callable(data): - raise TypeError( - f"data must be a callable that returns actual data to" - f"write, but got {type(data)}" - ) - try: - return data(), "" - except Exception as e: - _LOGGER.warning(f"Error getting data to write: {e}") - return "", str(e) diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py index a5aa18926b..fbe6ec4b43 100644 --- a/py/torch_tensorrt/dynamo/lower.py +++ b/py/torch_tensorrt/dynamo/lower.py @@ -18,7 +18,7 @@ from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer -from .trt_module import TRTModule +from torch_tensorrt.fx.trt_module import TRTModule from .utils import LowerPrecision logger = logging.getLogger(__name__) @@ -29,8 +29,9 @@ def compile( module: nn.Module, inputs, + device=torch.device(torch.cuda.current_device()), enabled_precisions=set(), - min_block_size: int = 10, + min_block_size: int = 3, max_workspace_size=1 << 25, verbose_log=False, timing_cache_prefix="", @@ -70,6 +71,7 @@ def compile( raise ValueError(f"Precision {enabled_precisions} not supported on FX") lower_setting = LowerSetting( + device=device, min_block_size=min_block_size, max_workspace_size=max_workspace_size, lower_precision=lower_precision, @@ -274,10 +276,12 @@ def __call__( lower_setting = self.lower_pass_manager_builder.lower_setting atol = lower_setting.correctness_atol rtol = lower_setting.correctness_rtol + device = lower_setting.device @validate_inference( atol=atol, rtol=rtol, + device=device, ) def do_lower(module: nn.Module, inputs: Input) -> nn.Module: module.eval() diff --git a/py/torch_tensorrt/dynamo/lower_setting.py b/py/torch_tensorrt/dynamo/lower_setting.py index d72709c75b..83771d9a69 100644 --- a/py/torch_tensorrt/dynamo/lower_setting.py +++ b/py/torch_tensorrt/dynamo/lower_setting.py @@ -1,6 +1,6 @@ import dataclasses as dc from typing import List, Optional, Set, Type - +import torch from torch import nn from torch.fx.passes.pass_manager import PassManager @@ -24,6 +24,7 @@ class LowerSettingBasic: """ lower_precision: LowerPrecision = LowerPrecision.FP32 + device: torch.device = torch.device(torch.cuda.current_device()) min_block_size: int = 3 ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None leaf_module_list: Optional[Set[Type[nn.Module]]] = None diff --git a/py/torch_tensorrt/dynamo/observer.py b/py/torch_tensorrt/dynamo/observer.py deleted file mode 100644 index 3742bd2840..0000000000 --- a/py/torch_tensorrt/dynamo/observer.py +++ /dev/null @@ -1,194 +0,0 @@ -import contextlib -import functools -import logging -import traceback -import typing as t -from contextvars import ContextVar -from dataclasses import dataclass, field - -_LOGGER = logging.getLogger(__name__) - -# A context variable to hold registered callbacks for all the observers for the -# current execution context. The callbacks list could have been a member -# variable on the observer instance, however, contextvars document advice -# against creating context variables not at module-global level. -# https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar -_CALLBACKS: ContextVar[t.Dict["Observer", t.List[t.Callable]]] = ContextVar( - "_CALLBACKS", default=None -) - -TObserverCallback = t.TypeVar("TObserverCallback", bound=t.Callable[..., t.Any]) - -# Whether to rethrow the exception caught while calling observer callbacks. -# Default to False. True is only used during tests. -RETHROW_CALLBACK_EXCEPTION: bool = False - - -@dataclass(frozen=True) -class Observer(t.Generic[TObserverCallback]): - """ - Usage: - - >>> some_observer: Observer = ... - >>> with some_observer.add(callback_func): - >>> # do stuff, and when some_observer.observe() is called, - >>> # it will execute callback_func() - >>> ... - - """ - - name: str = "" - # Ensure each Observer instance is considered a distinct key when stored in - # the `_CALLBACKS` dictionary. - unique_id: object = field(default_factory=lambda: object()) - - def add(self, callback: TObserverCallback) -> t.ContextManager: - self._get_callbacks().append(callback) - - # Cannot decorate the outer `add` directly with `contextmanager`, - # because if it were not used with a `with` statement, its body won't - # be executed. - @contextlib.contextmanager - def _add(): - try: - yield - finally: - try: - self._get_callbacks().remove(callback) - except ValueError: - # Callback should be in the callbacks list. I'm just being - # extra cautious here. I don't want it to throw and affect - # business logic. - pass - - return _add() - - def observe(self, *args, **kwargs) -> None: - for callback in self._get_callbacks(): - with _log_error( - "Error calling observer callback", rethrow=RETHROW_CALLBACK_EXCEPTION - ): - callback(*args, **kwargs) - - def _get_callbacks(self) -> t.List[t.Callable]: - """ - Gets the callbacks registered in current execution context. Any code - that manipulates the returned list (add, remove, iterate) is - concurrency safe. - """ - callbacks_dict = _CALLBACKS.get() - if callbacks_dict is None: - callbacks_dict = {} - _CALLBACKS.set(callbacks_dict) - - if self not in callbacks_dict: - callbacks_dict[self] = [] - - return callbacks_dict[self] - - -@dataclass(frozen=True) -class ObserveContext: - """ - Passed to the registered callables that observes any function decorated by - `observable`. See `observable` for detail. - - Attributes: - callable: the observed callable object - args: the args passed to the callable - kwargs: the kwargs passed to the callable - return_value: the return value returned by the callable, only available - when observing the callable after its invocation (via - `CallableObservers.post`) - """ - - callable: t.Callable - args: t.List[t.Any] - kwargs: t.Mapping[str, t.Any] - return_value: t.Any = None - - -def observable(): - """ - A decorator to turn a function into observable - - Example: - - >>> @observable() - >>> def func_to_observe(x, y) -> int: - >>> ... - >>> - >>> def log(ctx: ObserveContext): - >>> print( - >>> f"called {ctx.callable.__name__} with {ctx.args} {ctx.kwargs}" - >>> ) - >>> - >>> # register: - >>> with func_to_observe.observers.pre.add(log): - >>> func_to_observe(1, 2) - >>> # print out "called func_to_observe with (1,2) - >>> # here it won't print - """ - - def decorator(observed_func: callable) -> ObservedCallable: - wrapped_func = _make_observable(orig_func=observed_func) - return functools.wraps(observed_func)(wrapped_func) - - return decorator - - -@dataclass(frozen=True) -class CallableObservers: - pre: Observer[t.Callable[[ObserveContext], None]] - post: Observer[t.Callable[[ObserveContext], None]] - - -class ObservedCallable: - """ - Interface for an observed callable - """ - - observers: CallableObservers - orig_func: callable - - def __call__(self, *args, **kwargs) -> t.Any: - raise NotImplementedError() - - -def _make_observable(orig_func: t.Callable) -> ObservedCallable: - """ - A wrapper for a callable which is to be observed. - """ - - observers = CallableObservers( - pre=Observer(), - post=Observer(), - ) - - @functools.wraps(orig_func) - def observed_func(*args, **kwargs): - observers.pre.observe(ObserveContext(orig_func, args, kwargs)) - return_value = None - try: - return_value = orig_func(*args, **kwargs) - return return_value - finally: - observers.post.observe( - ObserveContext(orig_func, args, kwargs, return_value) - ) - - observed_func.orig_func = orig_func - observed_func.observers = observers - - return observed_func - - -@contextlib.contextmanager -def _log_error(msg: str, rethrow: bool = False) -> t.ContextManager: - try: - yield - except Exception as e: - _e = e # noqa: F841 - _LOGGER.info(f"{msg} (This error is handled): {traceback.format_exc()}") - if rethrow: - raise diff --git a/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py b/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py index 3fa4f69bc5..12aedafce5 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py @@ -10,7 +10,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch.fx.experimental.const_fold import split_const_subgraphs -from ..observer import observable +from torch_tensorrt.fx.observer import observable from torch_tensorrt.fx.tracer.acc_tracer import acc_ops from torch_tensorrt.fx.tracer.acc_tracer.acc_utils import get_attr diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py index c3c69e0117..b0e49ff825 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py @@ -13,7 +13,7 @@ from ..input_tensor_spec import InputTensorSpec from ..lower_setting import LowerSetting -from ..observer import Observer +from torch_tensorrt.fx.observer import Observer from ..passes.remove_duplicate_output_args import remove_duplicate_output_args from .graph_opts import common_subexpression_elimination from .pass_utils import extract_example_tensors_from_input @@ -260,7 +260,7 @@ def build_trt_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: - self._input = extract_example_tensors_from_input(input) + self._input = extract_example_tensors_from_input(input, self.lower_setting.device) self._trt_input = [] for input_obj in input: if isinstance(input_obj, _Input.Input): diff --git a/py/torch_tensorrt/dynamo/passes/pass_utils.py b/py/torch_tensorrt/dynamo/passes/pass_utils.py index 3fdd4c7541..3d010ad6c8 100644 --- a/py/torch_tensorrt/dynamo/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/passes/pass_utils.py @@ -125,7 +125,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None): +def validate_inference(rtol=None, atol=None, device=torch.device(torch.cuda.current_device())): def _validate_inference(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to validate that its inference results before and @@ -139,7 +139,7 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - input_tensors = extract_example_tensors_from_input(input) + input_tensors = extract_example_tensors_from_input(input, device) res0 = module(*input_tensors) processed_module = pass_(module, input, *args, **kwargs) res1 = processed_module(*input_tensors) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py index 709973ae22..3ce3b7ade8 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py @@ -8,7 +8,7 @@ from typing import Union from unittest import TestCase -import torch_tensorrt.dynamo.diagnostics as diag +import torch_tensorrt.fx.diagnostics as diag _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py index 185f3acc04..58f23c0a13 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py @@ -5,8 +5,8 @@ from contextlib import contextmanager from unittest import TestCase -import torch_tensorrt.dynamo.observer as ob -from torch_tensorrt.dynamo.observer import observable +import torch_tensorrt.fx.observer as ob +from torch_tensorrt.fx.observer import observable _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py index 352ecd062a..b067e93195 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py @@ -2,7 +2,7 @@ import functools from unittest import TestCase -import torch_tensorrt.dynamo.observer as ob +import torch_tensorrt.fx.observer as ob from test_observer import execution_verifier, set_observer_callback_rethrow from torch_tensorrt.dynamo.passes.lower_basic_pass import fuse_permute_linear diff --git a/py/torch_tensorrt/dynamo/trt_module.py b/py/torch_tensorrt/dynamo/trt_module.py deleted file mode 100644 index 099bbfcdc9..0000000000 --- a/py/torch_tensorrt/dynamo/trt_module.py +++ /dev/null @@ -1,239 +0,0 @@ -from typing import Any, List, Sequence - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt -import torch - -from .utils import torch_dtype_from_trt - - -class TRTModule(torch.nn.Module): - def __init__( - self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1 - ): - super(TRTModule, self).__init__() - self._register_state_dict_hook(TRTModule._on_state_dict) - self.engine = engine - self.input_names = input_names - self.output_names = output_names - self.cuda_graph_batch_size = cuda_graph_batch_size - self.initialized = False - - if engine: - self._initialize() - - def _initialize(self): - self.initialized = True - self.context = self.engine.create_execution_context() - - # Indices of inputs/outputs in the trt engine bindings, in the order - # as they are in the original PyTorch model. - self.input_binding_indices_in_order: Sequence[int] = [ - self.engine.get_binding_index(name) for name in self.input_names - ] - self.output_binding_indices_in_order: Sequence[int] = [ - self.engine.get_binding_index(name) for name in self.output_names - ] - primary_input_outputs = set() - primary_input_outputs.update(self.input_binding_indices_in_order) - primary_input_outputs.update(self.output_binding_indices_in_order) - self.hidden_output_binding_indices_in_order: Sequence[int] = [] - self.hidden_output_names: Sequence[str] = [] - for i in range( - self.engine.num_bindings // self.engine.num_optimization_profiles - ): - if i not in primary_input_outputs: - self.hidden_output_binding_indices_in_order.append(i) - self.hidden_output_names.append(self.engine.get_binding_name(i)) - - assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( - len(self.input_names) - + len(self.output_names) - + len(self.hidden_output_names) - ) - - self.input_dtypes: Sequence[torch.dtype] = [ - torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) - for idx in self.input_binding_indices_in_order - ] - self.input_shapes: Sequence[Sequence[int]] = [ - tuple(self.engine.get_binding_shape(idx)) - for idx in self.input_binding_indices_in_order - ] - self.output_dtypes: Sequence[torch.dtype] = [ - torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) - for idx in self.output_binding_indices_in_order - ] - self.output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - for idx in self.output_binding_indices_in_order - ] - self.hidden_output_dtypes: Sequence[torch.dtype] = [ - torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) - for idx in self.hidden_output_binding_indices_in_order - ] - self.hidden_output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - for idx in self.hidden_output_binding_indices_in_order - ] - - def _check_initialized(self): - if not self.initialized: - raise RuntimeError("TRTModule is not initialized.") - - def _on_state_dict(self, state_dict, prefix, local_metadata): - self._check_initialized() - state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) - state_dict[prefix + "input_names"] = self.input_names - state_dict[prefix + "output_names"] = self.output_names - state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - engine_bytes = state_dict[prefix + "engine"] - - logger = trt.Logger() - runtime = trt.Runtime(logger) - self.engine = runtime.deserialize_cuda_engine(engine_bytes) - - self.input_names = state_dict[prefix + "input_names"] - self.output_names = state_dict[prefix + "output_names"] - self._initialize() - - def __getstate__(self): - state = self.__dict__.copy() - state["engine"] = bytearray(self.engine.serialize()) - state.pop("context", None) - return state - - def __setstate__(self, state): - logger = trt.Logger() - runtime = trt.Runtime(logger) - state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) - self.__dict__.update(state) - if self.engine: - self.context = self.engine.create_execution_context() - - def forward(self, *inputs): - with torch.autograd.profiler.record_function("TRTModule:Forward"): - self._check_initialized() - - with torch.autograd.profiler.record_function("TRTModule:ProcessInputs"): - assert len(inputs) == len( - self.input_names - ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." - - # This is only used when the trt engine is using implicit batch dim. - batch_size = inputs[0].shape[0] - contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] - bindings: List[Any] = [None] * ( - len(self.input_names) - + len(self.output_names) - + len(self.hidden_output_names) - ) - - for i, input_name in enumerate(self.input_names): - assert inputs[ - i - ].is_cuda, f"{i}th input({input_name}) is not on cuda device." - assert ( - inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}." - - idx = self.input_binding_indices_in_order[i] - bindings[idx] = contiguous_inputs[i].data_ptr() - - if not self.engine.has_implicit_batch_dimension: - self.context.set_binding_shape( - idx, tuple(contiguous_inputs[i].shape) - ) - else: - assert inputs[i].size()[1:] == self.input_shapes[i], ( - f"Shape mismatch for {i}th input({input_name}). " - f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}." - ) - - with torch.autograd.profiler.record_function("TRTModule:ProcessOutputs"): - # create output tensors - outputs: List[torch.Tensor] = [] - - for i, idx in enumerate(self.output_binding_indices_in_order): - if self.engine.has_implicit_batch_dimension: - shape = (batch_size,) + self.output_shapes[i] - else: - shape = tuple(self.context.get_binding_shape(idx)) - - output = torch.empty( # type: ignore[call-overload] - size=shape, - dtype=self.output_dtypes[i], - device=torch.cuda.current_device(), - ) - outputs.append(output) - bindings[idx] = output.data_ptr() - - for i, idx in enumerate(self.hidden_output_binding_indices_in_order): - if self.engine.has_implicit_batch_dimension: - shape = (batch_size,) + self.hidden_output_shapes[i] - else: - shape = tuple(self.context.get_binding_shape(idx)) - - output = torch.empty( # type: ignore[call-overload] - size=shape, - dtype=self.hidden_output_dtypes[i], - device=torch.cuda.current_device(), - ) - bindings[idx] = output.data_ptr() - - with torch.autograd.profiler.record_function("TRTModule:TensorRTRuntime"): - if self.engine.has_implicit_batch_dimension: - self.context.execute_async( - batch_size, bindings, torch.cuda.current_stream().cuda_stream - ) - else: - self.context.execute_async_v2( - bindings, torch.cuda.current_stream().cuda_stream - ) - - if len(outputs) == 1: - return outputs[0] - - return tuple(outputs) - - def enable_profiling(self, profiler: "trt.IProfiler" = None): - """ - Enable TensorRT profiling. After calling this function, TensorRT will report - time spent on each layer in stdout for each forward run. - """ - self._check_initialized() - - if not self.context.profiler: - self.context.profiler = trt.Profiler() if profiler is None else profiler - - def disable_profiling(self): - """ - Disable TensorRT profiling. - """ - self._check_initialized() - - torch.cuda.synchronize() - del self.context - self.context = self.engine.create_execution_context() - - def get_layer_info(self) -> str: - """ - Get layer info of the engine. Only support for TRT > 8.2. - """ - inspector = self.engine.create_engine_inspector() - return inspector.get_engine_information(trt.LayerInformationFormat.JSON) From 0c5befd599d3b5bb2edc725897ad2542b054d76b Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 7 Apr 2023 10:59:35 -0700 Subject: [PATCH 18/45] fix: Refactor backend, add sample args --- py/torch_tensorrt/dynamo/__init__.py | 2 + py/torch_tensorrt/dynamo/_compiler.py | 114 ++++++++++ py/torch_tensorrt/dynamo/backends.py | 108 ++++++++++ py/torch_tensorrt/dynamo/lowering/__init__.py | 2 + .../dynamo/lowering/_decompositions.py | 45 ++++ .../dynamo/lowering/_partition.py | 78 +++++++ .../tensorrt_dynamo_backend.py | 202 ------------------ 7 files changed, 349 insertions(+), 202 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/__init__.py create mode 100644 py/torch_tensorrt/dynamo/_compiler.py create mode 100644 py/torch_tensorrt/dynamo/backends.py create mode 100644 py/torch_tensorrt/dynamo/lowering/__init__.py create mode 100644 py/torch_tensorrt/dynamo/lowering/_decompositions.py create mode 100644 py/torch_tensorrt/dynamo/lowering/_partition.py delete mode 100644 py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py new file mode 100644 index 0000000000..fd036ffa8e --- /dev/null +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -0,0 +1,2 @@ +from _compiler import compile +from backends import create_backend diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py new file mode 100644 index 0000000000..9804c691ff --- /dev/null +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -0,0 +1,114 @@ +import torch +import logging + +from torch_tensorrt import EngineCapability, Device + +from torch_tensorrt.dynamo.lowering._partition import partition +from torch_tensorrt.dynamo import create_backend + +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt + +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +logger = logging.getLogger(__name__) + + +def compile( + gm: torch.Module, + example_inputs, + *, + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=False, + capability=EngineCapability.default, + num_avg_timing_iters=1, + workspace_size=0, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + min_block_size=3, + torch_executed_ops=[], + torch_executed_modules=[], +): + custom_backend = create_backend( + device=device, + disable_tf32=disable_tf32, + sparse_weights=sparse_weights, + enabled_precisions=enabled_precisions, + refit=refit, + debug=debug, + capability=capability, + num_avg_timing_iters=num_avg_timing_iters, + workspace_size=workspace_size, + dla_sram_size=dla_sram_size, + dla_local_dram_size=dla_local_dram_size, + dla_global_dram_size=dla_global_dram_size, + calibrator=calibrator, + truncate_long_and_double=truncate_long_and_double, + require_full_compilation=require_full_compilation, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + torch_executed_modules=torch_executed_modules, + ) + + model = torch.compile(gm, backend=custom_backend) + # Ensure compilation + model(example_inputs) + + return model + + +def compile_logic(gm: torch.fx.GraphModule, example_inputs): + partitioned = partition(gm) + + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + """Helper function to get inputs to submodule""" + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in partitioned.named_children(): + submod = getattr(partitioned, name) + + # Get submodule inputs + acc_inputs = get_submod_inputs(partitioned, submod, example_inputs) + + # Create TRT Module from submodule + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + trt_mod = TRTModule(*r) + + # Replace FX Module with TRT Module + setattr(partitioned, name, trt_mod) + + return partitioned diff --git a/py/torch_tensorrt/dynamo/backends.py b/py/torch_tensorrt/dynamo/backends.py new file mode 100644 index 0000000000..2d010c7f2a --- /dev/null +++ b/py/torch_tensorrt/dynamo/backends.py @@ -0,0 +1,108 @@ +import torch +import logging +import traceback +from functools import partial +import torch._dynamo as td +from torch_tensorrt import EngineCapability, Device +from torch_tensorrt.dynamo import compile + +from torch._dynamo.backends.common import fake_tensor_unsupported + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + +from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions + +logger = logging.getLogger(__name__) + + +def create_backend( + input_signature=None, + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=False, + capability=EngineCapability.default, + num_avg_timing_iters=1, + workspace_size=20 << 30, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + min_block_size=3, + torch_executed_ops=[], + torch_executed_modules=[], +): + logger.warn( + "The Dynamo backend is an experimental feature, for which the " + + "following arguments are unsupported: " + + "{input_signature, disable_tf32, sparse_weights, refit, capability, " + + "num_avg_timing_iters, dla_sram_size, dla_local_dram_size, " + + "dla_global_dram_size, calibrator, truncate_long_and_double, " + + "require_full_compilation, min_block_size, torch_executed_ops, " + + "torch_executed_modules}" + ) + + return partial( + tensorrt_backend, + debug=debug, + enabled_precisions=enabled_precisions, + device=device, + workspace_size=workspace_size, + ) + + +@td.register_backend(name="tensorrt") +@fake_tensor_unsupported +def tensorrt_backend( + gm: torch.Module, + sample_inputs, + *, + debug=False, + enabled_precisions=set(), + device=Device._current_device(), + workspace_size=20 << 30, +): + + custom_backend = partial( + fx_dynamo_backend, + debug=debug, + enabled_precisions=enabled_precisions, + device=device, + workspace_size=workspace_size, + ) + + # Invoke AOTAutograd to translate operators to aten + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(custom_backend), + decompositions=get_decompositions(), + ) + + +@td.register_backend(name="fx_tensorrt") +@fake_tensor_unsupported +def fx_dynamo_backend( + gm: torch.fx.GraphModule, + example_inputs, + *, + debug=False, + enabled_precisions=set(), + device=Device._current_device(), + workspace_size=20 << 30, +): + """Helper function to manage translation of FX module to TRT engines""" + try: + trt_compiled = compile(gm, example_inputs) + return trt_compiled + except: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. " + + "Returning GraphModule forward instead." + ) + return gm.forward diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py new file mode 100644 index 0000000000..a57579d4ca --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -0,0 +1,2 @@ +from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.lowering._partition import partition diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py new file mode 100644 index 0000000000..7aff1a79d1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -0,0 +1,45 @@ +import torch +from torch._decomp import register_decomposition, core_aten_decompositions + + +DECOMPOSITIONS = {**core_aten_decompositions()} + +aten = torch.ops.aten + + +def replace_inplace_op(aten_op, outplace_op): + """Replace inplace operation with functional equivalent + Adapted from: + https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 + """ + + @register_decomposition(aten_op, registry=DECOMPOSITIONS) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +replace_inplace_op(aten.add_, aten.add) +replace_inplace_op(aten.addbmm_, aten.addbmm) +replace_inplace_op(aten.addmm_, aten.addmm) +replace_inplace_op(aten.addmv_, aten.addmv) +replace_inplace_op(aten.baddbmm_, aten.baddbmm) +replace_inplace_op(aten.cumprod_, aten.cumprod) +replace_inplace_op(aten.fill_, aten.fill) +replace_inplace_op(aten.gelu_, aten.gelu) +replace_inplace_op(aten.hardsigmoid_, aten.hardsigmoid) +replace_inplace_op(aten.index_put_, aten.index_put) +replace_inplace_op(aten.index_reduce_, aten.index_reduce) +replace_inplace_op(aten.logit_, aten.logit) +replace_inplace_op(aten.relu_, aten.relu) +replace_inplace_op(aten.renorm_, aten.renorm) +replace_inplace_op(aten.round_, aten.round) +replace_inplace_op(aten.scatter_, aten.scatter) +replace_inplace_op(aten.scatter_add_, aten.scatter_add) +replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) + + +def get_decompositions(): + return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/lowering/_partition.py new file mode 100644 index 0000000000..d96450f41d --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_partition.py @@ -0,0 +1,78 @@ +from typing import Dict + +import torch + +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport + +from torch_tensorrt.fx.converter_registry import CONVERTERS + + +MAX_NUM_TRT_ENGINES = 10 + + +class TorchTensorRTOperatorSupport(OperatorSupport): + """Class to determine whether the aten operators have converters""" + + def __init__(self, support_dict=None): + super().__init__(support_dict) + + # Initialize sets of supported/unsupported operators + self.supported_operators = set() + self.unsupported_operators = set() + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + if node.target in CONVERTERS.keys(): + # If node is a proper computational node, store the operator + if not node.is_impure(): + node_name = node._pretty_print_target(node.target) + self.supported_operators.add(node_name) + + return True + else: + if not node.is_impure(): + node_name = node._pretty_print_target(node.target) + self.unsupported_operators.add(node_name) + + return False + + def print_support_overview(self, num_trt_blocks=None): + if num_trt_blocks is not None: + print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n") + + print("Supported Nodes:") + for node_name in self.supported_operators: + print(node_name) + + print("\nUnsupported Nodes:") + for node_name in self.unsupported_operators: + print(node_name) + + +def partition(gm: torch.fx.GraphModule, verbose=True): + """Partition an FX GraphModule with aten ops into TRT engines + Partitioning is based on operator support + """ + supported_ops = TorchTensorRTOperatorSupport() + partitioner = CapabilityBasedPartitioner(gm, supported_ops) + + # Determine partitions, and raise error if the degree of partitioning + # exceeds a specified threshold + partitions = partitioner.propose_partitions() + num_blocks = len(partitions) + if num_blocks > MAX_NUM_TRT_ENGINES: + raise AssertionError( + f"The graph module has {num_blocks} TRT Engines which is larger than the " + + f"threshold={MAX_NUM_TRT_ENGINES}. Falling back to non-TRT module." + ) + + # Fuse partitions and display overview of supported/unsupported operators + fused_graph = partitioner.fuse_partitions(partitions) + num_blocks = len(partitions) + + if verbose: + supported_ops.print_support_overview(num_blocks) + + return fused_graph diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py deleted file mode 100644 index dad3c81a1b..0000000000 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ /dev/null @@ -1,202 +0,0 @@ -import torch -import traceback -import torch._dynamo as td - -from typing import Dict - -from torch_tensorrt.fx.fx2trt import ( - InputTensorSpec, - TRTInterpreter, -) -from torch._dynamo.backends.common import fake_tensor_unsupported -import tensorrt as trt -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from torch.fx.passes.operator_support import OperatorSupport -from torch_tensorrt.fx.converter_registry import CONVERTERS - -from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.fx.utils import LowerPrecision - -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -from torch._decomp import register_decomposition, core_aten_decompositions - - -DECOMPOSITIONS = {**core_aten_decompositions()} -MAX_NUM_TRT_ENGINES = 10 - -aten = torch.ops.aten - - -def replace_inplace_op(aten_op, outplace_op): - """Replace inplace operation with functional equivalent - Adapted from: - https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 - """ - - @register_decomposition(aten_op, registry=DECOMPOSITIONS) - def inplace_op(*args, **kwargs): - out = outplace_op(*args, **kwargs) - return args[0].copy_(out) - - return inplace_op - - -replace_inplace_op(aten.add_, aten.add) -replace_inplace_op(aten.addbmm_, aten.addbmm) -replace_inplace_op(aten.addmm_, aten.addmm) -replace_inplace_op(aten.addmv_, aten.addmv) -replace_inplace_op(aten.baddbmm_, aten.baddbmm) -replace_inplace_op(aten.cumprod_, aten.cumprod) -replace_inplace_op(aten.fill_, aten.fill) -replace_inplace_op(aten.gelu_, aten.gelu) -replace_inplace_op(aten.hardsigmoid_, aten.hardsigmoid) -replace_inplace_op(aten.index_put_, aten.index_put) -replace_inplace_op(aten.index_reduce_, aten.index_reduce) -replace_inplace_op(aten.logit_, aten.logit) -replace_inplace_op(aten.relu_, aten.relu) -replace_inplace_op(aten.renorm_, aten.renorm) -replace_inplace_op(aten.round_, aten.round) -replace_inplace_op(aten.scatter_, aten.scatter) -replace_inplace_op(aten.scatter_add_, aten.scatter_add) -replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) - - -class TorchTensorRTOperatorSupport(OperatorSupport): - """Class to determine whether the aten operators have converters""" - - def __init__(self, support_dict=None): - super().__init__(support_dict) - - # Initialize sets of supported/unsupported operators - self.supported_operators = set() - self.unsupported_operators = set() - - def is_node_supported( - self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: - if node.target in CONVERTERS.keys(): - # If node is a proper computational node, store the operator - if not node.is_impure(): - node_name = node._pretty_print_target(node.target) - self.supported_operators.add(node_name) - - return True - else: - if not node.is_impure(): - node_name = node._pretty_print_target(node.target) - self.unsupported_operators.add(node_name) - - return False - - def print_support_overview(self, num_trt_blocks=None): - if num_trt_blocks is not None: - print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n") - - print("Supported Nodes:") - for node_name in self.supported_operators: - print(node_name) - - print("\nUnsupported Nodes:") - for node_name in self.unsupported_operators: - print(node_name) - - -def partition(gm: torch.fx.GraphModule, verbose=True): - """Partition an FX GraphModule with aten ops into TRT engines - - Partitioning is based on operator support - """ - supported_ops = TorchTensorRTOperatorSupport() - partitioner = CapabilityBasedPartitioner(gm, supported_ops) - - # Determine partitions, and raise error if the degree of partitioning - # exceeds a specified threshold - partitions = partitioner.propose_partitions() - num_blocks = len(partitions) - if num_blocks > MAX_NUM_TRT_ENGINES: - raise AssertionError( - f"The graph module has {num_blocks} TRT Engines which is larger than the " - + f"threshold={MAX_NUM_TRT_ENGINES}. Falling back to non-TRT module." - ) - - # Fuse partitions and display overview of supported/unsupported operators - fused_graph = partitioner.fuse_partitions(partitions) - num_blocks = len(partitions) - - if verbose: - supported_ops.print_support_overview(num_blocks) - - return fused_graph - - -@td.register_backend(name="tensorrt") -@fake_tensor_unsupported -def tensorrt_backend(gm, sample_inputs): - # Invoke AOTAutograd to translate operators to aten - return aot_module_simplified( - gm, - sample_inputs, - fw_compiler=make_boxed_compiler(fx2trt_compiler), - decompositions=DECOMPOSITIONS, - ) - - -def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): - partitioned = partition(gm) - - precision = LowerPrecision.FP32 - - def get_submod_inputs(mod, submod, inputs): - """Helper function to get inputs to submodule""" - acc_inputs = None - - def get_input(self, inputs): - nonlocal acc_inputs - acc_inputs = inputs - - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs - - for name, _ in partitioned.named_children(): - submod = getattr(partitioned, name) - - # Get submodule inputs - acc_inputs = get_submod_inputs(partitioned, submod, example_inputs) - - # Create TRT Module from submodule - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - logger_level=trt.Logger.VERBOSE, - ) - - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, - ) - trt_mod = TRTModule(*r) - - # Replace FX Module with TRT Module - setattr(partitioned, name, trt_mod) - - return partitioned - - -@td.register_backend(name="fx_tensorrt") -@fake_tensor_unsupported -def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): - """Helper function to manage translation of FX module to TRT engines""" - try: - trt_compiled = fx2trt(gm, example_inputs) - return trt_compiled - except: - traceback.print_exc() - print( - "FX2TRT conversion failed on the subgraph. See trace above. " - + "Returning GraphModule forward instead." - ) - return gm.forward From a4047d29e4629151f411bf3133f023c14f59a7df Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 7 Apr 2023 11:46:10 -0700 Subject: [PATCH 19/45] chore: refactoring Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/fx2trt.py | 2 +- py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py | 4 +++- py/torch_tensorrt/dynamo/passes/pass_utils.py | 4 +++- .../dynamo/test/converters/acc_op/test_softmax.py | 1 + py/torch_tensorrt/dynamo/tools/common_fx2trt.py | 3 ++- py/torch_tensorrt/dynamo/tools/trt_minimizer.py | 4 +++- py/torch_tensorrt/dynamo/tools/trt_splitter.py | 2 +- 7 files changed, 14 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx2trt.py index b062a62a37..4dae6e542b 100644 --- a/py/torch_tensorrt/dynamo/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx2trt.py @@ -15,7 +15,7 @@ from torch_tensorrt.dynamo import CONVERTERS from .input_tensor_spec import InputTensorSpec -from .observer import Observer +from torch_tensorrt.fx.observer import Observer from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py index b0e49ff825..494f3922de 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py @@ -260,7 +260,9 @@ def build_trt_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: - self._input = extract_example_tensors_from_input(input, self.lower_setting.device) + self._input = extract_example_tensors_from_input( + input, self.lower_setting.device + ) self._trt_input = [] for input_obj in input: if isinstance(input_obj, _Input.Input): diff --git a/py/torch_tensorrt/dynamo/passes/pass_utils.py b/py/torch_tensorrt/dynamo/passes/pass_utils.py index 3d010ad6c8..96fa96cfae 100644 --- a/py/torch_tensorrt/dynamo/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/passes/pass_utils.py @@ -125,7 +125,9 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None, device=torch.device(torch.cuda.current_device())): +def validate_inference( + rtol=None, atol=None, device=torch.device(torch.cuda.current_device()) +): def _validate_inference(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to validate that its inference results before and diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py index e64996eef6..eab632c296 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py +++ b/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py @@ -60,5 +60,6 @@ def forward(self, x): Softmax(), input_specs, expected_ops={acc_ops.softmax} ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py index 5157ef67fa..13dcbcda37 100644 --- a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py @@ -8,11 +8,12 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch_tensorrt.fx import TRTModule from torch.fx.experimental.normalize import NormalizeArgs from torch.fx.passes import shape_prop from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase -from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter, TRTModule +from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter from torch_tensorrt.dynamo.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, diff --git a/py/torch_tensorrt/dynamo/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/tools/trt_minimizer.py index f4886fab22..f5c15b049b 100644 --- a/py/torch_tensorrt/dynamo/tools/trt_minimizer.py +++ b/py/torch_tensorrt/dynamo/tools/trt_minimizer.py @@ -5,7 +5,9 @@ import torch.fx.passes.net_min_base as net_min_base from torch.fx.passes.tools_common import Tensors -from .. import InputTensorSpec, TRTInterpreter, TRTModule +from .. import InputTensorSpec, TRTInterpreter + +from torch_tensorrt.fx import TRTModule _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/tools/trt_splitter.py b/py/torch_tensorrt/dynamo/tools/trt_splitter.py index bea925453f..c48f6d4e7d 100644 --- a/py/torch_tensorrt/dynamo/tools/trt_splitter.py +++ b/py/torch_tensorrt/dynamo/tools/trt_splitter.py @@ -11,8 +11,8 @@ NO_EXPLICIT_BATCH_DIM_SUPPORT, NO_IMPLICIT_BATCH_DIM_SUPPORT, TRTInterpreter, - TRTModule, ) +from torch_tensorrt.fx import TRTModule from ..tools.trt_minimizer import TensorRTMinimizer From cd4660d28c0b1a1650e3f2bcdade27d5661cf550 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 7 Apr 2023 14:36:25 -0700 Subject: [PATCH 20/45] chore: refactoring Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/lower_setting.py | 2 +- py/torch_tensorrt/dynamo/passes/graph_opts.py | 74 -- .../dynamo/passes/lower_basic_pass.py | 632 ------------------ .../dynamo/passes/lower_basic_pass_aten.py | 525 --------------- .../passes/lower_pass_manager_builder.py | 6 +- .../passes/remove_duplicate_output_args.py | 140 ---- .../test/core/test_input_tensor_spec.py | 28 +- .../dynamo/test/core/test_trt_module.py | 6 +- ...test_fix_clamp_numerical_limits_to_fp16.py | 2 +- .../test/passes/test_fix_reshape_batch_dim.py | 2 +- .../passes/test_fuse_permute_linear_trt.py | 104 +-- .../passes/test_fuse_permute_matmul_trt.py | 2 +- .../dynamo/test/passes/test_graph_opts.py | 2 +- .../dynamo/test/passes/test_multi_fuse_trt.py | 2 +- .../test_remove_duplicate_output_args.py | 2 +- .../dynamo/test/passes/test_setitem_trt.py | 2 +- .../dynamo/test/quant/test_quant_trt.py | 5 +- .../test/trt_lower/test_fx2trt_lower.py | 2 +- .../test/trt_lower/test_observer_gpu.py | 2 +- .../dynamo/tools/common_fx2trt.py | 2 +- py/torch_tensorrt/dynamo/utils.py | 2 +- 21 files changed, 73 insertions(+), 1471 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/passes/graph_opts.py delete mode 100644 py/torch_tensorrt/dynamo/passes/lower_basic_pass.py delete mode 100644 py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py delete mode 100644 py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py diff --git a/py/torch_tensorrt/dynamo/lower_setting.py b/py/torch_tensorrt/dynamo/lower_setting.py index 83771d9a69..9ca22d5ef8 100644 --- a/py/torch_tensorrt/dynamo/lower_setting.py +++ b/py/torch_tensorrt/dynamo/lower_setting.py @@ -5,7 +5,7 @@ from torch.fx.passes.pass_manager import PassManager from .input_tensor_spec import InputTensorSpec -from .passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul +from torch_tensorrt.fx.passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul from .utils import LowerPrecision diff --git a/py/torch_tensorrt/dynamo/passes/graph_opts.py b/py/torch_tensorrt/dynamo/passes/graph_opts.py deleted file mode 100644 index 2adc5c7fe3..0000000000 --- a/py/torch_tensorrt/dynamo/passes/graph_opts.py +++ /dev/null @@ -1,74 +0,0 @@ -from collections.abc import Sequence - -import torch -import torch.fx - - -def common_subexpression_elimination(graph_module: torch.fx.GraphModule) -> bool: - """ - Optimize quantization by removing repeated subexpressions. - - Args: - graph_module(torch.fx.GraphModule): target module to be optimized - - Returns: - Graph changed or not. - """ - - def seq_hashable(seq): - if seq is None: - return None - - items = [] - for old in seq: - if isinstance(old, Sequence) and not isinstance(old, str): - new = seq_hashable(old) - elif isinstance(old, dict): - new = dict_hashable(old) - elif isinstance(old, slice): - new = old.__reduce__() - else: - new = old - - items.append(new) - - return tuple(items) - - def dict_hashable(d): - if d is None: - return None - - items = [] - for k, old_v in d.items(): - if isinstance(old_v, Sequence): - new_v = seq_hashable(old_v) - elif isinstance(old_v, dict): - new_v = dict_hashable(old_v) - elif isinstance(old_v, slice): - new_v = old_v.__reduce__() - else: - new_v = old_v - - items.append((k, new_v)) - return tuple(sorted(items)) - - changed = False - env = {} - for n in graph_module.graph.nodes: - # do not CSE away impure ops - if n.op not in {"call_function", "call_method"} or n.is_impure(): - continue - - # hash target, args, kwargs - hash_val = (n.target, seq_hashable(n.args), dict_hashable(n.kwargs)) - - # check if a node has a substitute and can be eliminated - if hash_val in env: - n.replace_all_uses_with(env[hash_val]) - graph_module.graph.erase_node(n) - changed = True - continue - - env[hash_val] = n - - return changed diff --git a/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py b/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py deleted file mode 100644 index 12aedafce5..0000000000 --- a/py/torch_tensorrt/dynamo/passes/lower_basic_pass.py +++ /dev/null @@ -1,632 +0,0 @@ -import copy -import logging -import operator -import warnings -from typing import Any, Optional - -import torch -import torch.fx -import torch.fx as fx -import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -from torch.fx.experimental.const_fold import split_const_subgraphs - -from torch_tensorrt.fx.observer import observable - -from torch_tensorrt.fx.tracer.acc_tracer import acc_ops -from torch_tensorrt.fx.tracer.acc_tracer.acc_utils import get_attr -from .pass_utils import log_before_after, validate_inference - -_LOGGER = logging.getLogger(__name__) - -# Create an alias for module input type to avoid littering pyre-ignore for Any -# throughout the file. -Input = Any - - -def replace_mutable_op(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - if not isinstance(module, torch.fx.GraphModule): - return module - - # Before any lowering pass, replace mutable ops like torch.fill_ - # Because fx cannot deal with inplace ops - for n in module.graph.nodes: - # TODO: add more mutable ops - if (n.op == "call_method" and n.target == "fill_") or ( - n.op == "call_function" and n.target == torch.fill_ - ): - # Replace mutable op only if the modified variable - # is used by the rest of the graph - # only through this op - if set(n.args[0].users.keys()) == {n}: - with module.graph.inserting_after(n): - - # TODO: move this outside? - def fill_with_mul_zero_and_add(*args): - return args[0].mul(0.0).add(args[1]) - - new_node = module.graph.create_node( - "call_function", fill_with_mul_zero_and_add, args=n.args - ) - n.replace_all_uses_with(new_node) - module.graph.erase_node(n) - module.recompile() - return module - - -def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: - # Now we do constant folding on traced module. We want to skip pattern like - # weights -> quant -> dequant -> op during constant folding when the model is - # a quantized int8 model. - def skip_folding_quant_dequant(node: torch.fx.Node): - if node.target != acc_ops.quantize_per_tensor: - return False - # If quantize_per_node -> dequantize, then skip folding. - for user in node.users: - if user.target == acc_ops.dequantize: - return True - return False - - const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant) - const_split_mod.run_folding() - return const_split_mod - - -def replace_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for n in module.graph.nodes: - if n.op == "call_function" and n.target in ( - torch.ops.aten.max_pool2d_with_indices.default, - torch.ops.aten.max_pool3d_with_indices.default, - torch.ops.aten.native_batch_norm.default, - ): - if len(n.users) != 1: - raise RuntimeError( - f"{n.target} has users={len(n.users)}. We can only handle it with 1 user" - ) - if n.target == torch.ops.aten.max_pool2d_with_indices.default: - new_op = torch.ops.aten.max_pool2d - new_args = n.args - elif n.target == torch.ops.aten.max_pool3d_with_indices.default: - new_op = torch.ops.aten.max_pool3d - new_args = n.args - elif n.target == torch.ops.aten.native_batch_norm.default: - new_op = torch.ops.aten.batch_norm - new_args = list(n.args) - new_args.append(False) - new_args = tuple(new_args) - - getitem_node = next(iter(n.users)) - with module.graph.inserting_after(getitem_node): - new_node = module.graph.create_node( - "call_function", - new_op, - args=new_args, - kwargs=n.kwargs, - ) - getitem_node.replace_all_uses_with(new_node) - module.graph.erase_node(getitem_node) - module.graph.eliminate_dead_code() - module.recompile() - return module - - -@log_before_after -@validate_inference(atol=1e-3, rtol=1e-2) -def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input): - """ - Replace acc_ops.matmul + acc_ops.add with acc_ops.linear - TRT8.2 can take advantage of structured sparsity (2:4), but the graph needs contain a single FC layer. - Later versions of TRT should work with matmul. - - Example before: - def forward(self, x): - a = self.a - b = self.b - addmm_mm = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None - addmm_add = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None - return addmm_add - - After: - def forward(self, x): - a = self.a - b = self.b - linear_1 = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None - return linear_1 - """ - counter = 0 - for node in gm.graph.nodes: - if node.target != acc_ops.add: - continue - add_node = node - bias = add_node.kwargs["other"] - - if bias.op != "get_attr": - continue - # test that bias tensor is one-dimensional, should correspond to shape (out_features) - if get_attr(bias).dim() > 1: - continue - - node = add_node.kwargs["input"] - if node.target != acc_ops.matmul: - continue - matmul_node = node - a = matmul_node.kwargs["input"] - - node = matmul_node.kwargs["other"] - if node.op != "get_attr": - continue - - get_attr_node = node - weight = get_attr(get_attr_node) - # TODO: verify that weight comply with TRT structured sparsity requirements: - # For each output channel and for each spatial pixel in the kernel weights, - # every 4 input channels must have at least 2 zeros. - - # test that weight tensor is two-dimensional, should correspond to shape (out_features, in_features) - if weight.dim() != 2: - continue - - weight_t = weight.transpose(0, 1) - weight_t_name = "weight_t_tensor_" + str(counter) - gm.register_buffer(weight_t_name, weight_t) - counter += 1 - - with gm.graph.inserting_before(add_node): - weight_t_attr = gm.graph.get_attr(weight_t_name) - fused_node = gm.graph.call_function( - acc_ops.linear, - kwargs={"input": a, "weight": weight_t_attr, "bias": bias}, - ) - add_node.replace_all_uses_with(fused_node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - return gm - - -def trt_transposed_matmul( - lhs: torch.Tensor, rhs: torch.Tensor, lhs_transposed: bool, rhs_transposed: bool -): - if lhs_transposed: - lhs = lhs.transpose(-1, -2) - if rhs_transposed: - rhs = rhs.transpose(-1, -2) - return torch.matmul(lhs, rhs) - - -def trt_transposed_linear( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -): - return torch.matmul(input.transpose(-1, -2), weight.t()) + bias - - -def check_permute(node: torch.fx.Node): - ranks = len(node.meta["tensor_meta"].shape) - permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr] - allowed_permutation = list(i for i in range(ranks)) - allowed_permutation[-1] = ranks - 2 - allowed_permutation[-2] = ranks - 1 - return permutation == allowed_permutation - - -@observable() -@log_before_after -@validate_inference(atol=1e-3, rtol=1e-2) -def fuse_permute_linear(gm: torch.fx.GraphModule, input: Input): - """ - Fuse pattern like permute + linear if permute is transposing the last two dimension. - """ - for node in gm.graph.nodes: - if node.target == acc_ops.linear: - inp = node.kwargs["input"] - if inp.target == acc_ops.permute and check_permute(inp): - inp = inp.kwargs["input"] - weight = node.kwargs["weight"] - bias = node.kwargs["bias"] - with gm.graph.inserting_before(node): - fused_node = gm.graph.call_function( - trt_transposed_linear, args=(inp, weight, bias) - ) - node.replace_all_uses_with(fused_node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - return gm - - -@observable() -@log_before_after -@validate_inference(atol=1e-3, rtol=1e-2) -def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input): - """ - Fuse pattern like permute + matmul if permute is transposing the last two dimension. - """ - for node in gm.graph.nodes: - if node.target == acc_ops.matmul: - lhs, rhs = node.kwargs["input"], node.kwargs["other"] - lhs_transposed = rhs_tranposed = False - skip = False - - if lhs.target == acc_ops.permute and check_permute(lhs): - lhs_transposed = True - lhs = lhs.kwargs["input"] - - if rhs.target == acc_ops.permute and check_permute(rhs): - rhs_tranposed = True - rhs = rhs.kwargs["input"] - - if (not skip) and (lhs_transposed or rhs_tranposed): - with gm.graph.inserting_before(node): - fused_node = gm.graph.call_function( - trt_transposed_matmul, - args=(lhs, rhs, lhs_transposed, rhs_tranposed), - ) - node.replace_all_uses_with(fused_node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - return gm - - -def slice_list(sli: slice, dim: int, size: int): - slice_all = slice(None, None, None) - if size == 1: - return [sli] - elif size == 2: - if dim == 0: - return [sli, slice_all] - elif dim == 1: - return [slice_all, sli] - elif size == 3: - if dim == 0: - return [sli, slice_all, slice_all] - elif dim == 1: - return [slice_all, sli, slice_all] - elif dim == 2: - return [slice_all, slice_all, sli] - elif size == 4: - if dim == 0: - return [sli, slice_all, slice_all, slice_all] - elif dim == 1: - return [slice_all, sli, slice_all, slice_all] - elif dim == 2: - return [slice_all, slice_all, sli, slice_all] - elif dim == 3: - return [slice_all, slice_all, slice_all, sli] - - -def split_across( - gm: torch.fx.GraphModule, sli: slice, input_node: torch.fx.Node, dim: int, size: int -): - start_node = end_node = mid_node = None - if sli.start is None and sli.stop is None: - return (start_node, input_node, end_node) - if sli.start is not None: - st_sli = slice(0, sli.start, None) - slice_list_gen = slice_list(st_sli, dim, size) - start_node = gm.graph.call_function( - operator.getitem, args=(input_node, slice_list_gen) - ) - if sli.stop is not None: - end_sli = slice(sli.stop, None, None) - slice_list_gen = slice_list(end_sli, dim, size) - end_node = gm.graph.call_function( - operator.getitem, args=(input_node, slice_list_gen) - ) - if dim != size - 1: - mid_sli = slice(sli.start, sli.stop, None) - slice_list_gen = slice_list(mid_sli, dim, size) - mid_node = gm.graph.call_function( - operator.getitem, args=(input_node, slice_list_gen) - ) - return (start_node, mid_node, end_node) - - -def list_gen( - start_node: torch.fx.Node, - end_node: torch.fx.Node, - input_node: torch.fx.Node, - gm: torch.fx.GraphModule, - dim: int, -): - if start_node: - if end_node: - concat_list = [start_node, input_node, end_node] - else: - concat_list = [start_node, input_node] - else: - if end_node: - concat_list = [input_node, end_node] - else: - concat_list = [input_node] - if len(concat_list) > 1: - concat_node = gm.graph.call_function(torch.cat, args=(concat_list, dim)) - else: - concat_node = concat_list[0] - return concat_node - - -def transform_setitem(gm: torch.fx.GraphModule, input: Input): - """ - Setitem is not tracable in fx and acc tracer but is available in dynamo trace. This pass works for dynamo trace only. - The implementation decompose the setitem into a few getitem op and assembly together again through concat. - The major reason is that TRT does not support in-place copy and memory reference. - """ - map_replace = {} - for node in gm.graph.nodes: - for old_node in map_replace: - node.replace_input_with(old_node, map_replace[old_node]) - - if node.target == operator.setitem: - input_node = node.args[0] - sli = node.args[1] - inp = node.args[2] - - inp_flag = False - if type(inp) == torch.fx.node.Node and inp.target == operator.getitem: - new_args = list(copy.deepcopy(inp.args[1])) - for ind, val in enumerate(new_args): - if type(val) == int: - inp_flag = True - if val == -1: - new_args[ind] = slice(-1, None, None) - else: - new_args[ind] = slice(val, val + 1, None) - - if inp_flag: - with gm.graph.inserting_before(inp): - new_node = gm.graph.call_function( - operator.getitem, args=(inp.args[0], new_args) - ) - inp.replace_all_uses_with(new_node) - inp = new_node - - if type(sli) is not tuple: - sli = [sli] - - tmp_sli = [] - for x in sli: - if type(x) == int: - if x == -1: - tmp_sli.append(slice(-1, None, None)) - else: - tmp_sli.append(slice(x, x + 1, None)) - else: - tmp_sli.append(x) - sli = tmp_sli - - dimension = len(sli) - with gm.graph.inserting_before(node): - if dimension == 1: - start_node_0, _, end_node_0 = split_across( - gm, sli[0], input_node, dim=0, size=1 - ) - concat_node_0 = list_gen(start_node_0, end_node_0, inp, gm, 0) - elif dimension == 2: - start_node_0, mid_node_0, end_node_0 = split_across( - gm, sli[0], input_node, dim=0, size=2 - ) - start_node_1, _, end_node_1 = split_across( - gm, sli[1], mid_node_0, dim=1, size=2 - ) - concat_node_1 = list_gen(start_node_1, end_node_1, inp, gm, 1) - concat_node_0 = list_gen( - start_node_0, end_node_0, concat_node_1, gm, 0 - ) - elif dimension == 3: - start_node_0, mid_node_0, end_node_0 = split_across( - gm, sli[0], input_node, dim=0, size=3 - ) - start_node_1, mid_node_1, end_node_1 = split_across( - gm, sli[1], mid_node_0, dim=1, size=3 - ) - start_node_2, _, end_node_2 = split_across( - gm, sli[2], mid_node_1, dim=2, size=3 - ) - concat_node_2 = list_gen(start_node_2, end_node_2, inp, gm, 2) - concat_node_1 = list_gen( - start_node_1, end_node_1, concat_node_2, gm, 1 - ) - concat_node_0 = list_gen( - start_node_0, end_node_0, concat_node_1, gm, 0 - ) - elif dimension == 4: - start_node_0, mid_node_0, end_node_0 = split_across( - gm, sli[0], input_node, dim=0, size=4 - ) - start_node_1, mid_node_1, end_node_1 = split_across( - gm, sli[1], mid_node_0, dim=1, size=4 - ) - start_node_2, mid_node_2, end_node_2 = split_across( - gm, sli[2], mid_node_1, dim=2, size=4 - ) - start_node_3, _, end_node_3 = split_across( - gm, sli[3], mid_node_2, dim=3, size=4 - ) - concat_node_3 = list_gen(start_node_3, end_node_3, inp, gm, 3) - concat_node_2 = list_gen( - start_node_2, end_node_2, concat_node_3, gm, 2 - ) - concat_node_1 = list_gen( - start_node_1, end_node_1, concat_node_2, gm, 1 - ) - concat_node_0 = list_gen( - start_node_0, end_node_0, concat_node_1, gm, 0 - ) - else: - warnings.warn(f"setitem does not support dimension={dimension}") - continue - node.replace_all_uses_with(concat_node_0) - map_replace[input_node] = concat_node_0 - gm.graph.erase_node(node) - - gm.graph.lint() - gm.recompile() - return gm - - -def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule: - """\ - TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256), - since the dynamic shape of the reshape comes from the dynamic shape of - another node (y). The compilation will fail with various memory related - errors, depending on the size of the input tensor. - - This pass fixes the issue by finding this reshape pattern, checking that: - - x.size(0) == y.size(0) - - And then replaces reshape's batch size from y.size(0) to x.size(0). - """ - - def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]: - """\ - Try to find the reshape op's batch size as an input node. - - Match below graph structure and return `node_y`: - node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}}) - """ - if ( - maybe_reshape.op != "call_function" - or maybe_reshape.target != acc_ops.reshape - ): - return None - shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None) - if not shape: - return None - batch_size = shape[0] - if isinstance(batch_size, fx.Node): - return batch_size - return None - - def get_reshape_batch_size_inferred_source( - batch_size_node: fx.Node, - ) -> Optional[fx.Node]: - """\ - Given a node representing the batch size used for reshape op, we want - to know if it is coming from below pattern: - - batch_size_node = src.size()[0] - - or in IR graph: - - src -> size(input=_) -> getitem(input=_, idx=0) - ^ ~~~ batch_size_node - - If so, return `src`. Otherwise, return `None`. - """ - if ( - batch_size_node.op != "call_function" - or batch_size_node.target != acc_ops.getitem - or batch_size_node.kwargs["idx"] != 0 - ): - return None - maybe_size: fx.Node = batch_size_node.all_input_nodes[0] - if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size: - return None - return maybe_size.all_input_nodes[0] - - maybe_reshape: fx.Node - for maybe_reshape in mod.graph.nodes: - reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( - maybe_reshape - ) - if not reshape_batch_size: - continue - reshape_batch_size_inferred_source: Optional[ - fx.Node - ] = get_reshape_batch_size_inferred_source(reshape_batch_size) - if not reshape_batch_size_inferred_source: - continue - - reshape_input: fx.Node = maybe_reshape.kwargs["input"] - if reshape_input == reshape_batch_size_inferred_source: - continue - - if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source): - continue - - _LOGGER.info( - f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}" - ) - - # Step 1: create a node to compute batch size, using the tensor which - # is being reshaped: reshape_input.size()[0]. This batch size is now - # derived from reshape_input, the same node as the reshape op's input. - with mod.graph.inserting_before(maybe_reshape): - reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function( - acc_ops.getitem, - kwargs={ - "idx": 0, - "input": maybe_reshape.graph.call_function( - acc_ops.size, - kwargs={ - "input": reshape_input, - }, - ), - }, - ) - - # Step 2: update `maybe_reshape`'s shape argument to be - # (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER) - maybe_reshape.kwargs = { - **maybe_reshape.kwargs, - "acc_out_ty": acc_utils.build_raw_tensor_meta( - shape=( - reshape_batch_size_2, - *(maybe_reshape.kwargs["acc_out_ty"].shape[1:]), - ) - ), - } - - mod.graph.eliminate_dead_code() - mod.recompile() - return mod - - -def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool: - """\ - Check that x.size(0) == y.size(0) - """ - x_size, y_size = _get_shape(x), _get_shape(y) - return ( - x_size - and y_size - # now both are non-empty - and x_size[0] == y_size[0] - ) - - -def _get_shape(node: fx.Node) -> Optional[torch.Size]: - if ( - not getattr(node, "meta", None) - or not node.meta.get("tensor_meta", None) - or not getattr(node.meta["tensor_meta"], "shape", None) - ): - # shape info not available - return None - return node.meta["tensor_meta"].shape - - -@log_before_after -@validate_inference(atol=1e-3, rtol=1e-2) -def fix_clamp_numerical_limits_to_fp16( - mod: torch.fx.GraphModule, input: Input -) -> torch.fx.GraphModule: - MIN_FP16 = -65504.0 - MAX_FP16 = 65504.0 - for node in mod.graph.nodes: - if node.op == "call_function" and "clamp" in str(node.target): - input_kwargs = node.kwargs - if input_kwargs["min"] < MIN_FP16 and input_kwargs["max"] > MAX_FP16: - new_kwargs = { - "input": input_kwargs["input"], - "min": MIN_FP16, - "max": MAX_FP16, - } - node.kwargs = new_kwargs - - mod.recompile() - return mod diff --git a/py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py deleted file mode 100644 index 00063c3e21..0000000000 --- a/py/torch_tensorrt/dynamo/passes/lower_basic_pass_aten.py +++ /dev/null @@ -1,525 +0,0 @@ -import logging -import operator -from typing import Any - -import torch -import torch.fx -from torch.fx.experimental.const_fold import split_const_subgraphs -from torch.fx.passes.infra.pass_base import PassResult - -_LOGGER = logging.getLogger(__name__) - -# Create an alias for module input type to avoid littering pyre-ignore for Any -# throughout the file. -Input = Any - - -def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: - # Now we do constant folding on traced module. - def skip_folding(node: torch.fx.Node): - if node.target == torch.ops.aten.sym_size: - return True - - const_split_mod = split_const_subgraphs( - traced_mod, skip_folding_node_fn=skip_folding - ) - const_split_mod.run_folding() - return const_split_mod - - -def replace_inplace_ops( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """ - Remove this func after functionalization is workable - """ - modified = False - map_func = { - torch.ops.aten.relu_.default: torch.ops.aten.relu.default, - torch.ops.aten.hardtanh_.default: torch.ops.aten.hardtanh.default, - torch.ops.aten.add_.Tensor: torch.ops.aten.add.Tensor, - } - for n in module.graph.nodes: - if n.op == "call_function" and n.target in map_func.keys(): - modified = True - node = n - with module.graph.inserting_after(node): - new_args = node.args - new_node = module.graph.create_node( - "call_function", - map_func[node.target], - args=new_args, - kwargs=None, - ) - node.replace_all_uses_with(new_node) - module.graph.erase_node(node) - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def replace_native_layernorm_with_layernorm( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - modified = False - for n in module.graph.nodes: - if ( - n.op == "call_function" - and n.target == torch.ops.aten.native_layer_norm.default - ): - for v in n.users: - if v.op == "call_function" and v.target == operator.getitem: - if v.args[1] != 0: - raise RuntimeError( - f"Got args[{v.args[1]}]!!\n" - "layernorm can only generate output (args[0]), " - "not mean (args[1]) or std (args[2])!" - ) - new_op = torch.ops.aten.layer_norm.default - new_args = (*n.args, True) # cudnn_enable=True - modified = True - else: - continue - - with module.graph.inserting_after(v): - new_node = module.graph.create_node( - "call_function", - new_op, - args=new_args, - kwargs=v.kwargs, - ) - v.replace_all_uses_with(new_node) - - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def replace_transpose_mm_op_with_linear( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - modified = False - for n in module.graph.nodes: - if n.op == "call_function" and n.target == torch.ops.aten.t.default: - to_erase = [] - for v in n.users: - if v.op == "call_function" and v.target == torch.ops.aten.addmm.default: - new_op = torch.ops.aten.linear - bias, inp, _ = list(v.args) - weight = list(n.args)[0] - new_args = (inp, weight, bias) - modified = True - elif v.op == "call_function" and v.target == torch.ops.aten.mm.default: - new_op = torch.ops.aten.linear - inp, _ = list(v.args) - weight = list(n.args)[0] - new_args = (inp, weight, None) - modified = True - # this pass should be after `compose_bmm` - elif v.op == "call_function" and v.target == aten_compose_bmm_2d: - new_op = torch.ops.aten.linear - inp, _ = list(v.args) - weight = list(n.args)[0] - new_args = (inp, weight, None) - modified = True - else: - continue - - with module.graph.inserting_after(v): - new_node = module.graph.create_node( - "call_function", - new_op, - args=new_args, - kwargs=v.kwargs, - ) - v.replace_all_uses_with(new_node) - to_erase.append(v) - for v in to_erase: - module.graph.erase_node(v) - module.graph.eliminate_dead_code() - module.recompile() - # handle the linear with multiple dim, remove the extra reshape - for n in module.graph.nodes: - if n.op == "call_function" and n.target == torch.ops.aten.linear: - before = n.args[0] - after = next(iter(n.users)) - if (len(n.users) == 1 and after.target == torch.ops.aten.view.default) and ( - before.target == torch.ops.aten.view.default and len(before.users) == 1 - ): - real_input = before.args[0] - new_args = list(n.args) - new_args[0] = real_input - n.args = tuple(new_args) - after.replace_all_uses_with(n) - module.graph.eliminate_dead_code() - module.recompile() - - return PassResult(module, modified) - - -def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - modified = False - for n in module.graph.nodes: - if n.op == "call_function" and n.target in ( - torch.ops.aten.max_pool2d_with_indices.default, - torch.ops.aten.max_pool3d_with_indices.default, - torch.ops.aten.native_batch_norm.default, - torch.ops.aten._native_batch_norm_legit.default, - torch.ops.aten._native_batch_norm_legit_no_training.default, - ): - modified = True - if len(n.users) != 1: - raise RuntimeError( - f"{n.target} has users={len(n.users)}. We can only handle it with 1 user" - ) - if n.target == torch.ops.aten.max_pool2d_with_indices.default: - new_op = torch.ops.aten.max_pool2d - new_args = n.args - elif n.target == torch.ops.aten.max_pool3d_with_indices.default: - new_op = torch.ops.aten.max_pool3d - new_args = n.args - elif ( - n.target == torch.ops.aten.native_batch_norm.default - or n.target == torch.ops.aten._native_batch_norm_legit.default - ): - new_op = torch.ops.aten.batch_norm - new_args = list(n.args) - new_args.append(False) - new_args = tuple(new_args) - elif ( - n.target == torch.ops.aten._native_batch_norm_legit_no_training.default - ): - new_op = torch.ops.aten.batch_norm - new_args = list(n.args) - new_args.append(False) - # _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false) - # but batchnorm takes in a training arg at position 5. - new_args.insert(5, False) - new_args = tuple(new_args) - - getitem_node = next(iter(n.users)) - with module.graph.inserting_after(getitem_node): - new_node = module.graph.create_node( - "call_function", - new_op, - args=new_args, - kwargs=n.kwargs, - ) - getitem_node.replace_all_uses_with(new_node) - module.graph.erase_node(getitem_node) - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def replace_aten_reshape_alias_with_replace( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - # The stride parameter is not used. Replace with reshape without stride - modified = False - for n in module.graph.nodes: - if n.op == "call_function" and n.target in ( - torch.ops.aten._reshape_alias.default, - ): - modified = True - node = n - with module.graph.inserting_after(node): - new_args = (node.args[0], node.args[1]) - new_node = module.graph.create_node( - "call_function", - torch.ops.aten.reshape, - args=new_args, - kwargs=None, - ) - node.replace_all_uses_with(new_node) - module.graph.erase_node(node) - break - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def remove_ops( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """ - 1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable - 2. Remove inefficient op getitem(index=slice) P561572458 - """ - modified = False - for n in module.graph.nodes: - if n.op == "call_function" and n.target in (torch.ops.aten.clone.default,): - modified = True - node = n - input_n = node.all_input_nodes[0] - node.replace_all_uses_with(input_n) - module.graph.eliminate_dead_code() - module.recompile() - for n in module.graph.nodes: - if n.op == "call_function" and n.target in ( - torch.ops.aten._unsafe_view.default, - ): - modified = True - node = n - with module.graph.inserting_after(node): - new_node = module.graph.create_node( - "call_function", - torch.ops.aten.reshape, - args=node.args, - kwargs=node.kwargs, - ) - node.replace_all_uses_with(new_node) - module.graph.erase_node(node) - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def aten_operator_getitem(*args): - return operator.getitem(*args) - - -def replace_builtin_ops( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """ - To differential the same op in fx2ait as they are registered in the same dictionary - """ - - modified = False - for n in module.graph.nodes: - if n.op == "call_function" and n.target in (operator.getitem,): - modified = True - n.target = aten_operator_getitem - module.graph.eliminate_dead_code() - module.recompile() - - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -############### -""" -Trace compose. For some ops, we do not want to decompose further but want coarse granularity -For ex: -1. bmm -2. chunk -3. getitem(input, idx=(slice(),slice()...)) -""" - - -def aten_compose_getitem_slice(input, list_args): - for _, args in enumerate(list_args): - input = torch.ops.aten.slice.Tensor(input, *args) - return input - - -def compose_getitem_slice( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """ - combine decomposed getitem(input, idx=(slice(),slice()...)) - """ - - def match_pattern(module, node): - if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor: - holder = [] - holder.append(node) - while ( - len(node.users.keys()) == 1 - and next(iter(node.users)).target == torch.ops.aten.slice.Tensor - and node.args[1] + 1 == next(iter(node.users)).args[1] - ): - node = next(iter(node.users)) - holder.append(node) - if len(holder) == 1: - return (False,) - else: - return (True, holder) - return (False,) - - modified = False - for node in module.graph.nodes: - res = match_pattern(module, node) - if res[0]: - modified = True - holder = res[1] - input_n = holder[0].args[0] - last_n = holder[-1] - list_args = [] - for h_n in holder: - list_args.append(h_n.args[1:]) - - with module.graph.inserting_after(last_n): - new_args = (input_n, list_args) - new_node = module.graph.create_node( - "call_function", - aten_compose_getitem_slice, - args=new_args, - kwargs=None, - ) - last_n.replace_all_uses_with(new_node) - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def aten_compose_bmm_2d(flat_args_1, flat_args_2): - sym_size = torch.ops.aten.sym_size(flat_args_1, 0) - sym_size_1 = torch.ops.aten.sym_size(flat_args_1, 1) - sym_size_2 = torch.ops.aten.sym_size(flat_args_1, 2) - expand = torch.ops.aten.expand.default( - flat_args_1, [sym_size, sym_size_1, sym_size_2] - ) - view = torch.ops.aten.view.default(expand, [sym_size, sym_size_1, sym_size_2]) - sym_size_3 = torch.ops.aten.sym_size(flat_args_2, 0) - sym_size_4 = torch.ops.aten.sym_size(flat_args_2, 1) - expand_1 = torch.ops.aten.expand.default( - flat_args_2, [sym_size, sym_size_3, sym_size_4] - ) - view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) - bmm = torch.ops.aten.bmm.default(view, view_1) - view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) - return view_2 - - -def aten_compose_bmm_3d(flat_args_1, flat_args_2): - sym_size = torch.ops.aten.sym_size(flat_args_1, 0) - sym_size_1 = torch.ops.aten.sym_size(flat_args_1, 1) - sym_size_2 = torch.ops.aten.sym_size(flat_args_1, 2) - expand = torch.ops.aten.expand.default( - flat_args_1, [sym_size, sym_size_1, sym_size_2] - ) - view = torch.ops.aten.view.default(expand, [sym_size, sym_size_1, sym_size_2]) - sym_size_3 = torch.ops.aten.sym_size(flat_args_2, 1) - sym_size_4 = torch.ops.aten.sym_size(flat_args_2, 2) - expand_1 = torch.ops.aten.expand.default( - flat_args_2, [sym_size, sym_size_3, sym_size_4] - ) - view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) - bmm = torch.ops.aten.bmm.default(view, view_1) - view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) - return view_2 - - -def compose_bmm( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """ - combine decomposed bmm (matmul) - """ - modified = False - for n in module.graph.nodes: - if n.op == "call_function" and n.target in (torch.ops.aten.bmm.default,): - modified = True - node = n - input_n = node.all_input_nodes[0] - other_n = node.all_input_nodes[1] - output = next(iter(node.users)) - input_input_n = input_n.all_input_nodes[0] - if ( - input_input_n.target != torch.ops.aten.expand.default - and input_n.target != torch.ops.aten.view.default - ): - raise RuntimeError( - "Bmm is addressed in fixed pattern. A new pattern is met!" - ) - real_input = input_input_n.all_input_nodes[0] - input_other_n = other_n.all_input_nodes[0] - if ( - input_other_n.target != torch.ops.aten.expand.default - and other_n.target != torch.ops.aten.view.default - ): - raise RuntimeError( - "Bmm is addressed in fixed pattern. A new pattern is met!" - ) - real_other = input_other_n.all_input_nodes[0] - if len(real_other.meta["val"].size()) == 2: - new_func = aten_compose_bmm_2d - if len(real_other.meta["val"].size()) == 3: - new_func = aten_compose_bmm_3d - - with module.graph.inserting_after(node): - new_args = (real_input, real_other) - new_node = module.graph.create_node( - "call_function", - new_func, - args=new_args, - kwargs=None, - ) - output.replace_all_uses_with(new_node) - - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) - - -def aten_compose_chunk(flat_args_1, chunk, dim): - sym_size = torch.ops.aten.sym_size(flat_args_1, dim) - add = operator.add(sym_size, chunk) - sub = operator.sub(add, 1) - floordiv = operator.floordiv(sub, chunk) - split = torch.ops.aten.split.Tensor(flat_args_1, floordiv, dim) - return split - - -def compose_chunk( - module: torch.fx.GraphModule, -) -> torch.fx.GraphModule: - """ - combine decomposed chunk - """ - - def match_pattern(module, node): - if node.op == "call_function" and node.target in (torch.ops.aten.split.Tensor,): - div = node.args[1] - input = node.args[0] - if isinstance(div, int): - return (False,) - if div.target != operator.floordiv: - return (False,) - else: - div_const = div.args[1] - sub = div.args[0] - if sub.target != operator.sub: - return (False,) - else: - add = sub.args[0] - if add.target != operator.add: - return (False,) - else: - add_const = add.args[1] - if add_const != div_const: - return (False,) - symsize = add.args[0] - if symsize.target != torch.ops.aten.sym_size: - return (False,) - else: - symsize_input = symsize.args[0] - dim = symsize.args[1] - if symsize_input != input: - return (False,) - - return (True, div_const, dim) - else: - return (False,) - - modified = False - for node in module.graph.nodes: - res = match_pattern(module, node) - if res[0]: - modified = True - with module.graph.inserting_after(node): - new_args = (node.args[0], res[1], res[2]) - new_node = module.graph.create_node( - "call_function", - aten_compose_chunk, - args=new_args, - kwargs=None, - ) - node.replace_all_uses_with(new_node) - - module.graph.eliminate_dead_code() - module.recompile() - return PassResult(module, modified) diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py index 494f3922de..4c086fee41 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py @@ -14,11 +14,11 @@ from ..lower_setting import LowerSetting from torch_tensorrt.fx.observer import Observer -from ..passes.remove_duplicate_output_args import remove_duplicate_output_args -from .graph_opts import common_subexpression_elimination +from torch_tensorrt.fx.passes.remove_duplicate_output_args import remove_duplicate_output_args +from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination from .pass_utils import extract_example_tensors_from_input -from .lower_basic_pass import ( # noqa +from torch_tensorrt.fx.passes.lower_basic_pass import ( # noqa fix_clamp_numerical_limits_to_fp16, fix_reshape_batch_dim, replace_mutable_op, diff --git a/py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py b/py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py deleted file mode 100644 index 84a522a3f0..0000000000 --- a/py/torch_tensorrt/dynamo/passes/remove_duplicate_output_args.py +++ /dev/null @@ -1,140 +0,0 @@ -#!/usr/bin/env python3 - -import dataclasses as dc -import logging -import operator -import typing as t - -import torch.fx as fx - -_LOGGER = logging.getLogger(__name__) - -RemoveDuplicateOutputArgsFunc = t.Callable[ - [ - fx.GraphModule, - t.Collection[str], - ], - t.Mapping[str, "RemoveDuplicateResult"], -] - - -def remove_duplicate_output_args( - top_level: fx.GraphModule, target_subnets: t.Collection[str] -) -> t.Mapping[str, "RemoveDuplicateResult"]: - """Removes duplicate output args. - - This pass removes duplicate output args from the target subnets and fixes - their uses in the top level module where the subnets are called. This pass - must be called after acc split on the top-level net and subsequent calls to - the acc trace on the subnets. - - This pass will change both the subnets and top level module. - - Returns: - a mapping of the target subnet name to its dedupcate result - """ - - processed_subnets = {} - for node in top_level.graph.nodes: # type: fx.Node - if node.op == "call_module" and node.name in target_subnets: - assert isinstance(node.target, str) - sub_gm = top_level.get_submodule(node.target) - assert isinstance(sub_gm, fx.GraphModule) - - replace_res = _remove_duplicate_output_args(sub_gm) - processed_subnets[node.name] = replace_res - if replace_res.replacement_map is None: - continue - sub_gm.recompile() - - needs_recompile = False - # iterate on the copy since we will be changing elements of node.users - for user in list(node.users): - idx = _ensure_proper_output_use(user, node) - idx_new = replace_res.replacement_map[idx] - if idx_new != idx: - user.args = (user.args[0], idx_new) - needs_recompile = True - - if needs_recompile: - top_level.recompile() - return processed_subnets - - -@dc.dataclass(frozen=True) -class RemoveDuplicateResult: - replacement_map: t.Optional[t.List[int]] - module: fx.GraphModule - - -def _ensure_proper_output_use(user: fx.Node, target_node: fx.Node) -> int: - """ - Ensures the node looks in proper form of calling the output of an fx2trt - splitter sub-net. Specifically: - - 1. op is call function, target: operator.getitem - 2. args is a 2-element tuple - 3. args[0] is the name of the subnet's output - 4. args[1] is the index into the subnet output tuple - - E.g.: - - %getitem_4 : [#users=1] = call_function[target=operator.getitem](args = (%_run_on_acc_1, 4), kwargs = {}) - - returns the index into the subnet output tuple - """ - assert ( - user.op == "call_function" - and user.target == operator.getitem - and len(user.args) == 2 - and isinstance(user.args[0], fx.Node) - and user.args[0].name == target_node.name - and isinstance(user.args[1], int) - ), f"Node is not a proper user of splitter output: {user.format_node()}" - - return user.args[1] - - -def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult: - output_nodes = [n for n in gm.graph.nodes if n.op == "output"] - assert ( - len(output_nodes) == 1 - ), f"Expecting exactly one `output` node, but got {len(output_nodes)}" - - changed = False - # arg node name to its index in the new output args tuple - name_to_idx: t.Dict[str, int] = {} - output_node = output_nodes[0] - - # Output op only uses its `args[0]`, and it does not have `kwargs`. - # https://pytorch.org/docs/stable/fx.html#torch.fx.Node - args: t.Sequence[t.Any] = output_node.args[0] - - # Only concern outselves to the case where the args is an iterable of fx.Node. - # Other return cases (e.g., a single value) is possible and we don't handle - # that in this pass. - if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)): - return RemoveDuplicateResult(replacement_map=None, module=gm) - - # Map old index of the arg node to the remaining node's idx, - # initialized to `i => i` - replacement_map: t.List[int] = list(range(len(args))) - args_new = [] - for idx, a in enumerate(args): - assert isinstance(a, fx.Node), f"Expecting fx.Node instance, but got: {type(a)}" - - if a.name not in name_to_idx: - args_new.append(a) - name_to_idx[a.name] = len(args_new) - 1 - else: - changed = True - _LOGGER.warning( - f"Replaced duplicate output arg '{a.name}': " - f"{idx} -> {name_to_idx[a.name]}" - ) - replacement_map[idx] = name_to_idx[a.name] - - output_node.args = (tuple(args_new),) - if changed: - gm.recompile() - return RemoveDuplicateResult(replacement_map, module=gm) diff --git a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py index 89fbafe82b..fa5d444d02 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py @@ -4,7 +4,7 @@ import torch from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo import generate_input_specs, InputTensorSpec, LowerSetting +from torch_tensorrt.dynamo import InputTensorSpec, LowerSetting class TestTRTModule(TestCase): @@ -63,31 +63,5 @@ def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): tensor_shape[i] = batch_size self.assertSequenceEqual(tensor_shape, shape) - def test_generate_input_specs(self): - lower_setting = LowerSetting( - explicit_batch_dimension=False, opt_profile_replica=2 - ) - - # Implicit batch dim. - inputs = [torch.randn(1, 2, 3)] - specs = generate_input_specs(inputs, lower_setting) - for spec, tensor in zip(specs, inputs): - self._validate_spec(spec, tensor) - - # Explicit batch dim without additional inputs. - lower_setting.explicit_batch_dimension = True - specs = generate_input_specs(inputs, lower_setting) - for spec, tensor in zip(specs, inputs): - self._validate_spec(spec, tensor, dynamic_dims=[0]) - self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) - - # Explicit batch dim with additional inputs. - additional_inputs = [torch.randn(1, 1, 3)] - specs = generate_input_specs(inputs, lower_setting, additional_inputs) - for spec, tensor in zip(specs, inputs): - self._validate_spec(spec, tensor, dynamic_dims=[1]) - self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) - - if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/test/core/test_trt_module.py b/py/torch_tensorrt/dynamo/test/core/test_trt_module.py index 195c5ad65c..baf98c8d7c 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_trt_module.py +++ b/py/torch_tensorrt/dynamo/test/core/test_trt_module.py @@ -8,10 +8,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter, TRTModule - -# from torch_tensorrt import TRTModuleNext -# from torch_tensorrt import Device +from torch_tensorrt.fx import TRTModule +from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter from torch_tensorrt.dynamo.utils import LowerPrecision diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py index 0a4805d45f..2fbac9cfb4 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py @@ -3,7 +3,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch_tensorrt.dynamo.passes.lower_basic_pass import ( +from torch_tensorrt.fx.passes.lower_basic_pass import ( fix_clamp_numerical_limits_to_fp16, ) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py index 27ea6d038c..bd04692ad5 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py @@ -8,7 +8,7 @@ import torch.nn as nn from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.passes.lower_basic_pass import fix_reshape_batch_dim +from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer _LOGGER = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py index f2a4a89b69..12fd25447d 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py @@ -5,7 +5,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.passes.lower_basic_pass import ( +from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, trt_transposed_linear, ) @@ -31,57 +31,57 @@ def forward(self, x): apply_passes=[fuse_permute_linear], ) - def test_fuse_permute_linear_keep_permute(self): - """ - Fusion while keep permute node since permute has more than one consumers - """ - - class TestModule(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - y = x.permute(0, 2, 1) - return self.linear(y), y - - inputs = [torch.randn(6, 10, 20)] - a = TestModule(10, 30) - self.run_test( - TestModule(10, 30), - inputs, - {acc_ops.permute, trt_transposed_linear}, - apply_passes=[fuse_permute_linear], - ) - - # TODO: The following test has been disabled due to a bug in TRT 8.5.1.7 - # with self.linear2. Issue : https://github.com/pytorch/TensorRT/issues/1444 - @unittest.skip( - reason="test_multi_fuse_permute_linear has been disabled due to a bug in TRT 8.5.1.7 https://github.com/pytorch/TensorRT/issues/1444" - ) - def test_multi_fuse_permute_linear(self): - """ - Fusion when permute output is shared by multiple linears - """ - - class TestModule(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear1 = torch.nn.Linear(in_features, out_features) - self.linear2 = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - y = x.permute(0, 2, 1) - return self.linear1(y) + self.linear2(y) - - inputs = [torch.randn(8, 10, 20)] - a = TestModule(10, 30) - self.run_test( - TestModule(10, 30), - inputs, - {trt_transposed_linear}, - apply_passes=[fuse_permute_linear], - ) + # def test_fuse_permute_linear_keep_permute(self): + # """ + # Fusion while keep permute node since permute has more than one consumers + # """ + # + # class TestModule(torch.nn.Module): + # def __init__(self, in_features, out_features): + # super().__init__() + # self.linear = torch.nn.Linear(in_features, out_features) + # + # def forward(self, x): + # y = x.permute(0, 2, 1) + # return self.linear(y), y + # + # inputs = [torch.randn(6, 10, 20)] + # a = TestModule(10, 30) + # self.run_test( + # TestModule(10, 30), + # inputs, + # {acc_ops.permute, trt_transposed_linear}, + # apply_passes=[fuse_permute_linear], + # ) + # + # # TODO: The following test has been disabled due to a bug in TRT 8.5.1.7 + # # with self.linear2. Issue : https://github.com/pytorch/TensorRT/issues/1444 + # @unittest.skip( + # reason="test_multi_fuse_permute_linear has been disabled due to a bug in TRT 8.5.1.7 https://github.com/pytorch/TensorRT/issues/1444" + # ) + # def test_multi_fuse_permute_linear(self): + # """ + # Fusion when permute output is shared by multiple linears + # """ + # + # class TestModule(torch.nn.Module): + # def __init__(self, in_features, out_features): + # super().__init__() + # self.linear1 = torch.nn.Linear(in_features, out_features) + # self.linear2 = torch.nn.Linear(in_features, out_features) + # + # def forward(self, x): + # y = x.permute(0, 2, 1) + # return self.linear1(y) + self.linear2(y) + # + # inputs = [torch.randn(8, 10, 20)] + # a = TestModule(10, 30) + # self.run_test( + # TestModule(10, 30), + # inputs, + # {trt_transposed_linear}, + # apply_passes=[fuse_permute_linear], + # ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py index f48c759be7..fec3edbf9a 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.passes.lower_basic_pass import ( +from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_matmul, trt_transposed_matmul, ) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py b/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py index 8e75fbd17e..c91c456eb3 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py @@ -6,7 +6,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch_tensorrt.dynamo.passes.graph_opts import common_subexpression_elimination +from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py index d8edef52b9..af4279db2f 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.passes.lower_basic_pass import ( +from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, fuse_permute_matmul, trt_transposed_linear, diff --git a/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py index 2ab06be627..1bb76c6691 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py @@ -5,7 +5,7 @@ import torch.fx as fx import torch.nn as nn -import torch_tensorrt.dynamo.passes.remove_duplicate_output_args as dedup +import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup from torch.testing._internal.common_utils import run_tests, TestCase _LOGGER = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py b/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py index d5fce3778d..ded67e97dc 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py +++ b/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py @@ -3,7 +3,7 @@ from parameterized import parameterized from torch._dynamo.optimizations import backends from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.passes.lower_basic_pass import transform_setitem +from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase diff --git a/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py b/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py index 146f5a6932..e3fa371e38 100644 --- a/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py @@ -28,8 +28,9 @@ QuantizationTestCase, ) from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter, TRTModule -from torch_tensorrt.dynamo.passes.lower_basic_pass import run_const_fold +from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter +from torch_tensorrt.fx import TRTModule +from torch_tensorrt.fx.passes.lower_basic_pass import run_const_fold from torch_tensorrt.fx.tracer.acc_tracer import acc_ops from torch_tensorrt.dynamo.utils import LowerPrecision diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py index a626c739b0..4c4948ba2a 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py @@ -7,7 +7,7 @@ import torch.fx as fx import torch.nn as nn from torch_tensorrt.dynamo.lower import Lowerer, LowerSetting -from torch_tensorrt.dynamo.passes.lower_basic_pass import replace_mutable_op +from torch_tensorrt.fx.passes.lower_basic_pass import replace_mutable_op logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py index b067e93195..824cc66be9 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.observer as ob from test_observer import execution_verifier, set_observer_callback_rethrow -from torch_tensorrt.dynamo.passes.lower_basic_pass import fuse_permute_linear +from torch_tensorrt.fx.passes.lower_basic_pass import fuse_permute_linear class ObserverGPUTests(TestCase): diff --git a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py index 13dcbcda37..da09c00ab9 100644 --- a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/tools/common_fx2trt.py @@ -14,7 +14,7 @@ from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter -from torch_tensorrt.dynamo.passes.lower_basic_pass_aten import ( +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, compose_getitem_slice, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 52daf6f09e..79779f604e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -7,7 +7,7 @@ import torch from functorch import make_fx from functorch.experimental import functionalize -from torch_tensorrt.dynamo.passes.lower_basic_pass import ( +from torch_tensorrt.fx.passes.lower_basic_pass import ( replace_op_with_indices, run_const_fold, ) From b647b5df3a970f3faf8ee0a865b3adb846f10a66 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 7 Apr 2023 14:59:13 -0700 Subject: [PATCH 21/45] chore: Linter fixes Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/lower_setting.py | 5 ++++- .../dynamo/passes/lower_pass_manager_builder.py | 4 +++- py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lower_setting.py b/py/torch_tensorrt/dynamo/lower_setting.py index 9ca22d5ef8..890b99d18b 100644 --- a/py/torch_tensorrt/dynamo/lower_setting.py +++ b/py/torch_tensorrt/dynamo/lower_setting.py @@ -5,7 +5,10 @@ from torch.fx.passes.pass_manager import PassManager from .input_tensor_spec import InputTensorSpec -from torch_tensorrt.fx.passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul +from torch_tensorrt.fx.passes.lower_basic_pass import ( + fuse_permute_linear, + fuse_permute_matmul, +) from .utils import LowerPrecision diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py index 4c086fee41..ea2e331619 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py @@ -14,7 +14,9 @@ from ..lower_setting import LowerSetting from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.fx.passes.remove_duplicate_output_args import remove_duplicate_output_args +from torch_tensorrt.fx.passes.remove_duplicate_output_args import ( + remove_duplicate_output_args, +) from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination from .pass_utils import extract_example_tensors_from_input diff --git a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py index fa5d444d02..65c7ea4158 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py @@ -63,5 +63,6 @@ def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): tensor_shape[i] = batch_size self.assertSequenceEqual(tensor_shape, shape) + if __name__ == "__main__": run_tests() From 685bba19b92039bbddd37e52b0304daff1514387 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 7 Apr 2023 15:33:09 -0700 Subject: [PATCH 22/45] chore: add workspace_size, disable_tf32, sparse_weights settings Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/fx2trt.py | 14 +++++++++++--- py/torch_tensorrt/dynamo/lower.py | 14 ++++++++++---- py/torch_tensorrt/dynamo/lower_setting.py | 6 ++++-- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx2trt.py index 4dae6e542b..ab8c7c1930 100644 --- a/py/torch_tensorrt/dynamo/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx2trt.py @@ -153,9 +153,10 @@ def validate_conversion(self): def run( self, - max_workspace_size=1 << 25, + workspace_size=0, lower_precision=LowerPrecision.FP16, sparse_weights=False, + disable_tf32=False, force_fp32_output=False, strict_type_constraints=False, algorithm_selector=None, @@ -166,7 +167,7 @@ def run( """ Build TensorRT engine with some configs. Args: - max_workspace_size: set to the maximum size we can afford for temporary buffer + workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation. lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity force_fp32_output: force output to be fp32 @@ -206,7 +207,11 @@ def run( build_engine_start_time = datetime.now() builder_config = self.builder.create_builder_config() - builder_config.max_workspace_size = max_workspace_size + + if (workspace_size != 0): + builder_config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size) + + builder_config.workspace_size = workspace_size cache = None if timing_cache: @@ -231,6 +236,9 @@ def run( if sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) + if disable_tf32: + builder_config.clear_flag(trt.BuilderFlag.TF32) + if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py index fbe6ec4b43..581aad50fc 100644 --- a/py/torch_tensorrt/dynamo/lower.py +++ b/py/torch_tensorrt/dynamo/lower.py @@ -30,9 +30,11 @@ def compile( module: nn.Module, inputs, device=torch.device(torch.cuda.current_device()), + disable_tf32=False, + sparse_weights=False, enabled_precisions=set(), min_block_size: int = 3, - max_workspace_size=1 << 25, + workspace_size=0, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, @@ -48,7 +50,7 @@ def compile( module: Original module for lowering. input: Input for module. min_block_size: Minimal number of nodes for an accelerated submodule - max_workspace_size: Maximum size of workspace given to TensorRT. + workspace_size: Maximum size of workspace given to TensorRT. verbose_log: Enable verbose log for TensorRT if set True. timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. @@ -73,7 +75,9 @@ def compile( lower_setting = LowerSetting( device=device, min_block_size=min_block_size, - max_workspace_size=max_workspace_size, + disable_tf32=disable_tf32, + sparse_weights=sparse_weights, + workspace_size=workspace_size, lower_precision=lower_precision, verbose_log=verbose_log, timing_cache_prefix=timing_cache_prefix, @@ -128,8 +132,10 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: ) interp_result: TRTInterpreterResult = interpreter.run( - max_workspace_size=self.lower_setting.max_workspace_size, + workspace_size=self.lower_setting.workspace_size, lower_precision=self.lower_setting.lower_precision, + sparse_weights=self.lower_setting.sparse_weights, + disable_tf32=self.lower_setting.disable_tf32, strict_type_constraints=self.lower_setting.strict_type_constraints, algorithm_selector=algo_selector, timing_cache=cache_data, diff --git a/py/torch_tensorrt/dynamo/lower_setting.py b/py/torch_tensorrt/dynamo/lower_setting.py index 890b99d18b..be0ada8b55 100644 --- a/py/torch_tensorrt/dynamo/lower_setting.py +++ b/py/torch_tensorrt/dynamo/lower_setting.py @@ -29,6 +29,8 @@ class LowerSettingBasic: lower_precision: LowerPrecision = LowerPrecision.FP32 device: torch.device = torch.device(torch.cuda.current_device()) min_block_size: int = 3 + disable_tf32: bool = False + sparse_weights: bool = False ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None leaf_module_list: Optional[Set[Type[nn.Module]]] = None verbose_profile: bool = False @@ -43,7 +45,7 @@ class LowerSetting(LowerSettingBasic): input_specs: Specs for inputs to engine, can either be a single size or a range defined by Min, Optimal, Max sizes. explicit_precision: Use explicit precision during lowering. - max_workspace_size: The maximum workspace size. The maximum GPU temporary + workspace_size: The maximum workspace size. The maximum GPU temporary memory which the TensorRT engine can use at execution time. strict_type_constraints: Require TensorRT engine to strictly follow data type setting at execution time. @@ -73,7 +75,7 @@ class LowerSetting(LowerSettingBasic): input_specs: List[InputTensorSpec] = dc.field(default_factory=list) explicit_batch_dimension: bool = True explicit_precision: bool = False - max_workspace_size: int = 1 << 30 + workspace_size: int = 0 strict_type_constraints: bool = False customized_fuse_pass: PassManager = dc.field( default_factory=lambda: PassManager.build_from_passlist([]) From 5cbf46d84eacf5dcaf390ddba5b6f7befa92a7bb Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 7 Apr 2023 15:37:38 -0700 Subject: [PATCH 23/45] chore: Linter fixes Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/fx2trt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx2trt.py index ab8c7c1930..b3a68f7a28 100644 --- a/py/torch_tensorrt/dynamo/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx2trt.py @@ -208,8 +208,10 @@ def run( builder_config = self.builder.create_builder_config() - if (workspace_size != 0): - builder_config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size) + if workspace_size != 0: + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.WORKSPACE, workspace_size + ) builder_config.workspace_size = workspace_size From 24793009d599e178b5c097aaa7c5657707e2eb33 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 7 Apr 2023 15:41:30 -0700 Subject: [PATCH 24/45] feat: Add new `convert_module` function - Improve overall documentation and commenting, improve code delineation and separation of functionality --- py/torch_tensorrt/dynamo/_compiler.py | 65 ++----------------- py/torch_tensorrt/dynamo/backends.py | 60 ++++++++++++++++- py/torch_tensorrt/dynamo/conversion.py | 48 ++++++++++++++ py/torch_tensorrt/dynamo/lowering/__init__.py | 2 +- .../dynamo/lowering/_partition.py | 49 ++++++++++++-- 5 files changed, 155 insertions(+), 69 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9804c691ff..50472264cd 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1,26 +1,17 @@ import torch import logging +from typing import Sequence, Any from torch_tensorrt import EngineCapability, Device -from torch_tensorrt.dynamo.lowering._partition import partition from torch_tensorrt.dynamo import create_backend -from torch_tensorrt.fx.fx2trt import ( - InputTensorSpec, - TRTInterpreter, -) -import tensorrt as trt - -from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.fx.utils import LowerPrecision - logger = logging.getLogger(__name__) def compile( gm: torch.Module, - example_inputs, + example_inputs: Sequence[Any], *, device=Device._current_device(), disable_tf32=False, @@ -30,7 +21,7 @@ def compile( debug=False, capability=EngineCapability.default, num_avg_timing_iters=1, - workspace_size=0, + workspace_size=20 << 30, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, @@ -63,52 +54,8 @@ def compile( ) model = torch.compile(gm, backend=custom_backend) - # Ensure compilation - model(example_inputs) - - return model - - -def compile_logic(gm: torch.fx.GraphModule, example_inputs): - partitioned = partition(gm) - - precision = LowerPrecision.FP32 - - def get_submod_inputs(mod, submod, inputs): - """Helper function to get inputs to submodule""" - acc_inputs = None - def get_input(self, inputs): - nonlocal acc_inputs - acc_inputs = inputs + # Ensure compilation occurs by calling the function with provided inputs + model(*example_inputs) - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs - - for name, _ in partitioned.named_children(): - submod = getattr(partitioned, name) - - # Get submodule inputs - acc_inputs = get_submod_inputs(partitioned, submod, example_inputs) - - # Create TRT Module from submodule - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - logger_level=trt.Logger.VERBOSE, - ) - - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, - ) - trt_mod = TRTModule(*r) - - # Replace FX Module with TRT Module - setattr(partitioned, name, trt_mod) - - return partitioned + return model diff --git a/py/torch_tensorrt/dynamo/backends.py b/py/torch_tensorrt/dynamo/backends.py index 2d010c7f2a..55680dd221 100644 --- a/py/torch_tensorrt/dynamo/backends.py +++ b/py/torch_tensorrt/dynamo/backends.py @@ -6,11 +6,22 @@ from torch_tensorrt import EngineCapability, Device from torch_tensorrt.dynamo import compile +from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs +from torch_tensorrt.dynamo.conversion import convert_module + from torch._dynamo.backends.common import fake_tensor_unsupported from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt + +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision logger = logging.getLogger(__name__) @@ -97,7 +108,7 @@ def fx_dynamo_backend( ): """Helper function to manage translation of FX module to TRT engines""" try: - trt_compiled = compile(gm, example_inputs) + trt_compiled = compile_module(gm, example_inputs) return trt_compiled except: traceback.print_exc() @@ -106,3 +117,48 @@ def fx_dynamo_backend( + "Returning GraphModule forward instead." ) return gm.forward + + +def compile_module( + gm: torch.fx.GraphModule, + example_inputs, + debug: bool = False, + workspace_size: int = 20 << 30, + precision: LowerPrecision = LowerPrecision.FP32, +) -> torch.fx.GraphModule: + """Convert an FX module to a TRT module + Args: + module: FX GraphModule to convert + inputs: Inputs to the module + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + TRTModule or TRTModuleNext + """ + # Partition module into components that can be TRT-accelerated + partitioned_module = partition(gm) + + # Iterate over all components that can be accelerated + # Generate the corresponding TRT Module for those + for name, _ in partitioned_module.named_children(): + submodule = getattr(partitioned_module, name) + + # Get submodule inputs + submodule_inputs = get_submod_inputs( + partitioned_module, submodule, example_inputs + ) + + # Create TRT Module from submodule + trt_mod = convert_module( + submodule, + submodule_inputs, + debug=debug, + workspace_size=workspace_size, + precision=precision, + ) + + # Replace FX Module with TRT Module + setattr(partitioned_module, name, trt_mod) + + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/conversion.py b/py/torch_tensorrt/dynamo/conversion.py new file mode 100644 index 0000000000..4f495dad4b --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion.py @@ -0,0 +1,48 @@ +from typing import Sequence, Union +import torch +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt import TRTModuleNext +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +from torch_tensorrt.fx.utils import LowerPrecision + +import tensorrt as trt + + +def convert_module( + module: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], + debug: bool = False, + workspace_size: int = 20 << 30, + precision: LowerPrecision = LowerPrecision.FP32, +) -> Union[TRTModuleNext, TRTModule]: + """Convert an FX module to a TRT module + Args: + module: FX GraphModule to convert + inputs: Sequence of Tensors representing inputs to the module + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + TRTModule or TRTModuleNext + """ + interp = TRTInterpreter( + module, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING), + ) + + r = interp.run( + max_workspace_size=workspace_size, + lower_precision=precision, + profiling_verbosity=( + trt.ProfilingVerbosity.VERBOSE + if debug + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ), + ) + + return TRTModule(*r) diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index a57579d4ca..930cd17fb6 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,2 +1,2 @@ from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions -from torch_tensorrt.dynamo.lowering._partition import partition +from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/lowering/_partition.py index d96450f41d..13ba9c19a7 100644 --- a/py/torch_tensorrt/dynamo/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/lowering/_partition.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional, Sequence import torch @@ -12,7 +12,7 @@ class TorchTensorRTOperatorSupport(OperatorSupport): - """Class to determine whether the aten operators have converters""" + """Class to determine whether operators within a module are supported""" def __init__(self, support_dict=None): super().__init__(support_dict) @@ -38,7 +38,7 @@ def is_node_supported( return False - def print_support_overview(self, num_trt_blocks=None): + def print_support_overview(self, num_trt_blocks: Optional[int] = None): if num_trt_blocks is not None: print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n") @@ -51,9 +51,20 @@ def print_support_overview(self, num_trt_blocks=None): print(node_name) -def partition(gm: torch.fx.GraphModule, verbose=True): +def partition( + gm: torch.fx.GraphModule, + verbose: bool = True, + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, +) -> torch.fx.GraphModule: """Partition an FX GraphModule with aten ops into TRT engines - Partitioning is based on operator support + Partitioning is based on converter operator support + + Args: + gm: FX GraphModule to partition + verbose: Bool representing whether to print operator support + max_num_trt_engines: Maximum number of allowed TRT engines in partitioning + Returns: + torch.fx.GraphModule """ supported_ops = TorchTensorRTOperatorSupport() partitioner = CapabilityBasedPartitioner(gm, supported_ops) @@ -62,10 +73,10 @@ def partition(gm: torch.fx.GraphModule, verbose=True): # exceeds a specified threshold partitions = partitioner.propose_partitions() num_blocks = len(partitions) - if num_blocks > MAX_NUM_TRT_ENGINES: + if num_blocks > max_num_trt_engines: raise AssertionError( f"The graph module has {num_blocks} TRT Engines which is larger than the " - + f"threshold={MAX_NUM_TRT_ENGINES}. Falling back to non-TRT module." + + f"threshold={max_num_trt_engines}. Falling back to non-TRT module." ) # Fuse partitions and display overview of supported/unsupported operators @@ -76,3 +87,27 @@ def partition(gm: torch.fx.GraphModule, verbose=True): supported_ops.print_support_overview(num_blocks) return fused_graph + + +def get_submod_inputs( + mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, inputs +) -> Sequence[torch.Tensor]: + """Helper function to get inputs to a Torch submodule + + Args: + mod: Parent FX GraphModule + submod: Child FX GraphModule + inputs: Sample inputs to parent module + Returns: + Sequence of Tensors representing inputs to child module + """ + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs From eea388414ebbd8399aebe103e7052f53b453a337 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 7 Apr 2023 22:42:58 -0700 Subject: [PATCH 25/45] fix: Improve `torch_tensorrt` Dynamo path - Add dedicated settings and defaults files to centralize data and improve code readability, as well as reduce duplication of code - Improve documentation of functions, types, and comments - Rework logic to make compiler more uniform with existing torch tensorrt compilers, while retaining key Dynamo keywords needed for compilation via the torch.compile path --- py/torch_tensorrt/dynamo/_compiler.py | 60 ++++++--- py/torch_tensorrt/dynamo/_defaults.py | 7 ++ py/torch_tensorrt/dynamo/_settings.py | 17 +++ py/torch_tensorrt/dynamo/backends.py | 116 ++++++++---------- .../dynamo/lowering/_partition.py | 4 +- 5 files changed, 114 insertions(+), 90 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/_defaults.py create mode 100644 py/torch_tensorrt/dynamo/_settings.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 50472264cd..43b80dedde 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1,10 +1,18 @@ import torch import logging -from typing import Sequence, Any +import torch_tensorrt +from typing import Sequence, Any from torch_tensorrt import EngineCapability, Device - from torch_tensorrt.dynamo import create_backend +from torch_tensorrt.fx.utils import LowerPrecision + +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, +) + logger = logging.getLogger(__name__) @@ -18,10 +26,10 @@ def compile( sparse_weights=False, enabled_precisions=set(), refit=False, - debug=False, + debug=DEBUG, capability=EngineCapability.default, num_avg_timing_iters=1, - workspace_size=20 << 30, + workspace_size=MAX_WORKSPACE_SIZE, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, @@ -31,26 +39,38 @@ def compile( min_block_size=3, torch_executed_ops=[], torch_executed_modules=[], + **kwargs, ): + + logger.warn( + "The Dynamo backend is an experimental feature, for which only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" + ) + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + lower_precision = LowerPrecision.FP16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + lower_precision = LowerPrecision.FP32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + lower_precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + custom_backend = create_backend( - device=device, - disable_tf32=disable_tf32, - sparse_weights=sparse_weights, - enabled_precisions=enabled_precisions, - refit=refit, + precision=lower_precision, debug=debug, - capability=capability, - num_avg_timing_iters=num_avg_timing_iters, workspace_size=workspace_size, - dla_sram_size=dla_sram_size, - dla_local_dram_size=dla_local_dram_size, - dla_global_dram_size=dla_global_dram_size, - calibrator=calibrator, - truncate_long_and_double=truncate_long_and_double, - require_full_compilation=require_full_compilation, - min_block_size=min_block_size, - torch_executed_ops=torch_executed_ops, - torch_executed_modules=torch_executed_modules, + **kwargs, ) model = torch.compile(gm, backend=custom_backend) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py new file mode 100644 index 0000000000..814331e158 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -0,0 +1,7 @@ +from torch_tensorrt.fx.utils import LowerPrecision + + +PRECISION = LowerPrecision.FP32 +DEBUG = False +MAX_WORKSPACE_SIZE = 20 << 30 +MAX_NUM_TRT_ENGINES = 10 diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py new file mode 100644 index 0000000000..39693651e7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, + MAX_NUM_TRT_ENGINES, +) + + +@dataclass(frozen=True) +class CompilationSettings: + precision: LowerPrecision = (PRECISION,) + debug: bool = (DEBUG,) + workspace_size: int = (MAX_WORKSPACE_SIZE,) + max_num_trt_engines: int = (MAX_NUM_TRT_ENGINES,) diff --git a/py/torch_tensorrt/dynamo/backends.py b/py/torch_tensorrt/dynamo/backends.py index 55680dd221..096da7b316 100644 --- a/py/torch_tensorrt/dynamo/backends.py +++ b/py/torch_tensorrt/dynamo/backends.py @@ -3,9 +3,14 @@ import traceback from functools import partial import torch._dynamo as td -from torch_tensorrt import EngineCapability, Device -from torch_tensorrt.dynamo import compile +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, + MAX_NUM_TRT_ENGINES, +) +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs from torch_tensorrt.dynamo.conversion import convert_module @@ -14,55 +19,38 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -from torch_tensorrt.fx.fx2trt import ( - InputTensorSpec, - TRTInterpreter, -) -import tensorrt as trt - -from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision logger = logging.getLogger(__name__) def create_backend( - input_signature=None, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=False, - capability=EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=20 << 30, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - min_block_size=3, - torch_executed_ops=[], - torch_executed_modules=[], + precision: LowerPrecision = PRECISION, + debug: bool = DEBUG, + workspace_size: int = MAX_WORKSPACE_SIZE, + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + **kwargs ): - logger.warn( - "The Dynamo backend is an experimental feature, for which the " - + "following arguments are unsupported: " - + "{input_signature, disable_tf32, sparse_weights, refit, capability, " - + "num_avg_timing_iters, dla_sram_size, dla_local_dram_size, " - + "dla_global_dram_size, calibrator, truncate_long_and_double, " - + "require_full_compilation, min_block_size, torch_executed_ops, " - + "torch_executed_modules}" + """Create torch.compile backend given specified arguments + + Args: + precision: + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + Backend for torch.compile + """ + settings = CompilationSettings( + debug=debug, + precision=precision, + workspace_size=workspace_size, + max_num_trt_engines=max_num_trt_engines, ) return partial( tensorrt_backend, - debug=debug, - enabled_precisions=enabled_precisions, - device=device, - workspace_size=workspace_size, + settings=settings, ) @@ -71,19 +59,12 @@ def create_backend( def tensorrt_backend( gm: torch.Module, sample_inputs, - *, - debug=False, - enabled_precisions=set(), - device=Device._current_device(), - workspace_size=20 << 30, + settings: CompilationSettings = CompilationSettings(), ): custom_backend = partial( fx_dynamo_backend, - debug=debug, - enabled_precisions=enabled_precisions, - device=device, - workspace_size=workspace_size, + settings=settings, ) # Invoke AOTAutograd to translate operators to aten @@ -100,15 +81,15 @@ def tensorrt_backend( def fx_dynamo_backend( gm: torch.fx.GraphModule, example_inputs, - *, - debug=False, - enabled_precisions=set(), - device=Device._current_device(), - workspace_size=20 << 30, + settings: CompilationSettings = CompilationSettings(), ): """Helper function to manage translation of FX module to TRT engines""" try: - trt_compiled = compile_module(gm, example_inputs) + trt_compiled = compile_module( + gm, + example_inputs, + settings=settings, + ) return trt_compiled except: traceback.print_exc() @@ -122,22 +103,23 @@ def fx_dynamo_backend( def compile_module( gm: torch.fx.GraphModule, example_inputs, - debug: bool = False, - workspace_size: int = 20 << 30, - precision: LowerPrecision = LowerPrecision.FP32, + settings: CompilationSettings = CompilationSettings(), ) -> torch.fx.GraphModule: - """Convert an FX module to a TRT module + """Compile an FX module + + Includes: Partitioning + Conversion Phases + Args: module: FX GraphModule to convert inputs: Inputs to the module - debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module - precision: Model Layer precision + settings: Compilation settings Returns: - TRTModule or TRTModuleNext + Compiled FX GraphModule """ # Partition module into components that can be TRT-accelerated - partitioned_module = partition(gm) + partitioned_module = partition( + gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines + ) # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those @@ -153,9 +135,9 @@ def compile_module( trt_mod = convert_module( submodule, submodule_inputs, - debug=debug, - workspace_size=workspace_size, - precision=precision, + debug=settings.debug, + workspace_size=settings.workspace_size, + precision=settings.precision, ) # Replace FX Module with TRT Module diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/lowering/_partition.py index 13ba9c19a7..a4f2b79da0 100644 --- a/py/torch_tensorrt/dynamo/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/lowering/_partition.py @@ -2,15 +2,13 @@ import torch +from torch_tensorrt.dynamo._defaults import MAX_NUM_TRT_ENGINES from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch_tensorrt.fx.converter_registry import CONVERTERS -MAX_NUM_TRT_ENGINES = 10 - - class TorchTensorRTOperatorSupport(OperatorSupport): """Class to determine whether operators within a module are supported""" From aa0dda88d2f7cd6374baffd244d0c37765c7f02d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Sat, 8 Apr 2023 15:06:06 -0700 Subject: [PATCH 26/45] fix: Move key functions, fix bugs - Improve overall functionality, fix bugs - Move functions into __init__.py - Improve overall documentation, comments, function header typing, and code organization --- py/torch_tensorrt/dynamo/__init__.py | 121 +++++++++++++++++- py/torch_tensorrt/dynamo/_compiler.py | 81 ------------ py/torch_tensorrt/dynamo/backends.py | 61 ++------- .../dynamo/lowering/_partition.py | 20 ++- 4 files changed, 146 insertions(+), 137 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/_compiler.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index fd036ffa8e..1de768571f 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,2 +1,119 @@ -from _compiler import compile -from backends import create_backend +import torch +import logging +import torch_tensorrt +from functools import partial + +from typing import Sequence, Any +from torch_tensorrt import EngineCapability, Device +from torch_tensorrt.fx.utils import LowerPrecision + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.backends import tensorrt_backend +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, + MAX_NUM_TRT_ENGINES, +) + + +logger = logging.getLogger(__name__) + + +def compile( + gm: torch.nn.Module, + example_inputs: Sequence[Any], + *, + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=DEBUG, + capability=EngineCapability.default, + num_avg_timing_iters=1, + workspace_size=MAX_WORKSPACE_SIZE, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + min_block_size=3, + torch_executed_ops=[], + torch_executed_modules=[], + **kwargs, +): + + logger.warn( + "The Dynamo backend is an experimental feature, for which only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" + ) + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + lower_precision = LowerPrecision.FP16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + lower_precision = LowerPrecision.FP32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + lower_precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + + custom_backend = create_backend( + precision=lower_precision, + debug=debug, + workspace_size=workspace_size, + **kwargs, + ) + + model = torch.compile(gm, backend=custom_backend) + + # Ensure compilation occurs by calling the function with provided inputs + model(*example_inputs) + + return model + + +from torch_tensorrt.fx.utils import LowerPrecision + +logger = logging.getLogger(__name__) + + +def create_backend( + precision: LowerPrecision = PRECISION, + debug: bool = DEBUG, + workspace_size: int = MAX_WORKSPACE_SIZE, + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + **kwargs, +): + """Create torch.compile backend given specified arguments + + Args: + precision: + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + Backend for torch.compile + """ + settings = CompilationSettings( + debug=debug, + precision=precision, + workspace_size=workspace_size, + max_num_trt_engines=max_num_trt_engines, + ) + + return partial( + tensorrt_backend, + settings=settings, + ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py deleted file mode 100644 index 43b80dedde..0000000000 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import logging -import torch_tensorrt - -from typing import Sequence, Any -from torch_tensorrt import EngineCapability, Device -from torch_tensorrt.dynamo import create_backend -from torch_tensorrt.fx.utils import LowerPrecision - -from torch_tensorrt.dynamo._defaults import ( - PRECISION, - DEBUG, - MAX_WORKSPACE_SIZE, -) - - -logger = logging.getLogger(__name__) - - -def compile( - gm: torch.Module, - example_inputs: Sequence[Any], - *, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=DEBUG, - capability=EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=MAX_WORKSPACE_SIZE, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - min_block_size=3, - torch_executed_ops=[], - torch_executed_modules=[], - **kwargs, -): - - logger.warn( - "The Dynamo backend is an experimental feature, for which only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" - ) - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - lower_precision = LowerPrecision.FP16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - lower_precision = LowerPrecision.FP32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - lower_precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) - - custom_backend = create_backend( - precision=lower_precision, - debug=debug, - workspace_size=workspace_size, - **kwargs, - ) - - model = torch.compile(gm, backend=custom_backend) - - # Ensure compilation occurs by calling the function with provided inputs - model(*example_inputs) - - return model diff --git a/py/torch_tensorrt/dynamo/backends.py b/py/torch_tensorrt/dynamo/backends.py index 096da7b316..ad8a14fd65 100644 --- a/py/torch_tensorrt/dynamo/backends.py +++ b/py/torch_tensorrt/dynamo/backends.py @@ -1,15 +1,9 @@ +from typing import Sequence import torch -import logging import traceback from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo._defaults import ( - PRECISION, - DEBUG, - MAX_WORKSPACE_SIZE, - MAX_NUM_TRT_ENGINES, -) from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs @@ -19,49 +13,14 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -from torch_tensorrt.fx.utils import LowerPrecision - -logger = logging.getLogger(__name__) - - -def create_backend( - precision: LowerPrecision = PRECISION, - debug: bool = DEBUG, - workspace_size: int = MAX_WORKSPACE_SIZE, - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, - **kwargs -): - """Create torch.compile backend given specified arguments - - Args: - precision: - debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module - precision: Model Layer precision - Returns: - Backend for torch.compile - """ - settings = CompilationSettings( - debug=debug, - precision=precision, - workspace_size=workspace_size, - max_num_trt_engines=max_num_trt_engines, - ) - - return partial( - tensorrt_backend, - settings=settings, - ) - @td.register_backend(name="tensorrt") @fake_tensor_unsupported def tensorrt_backend( - gm: torch.Module, - sample_inputs, + gm: torch.nn.Module, + sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ): - custom_backend = partial( fx_dynamo_backend, settings=settings, @@ -80,10 +39,18 @@ def tensorrt_backend( @fake_tensor_unsupported def fx_dynamo_backend( gm: torch.fx.GraphModule, - example_inputs, + example_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ): - """Helper function to manage translation of FX module to TRT engines""" + """Helper function to manage translation of FX module to TRT engines + + Args: + module: FX GraphModule to convert + inputs: Inputs to the module + settings: Compilation settings + Returns: + Compiled FX GraphModule + """ try: trt_compiled = compile_module( gm, @@ -102,7 +69,7 @@ def fx_dynamo_backend( def compile_module( gm: torch.fx.GraphModule, - example_inputs, + example_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ) -> torch.fx.GraphModule: """Compile an FX module diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/lowering/_partition.py index a4f2b79da0..cbd4904515 100644 --- a/py/torch_tensorrt/dynamo/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/lowering/_partition.py @@ -38,15 +38,19 @@ def is_node_supported( def print_support_overview(self, num_trt_blocks: Optional[int] = None): if num_trt_blocks is not None: - print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n") + print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}") - print("Supported Nodes:") + print("\nSupported Nodes:") for node_name in self.supported_operators: - print(node_name) + print("-", node_name) - print("\nUnsupported Nodes:") - for node_name in self.unsupported_operators: - print(node_name) + if len(self.unsupported_operators) != 0: + print("\nUnsupported Nodes:") + for node_name in self.unsupported_operators: + print("-", node_name) + print("\n") + else: + print("\nAll Nodes Supported\n") def partition( @@ -88,7 +92,9 @@ def partition( def get_submod_inputs( - mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, inputs + mod: torch.fx.GraphModule, + submod: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], ) -> Sequence[torch.Tensor]: """Helper function to get inputs to a Torch submodule From a6d3a64e3f187d2e97797aa18cdb335087bdd33d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 10 Apr 2023 10:12:48 -0700 Subject: [PATCH 27/45] chore: refactor device related code Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_Device.py | 31 ++++++++++++++++++--------- py/torch_tensorrt/_enums.py | 3 ++- py/torch_tensorrt/dynamo/fx2trt.py | 2 -- py/torch_tensorrt/dynamo/lower.py | 14 +++++++++++- py/torch_tensorrt/ts/_compile_spec.py | 2 ++ 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 0662e17aa1..a954a8f919 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -1,11 +1,16 @@ import torch -from torch_tensorrt import _enums +# from torch_tensorrt import _enums +import tensorrt as trt from torch_tensorrt import logging -from torch_tensorrt import _C - import warnings +try: + from torch_tensorrt import _C +except: + warnings.warn("Unable to import _C extension of Torch-TensorRT. Some methods might be unavailable. You can ignore this error if you're \ + not using any functions dependent on internal C++ APIs") + class Device(object): """ @@ -51,7 +56,7 @@ def __init__(self, *args, **kwargs): ) else: (self.device_type, id) = Device._parse_device_str(args[0]) - if self.device_type == _enums.DeviceType.GPU: + if self.device_type == trt.DeviceType.GPU: self.gpu_id = id else: self.dla_core = id @@ -64,7 +69,7 @@ def __init__(self, *args, **kwargs): elif len(args) == 0: if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: - self.device_type = _enums.DeviceType.DLA + self.device_type = trt.DeviceType.DLA self.dla_core = kwargs["dla_core"] if "gpu_id" in kwargs: self.gpu_id = kwargs["gpu_id"] @@ -76,7 +81,7 @@ def __init__(self, *args, **kwargs): ) else: self.gpu_id = kwargs["gpu_id"] - self.device_type = _enums.DeviceType.GPU + self.device_type = trt.DeviceType.GPU else: raise ValueError( "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" @@ -97,7 +102,7 @@ def __init__(self, *args, **kwargs): def __str__(self) -> str: return ( "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")" - if self.device_type == _enums.DeviceType.GPU + if self.device_type == trt.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format( self.dla_core, self.allow_gpu_fallback ) @@ -105,7 +110,13 @@ def __str__(self) -> str: def _to_internal(self) -> _C.Device: internal_dev = _C.Device() - internal_dev.device_type = self.device_type + if (self.device_type == trt.DeviceType.GPU): + internal_dev.device_type = _C.DeviceType.GPU + elif (self.device_type == trt.DeviceType.DLA): + internal_dev.device_type = _C.DeviceType.DLA + else: + raise ValueError("Invalid DeviceType detected while parsing the Device class") + internal_dev.gpu_id = self.gpu_id internal_dev.dla_core = self.dla_core internal_dev.allow_gpu_fallback = self.allow_gpu_fallback @@ -136,6 +147,6 @@ def _parse_device_str(s): s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": - return (_enums.DeviceType.GPU, int(spec[1])) + return (trt.DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": - return (_enums.DeviceType.DLA, int(spec[1])) + return (trt.DeviceType.DLA, int(spec[1])) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index bc9ed42df4..63dffceb9d 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -1 +1,2 @@ -from torch_tensorrt._C import dtype, DeviceType, EngineCapability, TensorFormat +from torch_tensorrt._C import dtype, EngineCapability, TensorFormat +from tensorrt import DeviceType diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx2trt.py index b3a68f7a28..1971e219e0 100644 --- a/py/torch_tensorrt/dynamo/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx2trt.py @@ -213,8 +213,6 @@ def run( trt.MemoryPoolType.WORKSPACE, workspace_size ) - builder_config.workspace_size = workspace_size - cache = None if timing_cache: cache_file = numpy.array(timing_cache) diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py index 581aad50fc..4e6ceebbd4 100644 --- a/py/torch_tensorrt/dynamo/lower.py +++ b/py/torch_tensorrt/dynamo/lower.py @@ -20,7 +20,7 @@ from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer from torch_tensorrt.fx.trt_module import TRTModule from .utils import LowerPrecision - +from torch_tensorrt._Device import Device logger = logging.getLogger(__name__) Input = Sequence[Any] @@ -64,6 +64,7 @@ def compile( "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True" ) + # Parse precision into LowerPrecision lower_precision = LowerPrecision.FP32 if torch.float16 in enabled_precisions: lower_precision = LowerPrecision.FP16 @@ -71,6 +72,17 @@ def compile( lower_precision = LowerPrecision.FP32 else: raise ValueError(f"Precision {enabled_precisions} not supported on FX") + + # Parse device + if isinstance(device, Device): + if (device.gpu_id != -1): + device = torch.device(device.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + elif isinstance(device, torch.device): + device = device + else: + raise ValueError("Invalid device provided. Supported options: torch.device | torch_tensorrt.Device") lower_setting = LowerSetting( device=device, diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index b29d386118..1632952f39 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -10,6 +10,7 @@ import warnings from copy import deepcopy from torch_tensorrt.ts.ts_input import TSInput +# from torch_tensorrt.ts.ts_device import TSDevice def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: @@ -124,6 +125,7 @@ def _parse_device(device_info: Any) -> _C.Device: return info elif isinstance(device_info, Device): + # ts_device = TSDevice(gpu_id=device_info.gpu_id, dla_core=device_info.dla_core, allow_gpu_fallback=device_info.allow_gpu_fallback) return device_info._to_internal() elif isinstance(device_info, torch.device): return (Device._from_torch_device(device_info))._to_internal() From 5390259a3ed57bf62ace2e69a6c5bc4906e463cd Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 10 Apr 2023 10:13:57 -0700 Subject: [PATCH 28/45] chore: Linter fixes Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_Device.py | 14 +++++++++----- py/torch_tensorrt/dynamo/lower.py | 9 ++++++--- py/torch_tensorrt/ts/_compile_spec.py | 1 + 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index a954a8f919..4a53ad8885 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -8,8 +8,10 @@ try: from torch_tensorrt import _C except: - warnings.warn("Unable to import _C extension of Torch-TensorRT. Some methods might be unavailable. You can ignore this error if you're \ - not using any functions dependent on internal C++ APIs") + warnings.warn( + "Unable to import _C extension of Torch-TensorRT. Some methods might be unavailable. You can ignore this error if you're \ + not using any functions dependent on internal C++ APIs" + ) class Device(object): @@ -110,12 +112,14 @@ def __str__(self) -> str: def _to_internal(self) -> _C.Device: internal_dev = _C.Device() - if (self.device_type == trt.DeviceType.GPU): + if self.device_type == trt.DeviceType.GPU: internal_dev.device_type = _C.DeviceType.GPU - elif (self.device_type == trt.DeviceType.DLA): + elif self.device_type == trt.DeviceType.DLA: internal_dev.device_type = _C.DeviceType.DLA else: - raise ValueError("Invalid DeviceType detected while parsing the Device class") + raise ValueError( + "Invalid DeviceType detected while parsing the Device class" + ) internal_dev.gpu_id = self.gpu_id internal_dev.dla_core = self.dla_core diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py index 4e6ceebbd4..af81313f2f 100644 --- a/py/torch_tensorrt/dynamo/lower.py +++ b/py/torch_tensorrt/dynamo/lower.py @@ -21,6 +21,7 @@ from torch_tensorrt.fx.trt_module import TRTModule from .utils import LowerPrecision from torch_tensorrt._Device import Device + logger = logging.getLogger(__name__) Input = Sequence[Any] @@ -72,17 +73,19 @@ def compile( lower_precision = LowerPrecision.FP32 else: raise ValueError(f"Precision {enabled_precisions} not supported on FX") - + # Parse device if isinstance(device, Device): - if (device.gpu_id != -1): + if device.gpu_id != -1: device = torch.device(device.gpu_id) else: raise ValueError("Invalid GPU ID provided for the CUDA device provided") elif isinstance(device, torch.device): device = device else: - raise ValueError("Invalid device provided. Supported options: torch.device | torch_tensorrt.Device") + raise ValueError( + "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" + ) lower_setting = LowerSetting( device=device, diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 1632952f39..f7258a5066 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -10,6 +10,7 @@ import warnings from copy import deepcopy from torch_tensorrt.ts.ts_input import TSInput + # from torch_tensorrt.ts.ts_device import TSDevice From 87e4c7741601b09018f7c8aa9fe32cfd10537c5c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 10 Apr 2023 12:19:58 -0700 Subject: [PATCH 29/45] chore: linter fixes Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/ts/_compile_spec.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index f7258a5066..f17a9fa5bf 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -10,8 +10,7 @@ import warnings from copy import deepcopy from torch_tensorrt.ts.ts_input import TSInput - -# from torch_tensorrt.ts.ts_device import TSDevice +import tensorrt as trt def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: @@ -78,20 +77,24 @@ def _parse_enabled_precisions(precisions: Any) -> Set: def _parse_device_type(device: Any) -> _enums.DeviceType: if isinstance(device, torch.device): if device.type == "cuda": - return _enums.DeviceType.gpu + return _C.DeviceType.gpu else: ValueError( "Got a device type other than GPU or DLA (type: " + str(device.type) + ")" ) - elif isinstance(device, _enums.DeviceType): + elif isinstance(device, _C.DeviceType): return device + elif isinstance(device, trt.DeviceType): + if device == trt.DeviceType.DLA: + return _C.DeviceType.DLA + return _C.DeviceType.GPU elif isinstance(device, str): if device == "gpu" or device == "GPU": - return _enums.DeviceType.gpu + return _C.DeviceType.GPU elif device == "dla" or device == "DLA": - return _enums.DeviceType.dla + return _C.DeviceType.DLA else: ValueError( "Got a device type other than GPU or DLA (type: " + str(device) + ")" @@ -109,7 +112,6 @@ def _parse_device(device_info: Any) -> _C.Device: if "device_type" not in device_info: raise KeyError("Device type is required parameter") else: - assert isinstance(device_info["device_type"], _enums.DeviceType) info.device_type = _parse_device_type(device_info["device_type"]) if "gpu_id" in device_info: @@ -126,7 +128,6 @@ def _parse_device(device_info: Any) -> _C.Device: return info elif isinstance(device_info, Device): - # ts_device = TSDevice(gpu_id=device_info.gpu_id, dla_core=device_info.dla_core, allow_gpu_fallback=device_info.allow_gpu_fallback) return device_info._to_internal() elif isinstance(device_info, torch.device): return (Device._from_torch_device(device_info))._to_internal() From a12141c19b5b8e2a094dd85f939fc5a3f7fb3840 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 10 Apr 2023 15:25:04 -0700 Subject: [PATCH 30/45] chore: Fix dynamo tests Signed-off-by: Dheeraj Peri --- .../dynamo/test/trt_lower/test_fx2trt_lower.py | 8 ++++---- .../dynamo/test/trt_lower/test_observer_gpu.py | 4 ++-- .../dynamo/test/trt_lower/trt_splitter_test.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py index 4c4948ba2a..7546346d18 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py @@ -20,7 +20,7 @@ def forward(self, x): mod = _Mod() mod_traced = fx.symbolic_trace(mod) - input = [torch.rand(4)] + input = [torch.rand(4).cuda()] lower = Lowerer.create(LowerSetting()) lower(mod_traced, input) @@ -39,7 +39,7 @@ def forward(self, x): return self.bn(x) module = TestModule() - inputs = [torch.randn(1, 3, 224, 224)] + inputs = [torch.randn(1, 3, 224, 224).cuda()] lower = Lowerer.create(LowerSetting(ast_rewriter_allow_list={MyBatchNorm})) lower(module, inputs) @@ -53,7 +53,7 @@ def forward(self, x): return (torch.sqrt(x), self.a) lower = Lowerer.create(LowerSetting()) - lower(TestModule(), [torch.randn([2, 2])]) + lower(TestModule(), [torch.randn([2, 2]).cuda()]) def test_replace_mutable_op(self): class TestModule(torch.nn.Module): @@ -65,7 +65,7 @@ def forward(self, x, y): lower = Lowerer.create(LowerSetting()) mod_traced = fx.symbolic_trace(TestModule()) - lower(mod_traced, [torch.randn(3, 4), torch.randn(3, 4)]) + lower(mod_traced, [torch.randn(3, 4).cuda(), torch.randn(3, 4).cuda()]) def test_replace_mutable_op_dont_apply(self): class TestModule(torch.nn.Module): diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py index 824cc66be9..005891cbfd 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py @@ -19,8 +19,8 @@ def test_observe_lowerer(self): import torch import torch.nn as nn - import torch_tensorrt.fx.lower as lower - from torch_tensorrt.fx.lower_setting import LowerSetting + import torch_tensorrt.dynamo.lower as lower + from torch_tensorrt.dynamo.lower_setting import LowerSetting class Model(nn.Module): def forward(self, x, y): diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py index 9d96bf78b0..7738b829b9 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py @@ -358,12 +358,12 @@ def test_splitter(splitter): test_splitter(splitter) - def test_min_block_size(self): + def test_min_acc_module_size(self): """ sin relu cos sigmoid tanh a ====> b =====> c ====> d ========> e =====> f - We set sin, cos and tanh as acc node but also set min_block_size to 2 + We set sin, cos and tanh as acc node but also set min_acc_module_size to 2 and expect the whole module stay on CPU. """ @@ -386,9 +386,9 @@ class CustomOpSupport(op_support.OperatorSupport): "acc_ops.tanh": None, } - # Create splitter setting and set min_block_size to 2 + # Create splitter setting and set min_acc_module_size to 2 settings = splitter_base._SplitterSettingBase() - settings.min_block_size = 2 + settings.min_acc_module_size = 2 splitter = TRTSplitter( mod, (torch.randn(2, 3),), @@ -815,7 +815,7 @@ def test_split_non_tensor_edges_2(self): # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC with limit on ACC # subgraph size settings = splitter_base._SplitterSettingBase() - settings.min_block_size = 2 + settings.min_acc_module_size = 2 splitter = TRTSplitter( module_nn, (test_data,), @@ -912,7 +912,7 @@ def test_split_non_tensor_edges_4(self): # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC # subgraph size settings = splitter_base._SplitterSettingBase() - settings.min_block_size = 2 + settings.min_acc_module_size = 2 splitter = TRTSplitter( module_nn, (test_data,), @@ -1072,7 +1072,7 @@ def test_start_with_acc_module_(self): sin relu cos sigmoid tanh a ====> b =====> c ====> d ========> e =====> f - We set sin, relu and cos as acc node but also set min_block_size to 2 + We set sin, relu and cos as acc node but also set min_acc_module_size to 2 and expect the whole module stay on CPU. """ @@ -1095,9 +1095,9 @@ class CustomOpSupport(op_support.OperatorSupport): "acc_ops.relu": None, } - # Create splitter setting and set min_block_size to 2 + # Create splitter setting and set min_acc_module_size to 2 settings = splitter_base._SplitterSettingBase() - settings.min_block_size = 2 + settings.min_acc_module_size = 2 splitter = TRTSplitter( mod, (torch.randn(2, 3),), From 6d2e01a02924f9edab022bbe67cfa04176c611cb Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 10 Apr 2023 21:12:34 -0700 Subject: [PATCH 31/45] fix: Add test cases and improve backend - Add support for Input objects, add utilities - Add modeling e2e test cases for Dynamo backend - Improve defaults and settings in Dynamo class --- py/torch_tensorrt/_Input.py | 29 +++++- py/torch_tensorrt/dynamo/__init__.py | 13 ++- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_settings.py | 8 +- py/torch_tensorrt/dynamo/utils.py | 66 +++++++++++++ tests/py/api/test_dynamo_backend.py | 136 ++++++++++++++++++++++++++ 6 files changed, 243 insertions(+), 11 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/utils.py create mode 100644 tests/py/api/test_dynamo_backend.py diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 324c385fab..4bd0cc3fb5 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -237,6 +237,27 @@ def _supported_input_size_type(input_size: Any) -> bool: else: return False + @staticmethod + def _dtype_to_torch_type(dtype: _enums.dtype) -> torch.dtype: + if isinstance(dtype, _enums.dtype): + if dtype == _enums.dtype.long: + return torch.long + elif dtype == _enums.dtype.int32: + return torch.int32 + elif dtype == _enums.dtype.half: + return torch.half + elif dtype == _enums.dtype.float: + return torch.float + elif dtype == _enums.dtype.bool: + return torch.bool + else: + raise TypeError( + "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " + + str(dtype) + ) + else: + raise ValueError("Did not provide an _enums.dtype type as input.") + @staticmethod def _parse_dtype(dtype: Any) -> _enums.dtype: if isinstance(dtype, torch.dtype): @@ -416,8 +437,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor ) if self.shape_mode == Input._ShapeMode.STATIC: - return torch.randn(self.shape).to(dtype=self.dtype) + return torch.rand(self.shape).to( + dtype=Input._dtype_to_torch_type(self.dtype) + ) else: - return torch.randn(self.shape[optimization_profile_field]).to( - dtype=self.dtype + return torch.rand(self.shape[optimization_profile_field]).to( + dtype=Input._dtype_to_torch_type(self.dtype) ) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 1de768571f..2497b99789 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,13 +1,15 @@ import torch import logging +import collections.abc import torch_tensorrt from functools import partial -from typing import Sequence, Any +from typing import Any from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device from torch_tensorrt.dynamo.backends import tensorrt_backend from torch_tensorrt.dynamo._defaults import ( PRECISION, @@ -22,7 +24,7 @@ def compile( gm: torch.nn.Module, - example_inputs: Sequence[Any], + inputs: Any, *, device=Device._current_device(), disable_tf32=False, @@ -51,6 +53,11 @@ def compile( + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" ) + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + inputs = prepare_inputs(inputs, prepare_device(device)) + if ( torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions @@ -79,7 +86,7 @@ def compile( model = torch.compile(gm, backend=custom_backend) # Ensure compilation occurs by calling the function with provided inputs - model(*example_inputs) + model(*inputs) return model diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 814331e158..48c9a26f9e 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -4,4 +4,4 @@ PRECISION = LowerPrecision.FP32 DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 -MAX_NUM_TRT_ENGINES = 10 +MAX_NUM_TRT_ENGINES = 200 diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 39693651e7..c632943f53 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -11,7 +11,7 @@ @dataclass(frozen=True) class CompilationSettings: - precision: LowerPrecision = (PRECISION,) - debug: bool = (DEBUG,) - workspace_size: int = (MAX_WORKSPACE_SIZE,) - max_num_trt_engines: int = (MAX_NUM_TRT_ENGINES,) + precision: LowerPrecision = PRECISION + debug: bool = DEBUG + workspace_size: int = MAX_WORKSPACE_SIZE + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py new file mode 100644 index 0000000000..c096eb9397 --- /dev/null +++ b/py/torch_tensorrt/dynamo/utils.py @@ -0,0 +1,66 @@ +import torch + +from typing import Any, Union, Sequence, Dict +from torch_tensorrt import _Input, Device + + +def prepare_inputs( + inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict], + device: torch.device = torch.device("cuda"), +) -> Any: + if isinstance(inputs, _Input.Input): + if isinstance(inputs.shape, dict): + return inputs.example_tensor(optimization_profile_field="opt_shape").to( + device + ) + else: + return inputs.example_tensor().to(device) + + elif isinstance(inputs, torch.Tensor): + return inputs + + elif isinstance(inputs, list): + prepared_input = list() + + for input_obj in inputs: + prepared_input.append(prepare_inputs(input_obj)) + + return prepared_input + + elif isinstance(inputs, tuple): + prepared_input = list() + + for input_obj in inputs: + prepared_input.append(prepare_inputs(input_obj)) + + return tuple(prepared_input) + + elif isinstance(inputs, dict): + prepared_input = dict() + + for key, input_obj in inputs.items(): + prepared_input[key] = prepare_inputs(input_obj) + + return prepared_input + + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + +def prepare_device(device: Union[Device, torch.device]) -> torch.device: + if isinstance(device, Device): + if device.gpu_id != -1: + device = torch.device(device.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + + elif isinstance(device, torch.device): + device = device + + else: + raise ValueError( + "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" + ) diff --git a/tests/py/api/test_dynamo_backend.py b/tests/py/api/test_dynamo_backend.py new file mode 100644 index 0000000000..77e2f344cd --- /dev/null +++ b/tests/py/api/test_dynamo_backend.py @@ -0,0 +1,136 @@ +import unittest +import torch +import timm + +import torch_tensorrt as torchtrt +import torchvision.models as models + +from transformers import BertModel +from utils import COSINE_THRESHOLD, cosine_similarity + + +class TestModels(unittest.TestCase): + def test_resnet18(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_mobilenet_v2(self): + self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_efficientnet_b0(self): + self.model = ( + timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + ) + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_bert_base_uncased(self): + self.model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() + self.input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + self.input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, + dtype=self.input.dtype, + format=torch.contiguous_format, + ), + torchtrt.Input( + self.input.shape, + dtype=self.input.dtype, + format=torch.contiguous_format, + ), + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + "debug": True, + } + trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) + + model_outputs = self.model(self.input, self.input2) + trt_model_outputs = trt_mod(self.input, self.input2) + for key in model_outputs.keys(): + out, trt_out = model_outputs[key], trt_model_outputs[key] + cos_sim = cosine_similarity(out, trt_out) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_resnet18_half(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda").half() + self.input = torch.randn((1, 3, 224, 224)).to("cuda").half() + + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, dtype=torch.half, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.half}, + } + + trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +if __name__ == "__main__": + unittest.main() From 48618f4246979c2808445aaa54da8e7f9c20a117 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 11 Apr 2023 14:25:01 -0700 Subject: [PATCH 32/45] chore: Fix device dict Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/lower.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/lower.py index af81313f2f..d16eb008c1 100644 --- a/py/torch_tensorrt/dynamo/lower.py +++ b/py/torch_tensorrt/dynamo/lower.py @@ -82,6 +82,12 @@ def compile( raise ValueError("Invalid GPU ID provided for the CUDA device provided") elif isinstance(device, torch.device): device = device + elif isinstance(device, dict): + if "device_type" in device and device["device_type"] == trt.DeviceType.GPU: + if "gpu_id" in device: + device = torch.device(device["gpu_id"]) + else: + device = torch.device("cuda:0") else: raise ValueError( "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" From d890b7d9e7ccde19b42f204ce64f403a4d7a8815 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 11 Apr 2023 15:20:21 -0700 Subject: [PATCH 33/45] chore: Nest it to dynamo/fx_ts_compat Signed-off-by: Dheeraj Peri --- .circleci/config.yml | 82 +++++------ .../fx_ts_compat/Dynamic_Shape_Support.md | 137 ++++++++++++++++++ .../dynamo/{ => fx_ts_compat}/README.md | 0 .../dynamo/{ => fx_ts_compat}/__init__.py | 0 .../dynamo/{ => fx_ts_compat}/fx2trt.py | 2 +- .../{ => fx_ts_compat}/input_tensor_spec.py | 0 .../dynamo/{ => fx_ts_compat}/lower.py | 0 .../{ => fx_ts_compat}/lower_setting.py | 0 .../{ => fx_ts_compat}/passes/__init__.py | 0 .../passes/lower_pass_manager_builder.py | 2 +- .../{ => fx_ts_compat}/passes/pass_utils.py | 0 .../acc_op/test_adaptive_avgpool.py | 2 +- .../test/converters/acc_op/test_any.py | 4 +- .../test/converters/acc_op/test_as_strided.py | 4 +- .../test/converters/acc_op/test_avgpool.py | 2 +- .../test/converters/acc_op/test_batchnorm.py | 2 +- .../test/converters/acc_op/test_binary_ops.py | 2 +- .../test/converters/acc_op/test_cat.py | 2 +- .../test/converters/acc_op/test_chunk.py | 2 +- .../test/converters/acc_op/test_clamp.py | 2 +- .../converters/acc_op/test_convolution.py | 2 +- .../test/converters/acc_op/test_dequantize.py | 2 +- .../test/converters/acc_op/test_einsum.py | 2 +- .../test/converters/acc_op/test_elu.py | 2 +- .../test/converters/acc_op/test_embedding.py | 2 +- .../test/converters/acc_op/test_eq.py | 2 +- .../test/converters/acc_op/test_expand.py | 2 +- .../test/converters/acc_op/test_flatten.py | 0 .../test/converters/acc_op/test_gelu.py | 2 +- .../test/converters/acc_op/test_getitem.py | 2 +- .../test/converters/acc_op/test_gt.py | 2 +- .../converters/acc_op/test_hard_sigmoid.py | 2 +- .../test/converters/acc_op/test_hardtanh.py | 2 +- .../converters/acc_op/test_interpolate.py | 2 +- .../test/converters/acc_op/test_isinf.py | 2 +- .../test/converters/acc_op/test_leaky_relu.py | 2 +- .../test/converters/acc_op/test_linear.py | 2 +- .../converters/acc_op/test_logical_and.py | 2 +- .../test/converters/acc_op/test_logical_or.py | 2 +- .../converters/acc_op/test_logical_xor.py | 2 +- .../test/converters/acc_op/test_lt.py | 2 +- .../converters/acc_op/test_masked_fill.py | 2 +- .../test/converters/acc_op/test_matmul.py | 2 +- .../test/converters/acc_op/test_max.py | 2 +- .../test/converters/acc_op/test_maximum.py | 2 +- .../test/converters/acc_op/test_maxpool.py | 2 +- .../test/converters/acc_op/test_min.py | 2 +- .../test/converters/acc_op/test_minimum.py | 2 +- .../test/converters/acc_op/test_narrow.py | 2 +- .../test/converters/acc_op/test_ne.py | 2 +- .../test/converters/acc_op/test_new_ones.py | 2 +- .../test/converters/acc_op/test_numel.py | 2 +- .../test/converters/acc_op/test_pad.py | 2 +- .../test/converters/acc_op/test_permute.py | 2 +- .../test/converters/acc_op/test_prod.py | 2 +- .../acc_op/test_quantize_per_tensor.py | 2 +- .../test/converters/acc_op/test_reduce_ops.py | 2 +- .../test/converters/acc_op/test_relu.py | 2 +- .../acc_op/test_repeat_interleave.py | 2 +- .../test/converters/acc_op/test_reshape.py | 2 +- .../test/converters/acc_op/test_selu.py | 2 +- .../test/converters/acc_op/test_sigmoid.py | 2 +- .../test/converters/acc_op/test_silu.py | 0 .../test/converters/acc_op/test_size.py | 2 +- .../test/converters/acc_op/test_softmax.py | 2 +- .../test/converters/acc_op/test_softsign.py | 2 +- .../test/converters/acc_op/test_split.py | 2 +- .../test/converters/acc_op/test_squeeze.py | 2 +- .../test/converters/acc_op/test_std.py | 2 +- .../test/converters/acc_op/test_tanh.py | 2 +- .../test/converters/acc_op/test_tile.py | 2 +- .../test/converters/acc_op/test_to_dtype.py | 4 +- .../test/converters/acc_op/test_topk.py | 2 +- .../acc_op/test_transpose_convolution.py | 2 +- .../test/converters/acc_op/test_type_as.py | 4 +- .../test/converters/acc_op/test_unary_ops.py | 2 +- .../test/converters/acc_op/test_unsqueeze.py | 2 +- .../test/converters/acc_op/test_where.py | 2 +- .../aten_op/test_adaptive_avgpool_aten.py | 2 +- .../converters/aten_op/test_batchnorm_aten.py | 2 +- .../aten_op/test_binary_ops_aten.py | 2 +- .../test/converters/aten_op/test_cat_aten.py | 2 +- .../aten_op/test_convolution_aten.py | 2 +- .../converters/aten_op/test_expand_aten.py | 2 +- .../converters/aten_op/test_flatten_aten.py | 2 +- .../converters/aten_op/test_linear_aten.py | 2 +- .../converters/aten_op/test_maxpool_aten.py | 2 +- .../test/converters/aten_op/test_relu_aten.py | 2 +- .../converters/aten_op/test_reshape_aten.py | 2 +- .../converters/vanilla/test_add_vanilla.py | 2 +- .../vanilla/test_convolution_vanilla.py | 2 +- .../test/core/test_import_fx2trt.py | 2 +- .../test/core/test_input.py | 0 .../test/core/test_input_tensor_spec.py | 2 +- .../test/core/test_trt_module.py | 4 +- ...test_fix_clamp_numerical_limits_to_fp16.py | 0 .../test/passes/test_fix_reshape_batch_dim.py | 0 .../passes/test_fuse_permute_linear_trt.py | 2 +- .../passes/test_fuse_permute_matmul_trt.py | 2 +- .../test/passes/test_graph_opts.py | 0 .../test/passes/test_multi_fuse_trt.py | 2 +- .../test_remove_duplicate_output_args.py | 0 .../test/passes/test_setitem_trt.py | 2 +- .../test/quant/test_quant_trt.py | 4 +- .../test/tools/test_model_packager.py | 2 +- .../test/tracer/test_acc_shape_prop.py | 0 .../test/tracer/test_acc_tracer.py | 0 .../test/tracer/test_dispatch_tracer.py | 4 +- .../test/tracer/test_resnet.py | 12 +- .../test/trt_lower/test_diagnostics.py | 0 .../test/trt_lower/test_fx2trt_lower.py | 2 +- .../test/trt_lower/test_observer.py | 0 .../test/trt_lower/test_observer_gpu.py | 4 +- .../trt_lower/trt_operator_supported_test.py | 2 +- .../test/trt_lower/trt_splitter_test.py | 2 +- .../{ => fx_ts_compat}/tools/__init__.py | 0 .../{ => fx_ts_compat}/tools/common_fx2trt.py | 6 +- .../tools/engine_layer_visualize.py | 0 .../{ => fx_ts_compat}/tools/graph_util.py | 0 .../tools/model_packager.py | 0 .../{ => fx_ts_compat}/tools/node_profiler.py | 0 .../{ => fx_ts_compat}/tools/tensor_prop.py | 0 .../tools/timing_cache_utils.py | 0 .../{ => fx_ts_compat}/tools/trt_minimizer.py | 0 .../tools/trt_profiler_sorted.py | 0 .../{ => fx_ts_compat}/tools/trt_splitter.py | 0 .../dynamo/{ => fx_ts_compat}/types.py | 0 .../dynamo/{ => fx_ts_compat}/utils.py | 0 128 files changed, 289 insertions(+), 152 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/README.md (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/__init__.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/fx2trt.py (99%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/input_tensor_spec.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/lower.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/lower_setting.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/passes/__init__.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/passes/lower_pass_manager_builder.py (99%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/passes/pass_utils.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_adaptive_avgpool.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_any.py (93%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_as_strided.py (91%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_avgpool.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_batchnorm.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_binary_ops.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_cat.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_chunk.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_clamp.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_convolution.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_dequantize.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_einsum.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_elu.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_embedding.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_eq.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_expand.py (92%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_flatten.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_gelu.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_getitem.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_gt.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_hard_sigmoid.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_hardtanh.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_interpolate.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_isinf.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_leaky_relu.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_linear.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_logical_and.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_logical_or.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_logical_xor.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_lt.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_masked_fill.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_matmul.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_max.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_maximum.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_maxpool.py (99%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_min.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_minimum.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_narrow.py (93%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_ne.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_new_ones.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_numel.py (93%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_pad.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_permute.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_prod.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_quantize_per_tensor.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_reduce_ops.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_relu.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_repeat_interleave.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_reshape.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_selu.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_sigmoid.py (91%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_silu.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_size.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_softmax.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_softsign.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_split.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_squeeze.py (92%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_std.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_tanh.py (93%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_tile.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_to_dtype.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_topk.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_transpose_convolution.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_type_as.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_unary_ops.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_unsqueeze.py (94%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/acc_op/test_where.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_adaptive_avgpool_aten.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_batchnorm_aten.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_binary_ops_aten.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_cat_aten.py (93%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_convolution_aten.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_expand_aten.py (90%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_flatten_aten.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_linear_aten.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_maxpool_aten.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_relu_aten.py (93%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/aten_op/test_reshape_aten.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/vanilla/test_add_vanilla.py (87%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/converters/vanilla/test_convolution_vanilla.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/core/test_import_fx2trt.py (85%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/core/test_input.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/core/test_input_tensor_spec.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/core/test_trt_module.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_fix_clamp_numerical_limits_to_fp16.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_fix_reshape_batch_dim.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_fuse_permute_linear_trt.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_fuse_permute_matmul_trt.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_graph_opts.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_multi_fuse_trt.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_remove_duplicate_output_args.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/passes/test_setitem_trt.py (99%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/quant/test_quant_trt.py (99%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/tools/test_model_packager.py (95%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/tracer/test_acc_shape_prop.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/tracer/test_acc_tracer.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/tracer/test_dispatch_tracer.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/tracer/test_resnet.py (88%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/trt_lower/test_diagnostics.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/trt_lower/test_fx2trt_lower.py (97%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/trt_lower/test_observer.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/trt_lower/test_observer_gpu.py (91%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/trt_lower/trt_operator_supported_test.py (96%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/test/trt_lower/trt_splitter_test.py (99%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/__init__.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/common_fx2trt.py (98%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/engine_layer_visualize.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/graph_util.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/model_packager.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/node_profiler.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/tensor_prop.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/timing_cache_utils.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/trt_minimizer.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/trt_profiler_sorted.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/tools/trt_splitter.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/types.py (100%) rename py/torch_tensorrt/dynamo/{ => fx_ts_compat}/utils.py (100%) diff --git a/.circleci/config.yml b/.circleci/config.yml index a6731200cf..35ff68eea3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -711,13 +711,13 @@ commands: # =================== FX tests end ======================== # # =================== Dynamo tests start ======================== # - test-dynamo_core: + test-dynamo-fx_ts_core: description: "Test the Dynamo core" steps: - run: name: Run Dynamo core tests command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd core/ pytest --junitxml=/tmp/artifacts/test_results/dynamo/core/test_results.xml popd @@ -727,13 +727,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_converters_acc: + test-dynamo-fx_ts_converters_acc: description: "Test the Dynamo acc converters" steps: - run: name: Run FX converter tests command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd converters/acc_op/ pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/acc_op/test_results.xml popd @@ -743,13 +743,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_converters_aten: + test-dynamo-fx_ts_converters_aten: description: "Test the dynamo aten converters" steps: - run: name: Run dynamo converter tests command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd converters/aten_op/ pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/aten_op/test_results.xml popd @@ -759,13 +759,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_converters_vanilla: + test-dynamo-fx_ts_converters_vanilla: description: "Test the dynamo vanilla converters" steps: - run: name: Run dynamo converter tests command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd converters/vanilla/ pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/vanilla/test_results.xml popd @@ -775,13 +775,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_passes: + test-dynamo-fx_ts_passes: description: "Test the dynamo passes" steps: - run: name: Run dynamo passes command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd passes list_passes=$(ls | grep -v test_setitem*) pytest $list_passes --junitxml=/tmp/artifacts/test_results/dynamo/passes/test_results.xml @@ -791,13 +791,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_tools: + test-dynamo-fx_ts_tools: description: "Test the dynamo tools" steps: - run: name: Run dynamo tools command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd tools pytest --junitxml=/tmp/artifacts/test_results/dynamo/tools/test_results.xml popd @@ -806,13 +806,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_trt_lower: + test-dynamo-fx_ts_trt_lower: description: "Test the dynamo TRT lowering" steps: - run: name: Run dynamo TRT lowering command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd trt_lower pytest --junitxml=/tmp/artifacts/test_results/dynamo/trt_lower/test_results.xml popd @@ -821,13 +821,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_tracer: + test-dynamo-fx_ts_tracer: description: "Test all dynamo tracers" steps: - run: name: Run dynamo tracer command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd tracer list_tracer=$(ls | grep -v test_dispatch_*) pytest $list_tracer --junitxml=/tmp/artifacts/test_results/fx/tracer/test_results.xml @@ -837,13 +837,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_tracer_acc: + test-dynamo-fx_ts_tracer_acc: description: "Test the dynamo acc tracer only" steps: - run: name: Run dynamo tracer command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd tracer list_tracer=$(ls | grep test_acc) pytest $list_tracer --junitxml=/tmp/artifacts/test_results/dynamo/tracer/test_results.xml @@ -853,13 +853,13 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo_quant: + test-dynamo-fx_ts_quant: description: "Test the dynamo quant" steps: - run: name: Run dynamo quant tests command: | - cd py/torch_tensorrt/dynamo/test + cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd quant/ pytest --junitxml=/tmp/artifacts/test_results/dynamo/quant/test_results.xml popd @@ -869,42 +869,42 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo: + test-dynamo-fx_ts: description: "Test the dynamo backend" steps: - run: name: Run dynamo tests command: | mkdir -p /tmp/artifacts/test_results - - test-dynamo_converters_acc - - test-dynamo_converters_aten - - test-dynamo_converters_vanilla - - test-dynamo_passes - - test-dynamo_tools - - test-dynamo_trt_lower - - test-dynamo_tracer - - test-dynamo_core - - test-dynamo_quant + - test-dynamo-fx_ts_converters_acc + - test-dynamo-fx_ts_converters_aten + - test-dynamo-fx_ts_converters_vanilla + - test-dynamo-fx_ts_passes + - test-dynamo-fx_ts_tools + - test-dynamo-fx_ts_trt_lower + - test-dynamo-fx_ts_tracer + - test-dynamo-fx_ts_core + - test-dynamo-fx_ts_quant - store_test_results: path: /tmp/artifacts - store_artifacts: path: /tmp/testlogs - test-dynamo-no-aten: + test-dynamo-fx_ts-no-aten: description: "Test the dynamo backend without aten operators" steps: - run: name: Run dynamo tests without aten ops command: | mkdir -p /tmp/artifacts/test_results - - test-dynamo_converters_acc - - test-dynamo_converters_vanilla - - test-dynamo_passes - - test-dynamo_tools - - test-dynamo_trt_lower - - test-dynamo_tracer_acc - - test-dynamo_core - - test-dynamo_quant + - test-dynamo-fx_ts_converters_acc + - test-dynamo-fx_ts_converters_vanilla + - test-dynamo-fx_ts_passes + - test-dynamo-fx_ts_tools + - test-dynamo-fx_ts_trt_lower + - test-dynamo-fx_ts_tracer_acc + - test-dynamo-fx_ts_core + - test-dynamo-fx_ts_quant - store_test_results: path: /tmp/artifacts - store_artifacts: @@ -1117,7 +1117,7 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env - - test-dynamo + - test-dynamo-fx_ts test-py-dynamo-x86_64-linux-no-aten: parameters: @@ -1148,7 +1148,7 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env - - test-dynamo-no-aten + - test-dynamo-fx_ts-no-aten package-x86_64-linux: parameters: diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md b/py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md new file mode 100644 index 0000000000..eb4454340e --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md @@ -0,0 +1,137 @@ +# PyTorch Operations Dynamic Shape Support Summary + + + + | Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason | +| --- | --- | --- | --- | --- | --- | +| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. | +| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 | +| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | | +| | avg_pool1d | partially | (-1, 3, 3) | 1 | | +| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." | +| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | | +| cat | | yes | (-1,-,1,-1,-1) | 4 | | +| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! | +| clamp | | yes | (-1,-,1,-1,-1) | | | +| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. | +| | conv1d | partially | (-1, 3, 3) | 1 | | +| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. | +| dequantize | | yes | (-1,-,1,-1,-1) | 4 | | +| eimsum | | yes | (-1,-,1,-1,-1) | 4 | | +| elu | | yes | (-1,-,1,-1,-1) | 4 | | +| embedding | | yes | (-1,-,1,-1,-1) | 4 | | +| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperatorConstant | partially | (3,-1) | 1 | | +| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| expand | | no | | | Dynamic shape is not suitable for the expand operation. | +| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | | +| gelu | | yes | (-1,-,1,-1,-1) | 4 | | +| getitem | | yes | (-1,-,1,-1,-1) | 4 | | +| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | | +| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | | +| interpolate | | yes | (-1,-,1,-1,-1) | 4 | | +| isinf | | yes | (-1,-,1,-1,-1) | 4 | | +| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | | +| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. | +| logical_and | | yes | (-1, -1, -1, -1) | 4 | | +| logical_or | | yes | (-1, -1, -1, -1) | 4 | | +| logical_xor | | yes | (-1, -1, -1, -1) | 4 | | +| lt | | yes | (-1, -1, -1, -1) | 4 | | +| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| mat_mul | | yes | batch dim | | | +| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | | +| maximum | | yes | (-1, -1, -1, -1) | 4 | | +| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer | +| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | | +| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | | +| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MinMethod | yes | (-1, -1, -1, -1) | 4 | | +| minimum | | yes | (-1, -1, -1, -1) | 4 | | +| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! | +| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | | +| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeOperatorConstantConverter | partially | (3, -1) | 1 | | +| new_ones | | yes | (-1, -1, -1, -1) | 4 | | +| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. | +| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. | +| permute | | yes | (-1, -1, -1, -1) | 4 | | +| prod | | yes | (-1, -1, -1, -1) | 4 | | +| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | | +| reduce op | | yes | (-1, -1, -1, -1) | 4 | | +| relu | | yes | (-1, -1, -1, -1) | 4 | | +| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | +| reshape | | yes | (-1, -1, -1, -1) | 4 | | +| selu | | yes | (-1, -1, -1, -1) | 4 | | +| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | | +| silu | | yes | (-1,-,1,-1,-1) | 4 | | +| size | | yes | (-1, -1, -1, -1) | 4 | | +| softmax | | yes | (-1, -1, -1, -1) | 4 | | +| softsign | | yes | (-1, -1, -1, -1) | 4 | | +| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! | +| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. | +| std | | yes | (-1, -1, -1, -1) | 4 | | +| tanh | | yes | (-1, -1, -1, -1) | 4 | | +| tile | | yes | (-1, -1, -1, -1) | 4 | | +| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | | +| | float | yes | (-1, -1, -1, -1) | 4 | | +| topk | | yes | (-1, -1, -1, -1) | 4 | | +| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | | +| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | | +| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} | +| unary ops | | yes | (-1, -1, -1, -1) | 4 | | +| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | +| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] | + + + +Binary Ops Include following operations: +|Binary Ops | +|----------| +|add | +|sub | +|div | +|mul | +|floor_div | +|fmod | +|floor_divide| +|pow | + + +Unary Ops Include following operations: +|Unary Ops | +|----------| +|rsqrt | +|sin | +|cos | +|tan | +|sinh | +|cosh | +|asin | +|acos | +|atan | +|abs | +|neg | +|reciprocal| +|sqrt | +|log | +|exp | +|floor | +|ceil | +|sign | + +Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing. diff --git a/py/torch_tensorrt/dynamo/README.md b/py/torch_tensorrt/dynamo/fx_ts_compat/README.md similarity index 100% rename from py/torch_tensorrt/dynamo/README.md rename to py/torch_tensorrt/dynamo/fx_ts_compat/README.md diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py similarity index 100% rename from py/torch_tensorrt/dynamo/__init__.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py diff --git a/py/torch_tensorrt/dynamo/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py similarity index 99% rename from py/torch_tensorrt/dynamo/fx2trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index 1971e219e0..18e514d83d 100644 --- a/py/torch_tensorrt/dynamo/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -13,7 +13,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata -from torch_tensorrt.dynamo import CONVERTERS +from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS from .input_tensor_spec import InputTensorSpec from torch_tensorrt.fx.observer import Observer from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt diff --git a/py/torch_tensorrt/dynamo/input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py similarity index 100% rename from py/torch_tensorrt/dynamo/input_tensor_spec.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py diff --git a/py/torch_tensorrt/dynamo/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py similarity index 100% rename from py/torch_tensorrt/dynamo/lower.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/lower.py diff --git a/py/torch_tensorrt/dynamo/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py similarity index 100% rename from py/torch_tensorrt/dynamo/lower_setting.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py diff --git a/py/torch_tensorrt/dynamo/passes/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py similarity index 100% rename from py/torch_tensorrt/dynamo/passes/__init__.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py diff --git a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py similarity index 99% rename from py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py index ea2e331619..cb012c4f4e 100644 --- a/py/torch_tensorrt/dynamo/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py @@ -8,7 +8,7 @@ from torch.fx.passes.pass_manager import inplace_wrapper, PassManager from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult -from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision from torch_tensorrt import _Input from ..input_tensor_spec import InputTensorSpec diff --git a/py/torch_tensorrt/dynamo/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/passes/pass_utils.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py index 0b194e4c77..ddd647707c 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_adaptive_avgpool.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestAdaptiveAvgPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py similarity index 93% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py index 1e46e3cff1..7b50fd4515 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_any.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py @@ -3,9 +3,9 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase -# from torch_tensorrt.dynamo.tools.common_fx2trt import InputTensorSpec +# from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import InputTensorSpec class TestAnyConverters(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py similarity index 91% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py index 72eecb5810..3aff0638d6 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_as_strided.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py @@ -3,9 +3,9 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase -# from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +# from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py index 88f55c58a9..9eb579d326 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_avgpool.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestAvgPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py index 965bbcd729..c22772e8ad 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_batchnorm.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestBatchNormConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py index f2a9fb1620..6c226ec405 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_binary_ops.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py @@ -6,7 +6,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec NEED_TEST_BOTH_CONSTANTS_CASE = True diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py index e9232c35d7..5124cd5f05 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_cat.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestCatConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py index 49fb8cff5b..81e575a3e7 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_chunk.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestChunkConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py index 96e611626c..f7fdfd653d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestClampConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py index bedc75f194..bbf03b7888 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_convolution.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestConvolutionConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py index 212a77ec63..40a316aa2d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_dequantize.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py @@ -6,7 +6,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py index 88a7e5fae7..d7f3268b55 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_einsum.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py index 313d8ec022..bc52c5688b 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_elu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestELUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py index ecfa171f0c..34853249cc 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_embedding.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py @@ -5,7 +5,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py index befe675232..e83b1c8c16 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_eq.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestEqConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py similarity index 92% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py index fd369459f3..e7021e2353 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_expand.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase class TestExpandConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_flatten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_flatten.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_flatten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_flatten.py diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py index e7b7bd806d..36d427d49f 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gelu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py index 9cc68ad87e..012f6f8d05 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_getitem.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestGetitemConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py index b6e8e602d7..b08df22e1d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_gt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestGtConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py index 86d5b1a099..40592245ab 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hard_sigmoid.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py @@ -2,7 +2,7 @@ from parameterized import parameterized from torch import nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec from torch_tensorrt.fx.tracer.acc_tracer import acc_ops diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py index 97d326851a..f729a2f4d2 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_hardtanh.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestHardtanhConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py index c3e10f96ee..4c903f2f0c 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestInterpolateConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py index 9717eb52c1..2e52c4161d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_isinf.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip("Implementation is commented out due to accuracy issue T113156424") diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py index 601aa7ee91..12244316de 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_leaky_relu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLeakyReLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py index 361a25fa04..3eda648854 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_linear.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLinearConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py index 85f18ea3f3..23b63ca03c 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_and.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestAndMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py index 265f5735eb..1db491bf05 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_or.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLogicalOrMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py index 0cd6174950..b27de8f492 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_logical_xor.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLogicalXorMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py index df51e6bf58..45000c8b80 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_lt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLtConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py index 9e3ca83015..3c56d50750 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_masked_fill.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase class TestMaskedFill(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py index 7e7456c437..19fa661df5 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_matmul.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMatMulConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py index c2cf7d252d..8857f87351 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_max.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMaxConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py index e0bec6f15d..9c819edd67 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maximum.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMaximumConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py similarity index 99% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py index ddb48b4b69..8018f632d8 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_maxpool.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMaxPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py index 9f37238240..be0b1b0c9e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_min.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMinConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py index a4b605cf66..6c3431d1a3 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_minimum.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMinimumConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py similarity index 93% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py index 93cf4ea523..c4dab876df 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_narrow.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestNarrowConverterWithDynamicShape(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py index affbc57aae..5b3f4e7c94 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_ne.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestNeFunctionConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py index dabfedb139..4b5dacbea7 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_new_ones.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestNewOnesConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py similarity index 93% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py index a2eafd4fdc..a79e29600d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_numel.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase class TestNumelConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py index e850268fde..7625f4edee 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_pad.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py @@ -7,7 +7,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase # from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py index 9e4ebc9cf4..6bed54f73e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_permute.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestPermuteConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py index 0d6c16b98e..fea3658644 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_prod.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec # NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py index 5830a3e463..29576ae37e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_quantize_per_tensor.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py @@ -6,7 +6,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py index 988eb7b477..20d68925af 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reduce_ops.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec reduce_ops = [(torch.sum, acc_ops.sum), (torch.mean, acc_ops.mean)] diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py index e520b742c1..d35492a35d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_relu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestReLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py index efb1f80c0f..d34706d2d8 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_repeat_interleave.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py @@ -3,7 +3,7 @@ from parameterized import parameterized from torch import nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestRepeatInterLeave(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py index b5b1dc8f6c..716db89f3a 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_reshape.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestReshapeConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py index 4a89f364ee..b5d4ab774e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_selu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSeLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py similarity index 91% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py index d8abf37707..36908a0720 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_sigmoid.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSigmoid(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_silu.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_silu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_silu.py diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py index 9fd1e45015..5834577e9c 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_size.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSizeConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py index eab632c296..dfed9e1f12 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softmax.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSoftmaxConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py index 47241685fb..6a10923596 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_softsign.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSoftsignConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py index 4861fecc34..adafb65655 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_split.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSplitConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py similarity index 92% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py index bc65e010e4..69479daf2e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSqueeze(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py index cd38314295..fbd4eaf1ca 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_std.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMinConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py similarity index 93% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py index 94c442a4ed..6bdaf5f71c 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tanh.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTanh(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py index 1d14987adc..bda604ea8e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_tile.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTile(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py index c057088c77..3b0642d720 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_to_dtype.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec -from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision class TestToConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py index 7790857f5a..874edbe2f0 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_topk.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTopKConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py index 934b4c0d81..81dd8025b6 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_transpose_convolution.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTransposeConvolutionConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py index 24f99b5bff..e7455ec84b 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_type_as.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py @@ -3,8 +3,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec -from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision class TestTypeAsConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py index f88c07c97a..5d50a4ec9b 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unary_ops.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py @@ -6,7 +6,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec unary_ops = [ (torch.sin, acc_ops.sin, False), diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py similarity index 94% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py index a422f1b6fe..19c6b2969c 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_unsqueeze.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestUnsqueeze(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py index 2985042f6b..9c846709bf 100644 --- a/py/torch_tensorrt/dynamo/test/converters/acc_op/test_where.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase class TestWhere(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py index b51c9a8f9a..5b4930af64 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_adaptive_avgpool_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestAdaptiveAvgPoolConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py index aed68ba35f..db16461abe 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_batchnorm_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py @@ -1,6 +1,6 @@ import torch from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestBatchNormConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py index b80cd514c1..75e40a2109 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_binary_ops_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py @@ -5,7 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec NEED_TEST_BOTH_CONSTANTS_CASE = True diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py similarity index 93% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py index 1d181c0442..6ab72b00a8 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestCatConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py index 60971038fa..a1f48534d2 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_convolution_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestConvolutionConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py similarity index 90% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py index 380cdc4db3..fe8b32692e 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_expand_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase class TestExpandConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py index b1b0b584f0..d828f12d77 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_flatten_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py @@ -4,7 +4,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestFlattenConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py index 5c06035e37..f9f90c0f09 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_linear_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestLinearConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py index 5a121f0e07..6602f23fe7 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_maxpool_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py @@ -4,7 +4,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestMaxPoolConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py similarity index 93% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py index fb7fe2f509..78f3e381b2 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_relu_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestReLUConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py index 09dcb65ab1..0356faa382 100644 --- a/py/torch_tensorrt/dynamo/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py @@ -4,7 +4,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestReshapeConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py similarity index 87% rename from py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py index 6f805421f4..1e6c748cc1 100644 --- a/py/torch_tensorrt/dynamo/test/converters/vanilla/test_add_vanilla.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py @@ -5,7 +5,7 @@ import torch import torch.fx from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import VanillaTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import VanillaTestCase class TestAddConverter(VanillaTestCase): diff --git a/py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py index c73c30a30e..4bd1c7519d 100644 --- a/py/torch_tensorrt/dynamo/test/converters/vanilla/test_convolution_vanilla.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py @@ -4,7 +4,7 @@ import torch.fx from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.tools.common_fx2trt import VanillaTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import VanillaTestCase class TestConvolutionConverter(VanillaTestCase): diff --git a/py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py similarity index 85% rename from py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py index 7373ddc4fe..12e47ef112 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_import_fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py @@ -5,7 +5,7 @@ # Test that this import should not trigger any error when run # in non-GPU hosts, or in any build mode. -import torch_tensorrt.dynamo.lower as fxl # noqa: F401 +import torch_tensorrt.dynamo.fx_ts_compat.lower as fxl # noqa: F401 from torch.testing._internal.common_utils import run_tests, TestCase diff --git a/py/torch_tensorrt/dynamo/test/core/test_input.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/core/test_input.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py diff --git a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py index 65c7ea4158..7794b1bac8 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py @@ -4,7 +4,7 @@ import torch from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo import InputTensorSpec, LowerSetting +from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting class TestTRTModule(TestCase): diff --git a/py/torch_tensorrt/dynamo/test/core/test_trt_module.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/core/test_trt_module.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py index baf98c8d7c..8043b753ac 100644 --- a/py/torch_tensorrt/dynamo/test/core/test_trt_module.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py @@ -9,8 +9,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx import TRTModule -from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter -from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision class TestTRTModule(TestCase): diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_clamp_numerical_limits_to_fp16.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/passes/test_fix_clamp_numerical_limits_to_fp16.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_clamp_numerical_limits_to_fp16.py diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_reshape_batch_dim.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/passes/test_fix_reshape_batch_dim.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_reshape_batch_dim.py diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py index 12fd25447d..b9b39c8a1b 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_linear_trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py @@ -9,7 +9,7 @@ fuse_permute_linear, trt_transposed_linear, ) -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase class TestFusePermuteLinear(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py index fec3edbf9a..6570f6c276 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_fuse_permute_matmul_trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py @@ -8,7 +8,7 @@ fuse_permute_matmul, trt_transposed_matmul, ) -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase def tranpose_last_two_dims(x): diff --git a/py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_graph_opts.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/passes/test_graph_opts.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_graph_opts.py diff --git a/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py index af4279db2f..9712ca3b91 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_multi_fuse_trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py @@ -10,7 +10,7 @@ trt_transposed_linear, trt_transposed_matmul, ) -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase def permute021(x): diff --git a/py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_remove_duplicate_output_args.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/passes/test_remove_duplicate_output_args.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_remove_duplicate_output_args.py diff --git a/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py similarity index 99% rename from py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py index ded67e97dc..777796a083 100644 --- a/py/torch_tensorrt/dynamo/test/passes/test_setitem_trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py @@ -4,7 +4,7 @@ from torch._dynamo.optimizations import backends from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem -from torch_tensorrt.dynamo.tools.common_fx2trt import AccTestCase +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase class TestTransformSetitem(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py similarity index 99% rename from py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py index e3fa371e38..fabd94e24c 100644 --- a/py/torch_tensorrt/dynamo/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py @@ -28,11 +28,11 @@ QuantizationTestCase, ) from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter +from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter from torch_tensorrt.fx import TRTModule from torch_tensorrt.fx.passes.lower_basic_pass import run_const_fold from torch_tensorrt.fx.tracer.acc_tracer import acc_ops -from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision def lower_to_trt(model, inputs, shape_ranges): diff --git a/py/torch_tensorrt/dynamo/test/tools/test_model_packager.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py similarity index 95% rename from py/torch_tensorrt/dynamo/test/tools/test_model_packager.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py index 00293ccadb..209181137e 100644 --- a/py/torch_tensorrt/dynamo/test/tools/test_model_packager.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py @@ -5,7 +5,7 @@ import torch.fx from torch import nn from torch.package import PackageImporter -from torch_tensorrt.dynamo.tools.model_packager import ( +from torch_tensorrt.dynamo.fx_ts_compat.tools.model_packager import ( generate_standalone_repro, ModelPackager, ) diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_acc_shape_prop.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_shape_prop.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/tracer/test_acc_shape_prop.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_shape_prop.py diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_tracer.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/tracer/test_acc_tracer.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_tracer.py diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py similarity index 98% rename from py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py index 4af730d67c..a066bc4413 100644 --- a/py/torch_tensorrt/dynamo/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py @@ -11,9 +11,9 @@ from torch._dynamo.optimizations.normalize import normalize_ir from torch.library import Library -from torch_tensorrt.dynamo.lower import compile +from torch_tensorrt.dynamo.fx_ts_compat.lower import compile from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx -from torch_tensorrt.dynamo.utils import LowerPrecision, proxytensor_trace +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision, proxytensor_trace # TODO(ezyang): remove this after we properly support fake example inputs torch._dynamo.config.DO_NOT_USE_legacy_non_fake_example_inputs = True diff --git a/py/torch_tensorrt/dynamo/test/tracer/test_resnet.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py similarity index 88% rename from py/torch_tensorrt/dynamo/test/tracer/test_resnet.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py index cf04edc5d9..1dfdfa7125 100644 --- a/py/torch_tensorrt/dynamo/test/tracer/test_resnet.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py @@ -4,8 +4,8 @@ import torch._dynamo.config import torchvision -from torch_tensorrt.dynamo.lower import compile -from torch_tensorrt.dynamo.utils import LowerPrecision +from torch_tensorrt.dynamo.fx_ts_compat.lower import compile +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision class ResnetTest(unittest.TestCase): @@ -19,7 +19,7 @@ def test_resnet18_aten(self): aten_mod = compile( mod, inputs, - lower_precision=LowerPrecision.FP16, + enabled_precisions={torch.float16}, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, @@ -31,7 +31,7 @@ def test_resnet18_aten(self): fx_mod = compile( mod, inputs, - lower_precision=LowerPrecision.FP16, + enabled_precisions={torch.float16}, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, @@ -59,7 +59,7 @@ def test_resnet18_aten_dynamic(self): aten_mod = compile( mod, inputs, - lower_precision=LowerPrecision.FP16, + enabled_precisions={torch.float16}, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, @@ -71,7 +71,7 @@ def test_resnet18_aten_dynamic(self): fx_mod = compile( mod, inputs, - lower_precision=LowerPrecision.FP16, + enabled_precisions={torch.float16}, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_diagnostics.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/trt_lower/test_diagnostics.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_diagnostics.py diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py similarity index 97% rename from py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py index 7546346d18..4077fe1491 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py @@ -6,7 +6,7 @@ import torch import torch.fx as fx import torch.nn as nn -from torch_tensorrt.dynamo.lower import Lowerer, LowerSetting +from torch_tensorrt.dynamo.fx_ts_compat.lower import Lowerer, LowerSetting from torch_tensorrt.fx.passes.lower_basic_pass import replace_mutable_op logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/trt_lower/test_observer.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer.py diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py similarity index 91% rename from py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py index 005891cbfd..bd17f42e72 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py @@ -19,8 +19,8 @@ def test_observe_lowerer(self): import torch import torch.nn as nn - import torch_tensorrt.dynamo.lower as lower - from torch_tensorrt.dynamo.lower_setting import LowerSetting + import torch_tensorrt.dynamo.fx_ts_compat.lower as lower + from torch_tensorrt.dynamo.fx_ts_compat.lower_setting import LowerSetting class Model(nn.Module): def forward(self, x, y): diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py index 699b787f0e..400afe0457 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/trt_operator_supported_test.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops # noqa: F401 from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.tools.trt_splitter import create_trt_operator_support +from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import create_trt_operator_support from torch_tensorrt.fx.tracer.acc_tracer import acc_ops, acc_tracer diff --git a/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py similarity index 99% rename from py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py index 7738b829b9..e9d2433684 100644 --- a/py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py @@ -10,7 +10,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import splitter_base from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import TRTSplitter, TRTSplitterSetting from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!" diff --git a/py/torch_tensorrt/dynamo/tools/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/__init__.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py diff --git a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py similarity index 98% rename from py/torch_tensorrt/dynamo/tools/common_fx2trt.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py index da09c00ab9..5c0d8bbc76 100644 --- a/py/torch_tensorrt/dynamo/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py @@ -13,7 +13,7 @@ from torch.fx.passes import shape_prop from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase -from torch_tensorrt.dynamo import InputTensorSpec, TRTInterpreter +from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, @@ -26,8 +26,8 @@ replace_transpose_mm_op_with_linear, run_const_fold, ) -from torch_tensorrt.dynamo.passes.pass_utils import chain_passes -from torch_tensorrt.dynamo.utils import LowerPrecision, proxytensor_trace +from torch_tensorrt.dynamo.fx_ts_compat.passes.pass_utils import chain_passes +from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision, proxytensor_trace _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/tools/engine_layer_visualize.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/engine_layer_visualize.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/engine_layer_visualize.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/engine_layer_visualize.py diff --git a/py/torch_tensorrt/dynamo/tools/graph_util.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/graph_util.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/graph_util.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/graph_util.py diff --git a/py/torch_tensorrt/dynamo/tools/model_packager.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/model_packager.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/model_packager.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/model_packager.py diff --git a/py/torch_tensorrt/dynamo/tools/node_profiler.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/node_profiler.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/node_profiler.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/node_profiler.py diff --git a/py/torch_tensorrt/dynamo/tools/tensor_prop.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/tensor_prop.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/tensor_prop.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/tensor_prop.py diff --git a/py/torch_tensorrt/dynamo/tools/timing_cache_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/timing_cache_utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/timing_cache_utils.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/timing_cache_utils.py diff --git a/py/torch_tensorrt/dynamo/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/trt_minimizer.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py diff --git a/py/torch_tensorrt/dynamo/tools/trt_profiler_sorted.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_profiler_sorted.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/trt_profiler_sorted.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_profiler_sorted.py diff --git a/py/torch_tensorrt/dynamo/tools/trt_splitter.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_splitter.py similarity index 100% rename from py/torch_tensorrt/dynamo/tools/trt_splitter.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_splitter.py diff --git a/py/torch_tensorrt/dynamo/types.py b/py/torch_tensorrt/dynamo/fx_ts_compat/types.py similarity index 100% rename from py/torch_tensorrt/dynamo/types.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/types.py diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/utils.py rename to py/torch_tensorrt/dynamo/fx_ts_compat/utils.py From 8b15cdca52bd92339e0df3acabe813adb9671c5f Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 11 Apr 2023 15:23:54 -0700 Subject: [PATCH 34/45] chore: Update setup.py Signed-off-by: Dheeraj Peri --- py/setup.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/py/setup.py b/py/setup.py index 200a43c533..d3c969b1ee 100644 --- a/py/setup.py +++ b/py/setup.py @@ -357,8 +357,9 @@ def run(self): "torch_tensorrt.fx.tracer.acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer", "torch_tensorrt.dynamo", - "torch_tensorrt.dynamo.passes", - "torch_tensorrt.dynamo.tools", + "torch_tensorrt.dynamo.fx_ts_compat", + "torch_tensorrt.dynamo.fx_ts_compat.passes", + "torch_tensorrt.dynamo.fx_ts_compat.tools", ] package_dir = { "torch_tensorrt.fx": "torch_tensorrt/fx", @@ -368,8 +369,9 @@ def run(self): "torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer", "torch_tensorrt.dynamo": "torch_tensorrt/dynamo", - "torch_tensorrt.dynamo.passes": "torch_tensorrt/dynamo/passes", - "torch_tensorrt.dynamo.tools": "torch_tensorrt/dynamo/tools", + "torch_tensorrt.dynamo.fx_ts_compat": "torch_tensorrt/dynamo/fx_ts_compat", + "torch_tensorrt.dynamo.fx_ts_compat.passes": "torch_tensorrt/dynamo/fx_ts_compat/passes", + "torch_tensorrt.dynamo.fx_ts_compat.tools": "torch_tensorrt/dynamo/fx_ts_compat/tools", } with open("README.md", "r", encoding="utf-8") as fh: From 8f42a1872c3992bce2eeb32ffc83c0fa2a3be345 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 11 Apr 2023 18:16:41 -0700 Subject: [PATCH 35/45] chore: Rename ir to fx_ts_compat Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_compile.py | 12 ++++++------ .../dynamo/fx_ts_compat/test/core/test_input.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 4aea98852c..fb96f5e373 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -15,7 +15,7 @@ class _IRType(Enum): ts = 0 fx = 1 - dynamo = 2 + fx_ts_compat = 2 class _ModuleType(Enum): @@ -46,14 +46,14 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" - ir_targets_dynamo = ir == "dynamo" + ir_targets_fx_ts_compat = ir == "fx_ts_compat" if module_is_tsable and ir_targets_torchscript: return _IRType.ts elif module_is_fxable and ir_targets_fx: return _IRType.fx - elif module_is_fxable and ir_targets_dynamo: - return _IRType.dynamo + elif module_is_fxable and ir_targets_fx_ts_compat: + return _IRType.fx_ts_compat else: if ir == "default": # Options are listed in order of preference @@ -152,8 +152,8 @@ def compile( dynamic_batch=False, **kwargs, ) - elif target_ir == _IRType.dynamo: - return torch_tensorrt.dynamo.compile( + elif target_ir == _IRType.fx_ts_compat: + return torch_tensorrt.dynamo.fx_ts_compat.compile( module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs ) else: diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py index 869482dbef..b7dd8153cb 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py @@ -21,7 +21,7 @@ def forward(self, x): trt_mod = torch_tensorrt.compile( mod, - ir="dynamo", + ir="fx_ts_compat", inputs=inputs, min_block_size=1, ) @@ -45,7 +45,7 @@ def forward(self, x): trt_mod = torch_tensorrt.compile( mod, - ir="dynamo", + ir="fx_ts_compat", inputs=inputs, min_block_size=1, ) @@ -76,7 +76,7 @@ def forward(self, x): trt_mod = torch_tensorrt.compile( mod, - ir="dynamo", + ir="fx_ts_compat", inputs=inputs, min_block_size=1, ) From d49b46c6a5f110eb7a0ba93f709cd89e0fd13791 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 11 Apr 2023 18:51:01 -0700 Subject: [PATCH 36/45] chore: Fix import Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 py/torch_tensorrt/dynamo/__init__.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py new file mode 100644 index 0000000000..c018cfc332 --- /dev/null +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -0,0 +1 @@ +from torch_tensorrt.dynamo import fx_ts_compat From cf5bb20e3d036c6c01368bdd8cc187e7ab74530a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 11 Apr 2023 18:52:32 -0700 Subject: [PATCH 37/45] chore: Linter fixes Signed-off-by: Dheeraj Peri --- .../test/converters/acc_op/test_adaptive_avgpool.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_avgpool.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_batchnorm.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_binary_ops.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_convolution.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_dequantize.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_einsum.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_embedding.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_getitem.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_hardtanh.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_interpolate.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_leaky_relu.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_linear.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_logical_and.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_logical_or.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_logical_xor.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_matmul.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_max.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_maximum.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_maxpool.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_min.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_minimum.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_narrow.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_new_ones.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_permute.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py | 5 ++++- .../test/converters/acc_op/test_quantize_per_tensor.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_reduce_ops.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py | 5 ++++- .../test/converters/acc_op/test_repeat_interleave.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_reshape.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_sigmoid.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_size.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_softmax.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_softsign.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_split.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_squeeze.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_std.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_to_dtype.py | 5 ++++- .../dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py | 5 ++++- .../test/converters/acc_op/test_transpose_convolution.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_type_as.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_unary_ops.py | 5 ++++- .../fx_ts_compat/test/converters/acc_op/test_unsqueeze.py | 5 ++++- .../test/converters/aten_op/test_adaptive_avgpool_aten.py | 5 ++++- .../test/converters/aten_op/test_batchnorm_aten.py | 5 ++++- .../test/converters/aten_op/test_binary_ops_aten.py | 5 ++++- .../fx_ts_compat/test/converters/aten_op/test_cat_aten.py | 5 ++++- .../test/converters/aten_op/test_convolution_aten.py | 5 ++++- .../test/converters/aten_op/test_flatten_aten.py | 5 ++++- .../fx_ts_compat/test/converters/aten_op/test_linear_aten.py | 5 ++++- .../test/converters/aten_op/test_maxpool_aten.py | 5 ++++- .../fx_ts_compat/test/converters/aten_op/test_relu_aten.py | 5 ++++- .../test/converters/aten_op/test_reshape_aten.py | 5 ++++- .../test/trt_lower/trt_operator_supported_test.py | 4 +++- .../dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py | 5 ++++- 70 files changed, 279 insertions(+), 70 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py index ddd647707c..37f8dcade8 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestAdaptiveAvgPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py index 9eb579d326..f9cb1cb9cd 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestAvgPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py index c22772e8ad..d52bcd8905 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py @@ -1,7 +1,10 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestBatchNormConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py index 6c226ec405..ae006e03a9 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py @@ -6,7 +6,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) NEED_TEST_BOTH_CONSTANTS_CASE = True diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py index 5124cd5f05..807ab8842e 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestCatConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py index 81e575a3e7..42706d8e1f 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestChunkConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py index f7fdfd653d..a64d58a98b 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestClampConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py index bbf03b7888..ab29f0dfc3 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestConvolutionConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py index 40a316aa2d..1f7f6cbe88 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py @@ -6,7 +6,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py index d7f3268b55..c6beebdf4c 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py index bc52c5688b..c35154bd76 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestELUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py index 34853249cc..05186300b4 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py @@ -5,7 +5,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py index e83b1c8c16..8cb9185673 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestEqConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py index 36d427d49f..1c7c8264f2 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py @@ -4,7 +4,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py index 012f6f8d05..880cbe2418 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestGetitemConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py index b08df22e1d..fac763acb0 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestGtConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py index 40592245ab..cfe0e2b52e 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py @@ -2,7 +2,10 @@ from parameterized import parameterized from torch import nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) from torch_tensorrt.fx.tracer.acc_tracer import acc_ops diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py index f729a2f4d2..469816e2b4 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestHardtanhConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py index 4c903f2f0c..8eefb88ed9 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestInterpolateConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py index 2e52c4161d..89c65e7eff 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py @@ -4,7 +4,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) @unittest.skip("Implementation is commented out due to accuracy issue T113156424") diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py index 12244316de..02deb0ee57 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestLeakyReLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py index 3eda648854..25353e8f29 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestLinearConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py index 23b63ca03c..71851221c2 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestAndMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py index 1db491bf05..4f45612b34 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestLogicalOrMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py index b27de8f492..591c7322bf 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestLogicalXorMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py index 45000c8b80..6d037145ac 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestLtConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py index 19fa661df5..2f979f1243 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMatMulConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py index 8857f87351..be6b4cdedc 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMaxConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py index 9c819edd67..8c1522d3ad 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py @@ -1,7 +1,10 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMaximumConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py index 8018f632d8..7ed6301467 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMaxPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py index be0b1b0c9e..6d09db1d5c 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMinConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py index 6c3431d1a3..7778f784a2 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py @@ -1,7 +1,10 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMinimumConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py index c4dab876df..13d0e257ac 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestNarrowConverterWithDynamicShape(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py index 5b3f4e7c94..2fd99787b4 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestNeFunctionConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py index 4b5dacbea7..79754b38d4 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestNewOnesConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py index 6bed54f73e..a8b8c95f0b 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestPermuteConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py index fea3658644..e13c8b3048 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) # NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py index 29576ae37e..eaef10df94 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py @@ -6,7 +6,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) @unittest.skip( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py index 20d68925af..4fe7f8511c 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) reduce_ops = [(torch.sum, acc_ops.sum), (torch.mean, acc_ops.mean)] diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py index d35492a35d..774cd6fec7 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestReLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py index d34706d2d8..0c4360d53f 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py @@ -3,7 +3,10 @@ from parameterized import parameterized from torch import nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestRepeatInterLeave(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py index 716db89f3a..dba833276f 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestReshapeConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py index b5d4ab774e..cbc4c04117 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSeLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py index 36908a0720..77aa8c9392 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSigmoid(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py index 5834577e9c..411b8b6a46 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSizeConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py index dfed9e1f12..20c4ab744d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSoftmaxConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py index 6a10923596..73b97a02b6 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSoftsignConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py index adafb65655..20f63ab958 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSplitConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py index 69479daf2e..f1cc4fe96d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestSqueeze(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py index fbd4eaf1ca..bc1d0ece89 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestMinConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py index 6bdaf5f71c..dd39d29d41 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestTanh(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py index bda604ea8e..c370c58eba 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestTile(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py index 3b0642d720..788a252e6e 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py @@ -1,7 +1,10 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py index 874edbe2f0..83de8eb894 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py @@ -3,7 +3,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestTopKConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py index 81dd8025b6..1f837c12f7 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py @@ -4,7 +4,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestTransposeConvolutionConverter(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py index e7455ec84b..2b6869d0f0 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py @@ -3,7 +3,10 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py index 5d50a4ec9b..2015fc21ef 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py @@ -6,7 +6,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) unary_ops = [ (torch.sin, acc_ops.sin, False), diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py index 19c6b2969c..059374194c 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py @@ -4,7 +4,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + AccTestCase, + InputTensorSpec, +) class TestUnsqueeze(AccTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py index 5b4930af64..b3d8550bb6 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py @@ -2,7 +2,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestAdaptiveAvgPoolConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py index db16461abe..2ca9b7ed82 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py @@ -1,6 +1,9 @@ import torch from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestBatchNormConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py index 75e40a2109..a328b8655c 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py @@ -5,7 +5,10 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) NEED_TEST_BOTH_CONSTANTS_CASE = True diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py index 6ab72b00a8..50190113ad 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py @@ -2,7 +2,10 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestCatConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py index a1f48534d2..9c4ceaa9bf 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py @@ -1,7 +1,10 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestConvolutionConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py index d828f12d77..ca9e8143ce 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py @@ -4,7 +4,10 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestFlattenConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py index f9f90c0f09..8790cdeecc 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py @@ -1,7 +1,10 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestLinearConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py index 6602f23fe7..3ffd59ed19 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py @@ -4,7 +4,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestMaxPoolConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py index 78f3e381b2..3367e237fb 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py @@ -1,7 +1,10 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestReLUConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py index 0356faa382..0382ad7788 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py @@ -4,7 +4,10 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( + DispatchTestCase, + InputTensorSpec, +) class TestReshapeConverter(DispatchTestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py index 400afe0457..ebccd3c08b 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py @@ -5,7 +5,9 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops # noqa: F401 from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import create_trt_operator_support +from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import ( + create_trt_operator_support, +) from torch_tensorrt.fx.tracer.acc_tracer import acc_ops, acc_tracer diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py index e9d2433684..6421f662fc 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py @@ -10,7 +10,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import splitter_base from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import ( + TRTSplitter, + TRTSplitterSetting, +) from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!" From 226cc793424bb1435e1dfbce9009ab993974ba2f Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 11 Apr 2023 17:09:05 -0700 Subject: [PATCH 38/45] fix: Reorganize, add tests --- .circleci/config.yml | 106 +++++++++++++ py/torch_tensorrt/__init__.py | 2 + py/torch_tensorrt/_compile.py | 12 ++ py/torch_tensorrt/dynamo/__init__.py | 127 +-------------- py/torch_tensorrt/dynamo/lowering/__init__.py | 2 - py/torch_tensorrt/dynamo/test/conftest.py | 18 +++ .../dynamo/test/test_dynamo_backend.py | 144 ++++++++++++++++++ py/torch_tensorrt/dynamo/test/utils.py | 54 +++++++ .../dynamo/torch_compile/__init__.py | 126 +++++++++++++++ .../dynamo/{ => torch_compile}/_defaults.py | 0 .../dynamo/{ => torch_compile}/_settings.py | 2 +- .../dynamo/{ => torch_compile}/backends.py | 13 +- .../dynamo/{ => torch_compile}/conversion.py | 0 .../dynamo/torch_compile/lowering/__init__.py | 7 + .../lowering/_decompositions.py | 0 .../lowering/_partition.py | 2 +- .../dynamo/{ => torch_compile}/utils.py | 0 tests/py/api/test_dynamo_backend.py | 136 ----------------- 18 files changed, 481 insertions(+), 270 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/lowering/__init__.py create mode 100644 py/torch_tensorrt/dynamo/test/conftest.py create mode 100644 py/torch_tensorrt/dynamo/test/test_dynamo_backend.py create mode 100644 py/torch_tensorrt/dynamo/test/utils.py create mode 100644 py/torch_tensorrt/dynamo/torch_compile/__init__.py rename py/torch_tensorrt/dynamo/{ => torch_compile}/_defaults.py (100%) rename py/torch_tensorrt/dynamo/{ => torch_compile}/_settings.py (86%) rename py/torch_tensorrt/dynamo/{ => torch_compile}/backends.py (89%) rename py/torch_tensorrt/dynamo/{ => torch_compile}/conversion.py (100%) create mode 100644 py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py rename py/torch_tensorrt/dynamo/{ => torch_compile}/lowering/_decompositions.py (100%) rename py/torch_tensorrt/dynamo/{ => torch_compile}/lowering/_partition.py (98%) rename py/torch_tensorrt/dynamo/{ => torch_compile}/utils.py (100%) delete mode 100644 tests/py/api/test_dynamo_backend.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 91e6a71f7e..1bf8ccd41c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -707,6 +707,23 @@ commands: - store_artifacts: path: /tmp/testlogs + test-dynamo-torch_compile: + description: "Test the Dynamo torch_compile path" + steps: + - run: + name: Run Dynamo torch_compile E2E tests + command: | + cd py/torch_tensorrt/dynamo/ + pushd test/ + pip3 install timm + pip3 install transformers + pytest --junitxml=/tmp/artifacts/test_results/dynamo/test_results.xml --ir torch_compile + popd + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + # Define a job to be invoked later in a workflow. # See: https://circleci.com/docs/2.0/configuration-reference/#jobs jobs: @@ -883,6 +900,68 @@ jobs: - dump-test-env - test-fx-no-aten + test-py-dynamo-x86_64-linux: + parameters: + torch-build: + type: string + torch-build-index: + type: string + trt-version-long: + type: string + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.large + steps: + - checkout + - attach_workspace: + at: /tmp/dist/ + - install-torch-from-index: + torch-build: << parameters.torch-build >> + torch-build-index: << parameters.torch-build-index >> + - create-py-env: + trt-version-long: << parameters.trt-version-long >> + - install-cudnn + # - run: + # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" + # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH + - run: + name: "Install torch-tensorrt" + command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl + # We install torch after torch-trt because pip automatically enforces the version constraint otherwise + - dump-test-env + - test-dynamo-torch_compile + + test-py-dynamo-x86_64-linux-no-aten: + parameters: + torch-build: + type: string + torch-build-index: + type: string + trt-version-long: + type: string + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.large + steps: + - checkout + - attach_workspace: + at: /tmp/dist/ + - install-torch-from-index: + torch-build: << parameters.torch-build >> + torch-build-index: << parameters.torch-build-index >> + - create-py-env: + trt-version-long: << parameters.trt-version-long >> + - install-cudnn + # - run: + # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" + # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH + - run: + name: "Install torch-tensorrt" + command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl + # We install torch after torch-trt because pip automatically enforces the version constraint otherwise + - dump-test-env + - test-dynamo-torch_compile + package-x86_64-linux: parameters: enabled: @@ -1261,6 +1340,13 @@ workflows: requires: - build-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - build-x86_64-linux + - build-x86_64-linux: name: build-x86_64-linux-legacy torch-build: << pipeline.parameters.torch-build-legacy >> @@ -1291,6 +1377,12 @@ workflows: requires: - build-x86_64-linux-legacy + - test-py-dynamo-x86_64-linux-no-aten: + torch-build: << pipeline.parameters.torch-build-legacy >> + torch-build-index: << pipeline.parameters.torch-build-index-legacy >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - build-x86_64-linux-legacy release: when: << pipeline.parameters.enable-packaging >> jobs: @@ -1328,6 +1420,13 @@ workflows: requires: - package-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - package-x86_64-linux + on-push: jobs: - build-x86_64-linux: @@ -1357,6 +1456,13 @@ workflows: requires: - build-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + requires: + - build-x86_64-linux + - build-x86_64-linux-cmake: torch-build: << pipeline.parameters.torch-build >> torch-build-index: << pipeline.parameters.torch-build-index >> diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 3261265215..015a31f465 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -93,6 +93,8 @@ def _find_lib(name, paths): from torch_tensorrt._TRTModuleNext import TRTModuleNext from torch_tensorrt import fx +from torch_tensorrt import dynamo +from torch_tensorrt.dynamo import torch_compile def _register_with_torch(): diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index cbd1b87c5c..8075245c6d 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -15,6 +15,8 @@ class _IRType(Enum): ts = 0 fx = 1 + torch_compile = 2 + fx_ts_compat_compile = 3 class _ModuleType(Enum): @@ -45,11 +47,17 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" + ir_targets_torch_compile = ir == "torch_compile" + ir_targets_fx_ts_compat_compile = ir == "fx_ts_compat_compile" if module_is_tsable and ir_targets_torchscript: return _IRType.ts elif module_is_fxable and ir_targets_fx: return _IRType.fx + elif module_is_fxable and ir_targets_torch_compile: + return _IRType.torch_compile + elif module_is_fxable and ir_targets_fx_ts_compat_compile: + return _IRType.fx_ts_compat_compile else: if ir == "default": # Options are listed in order of preference @@ -148,6 +156,10 @@ def compile( dynamic_batch=False, **kwargs, ) + elif target_ir == _IRType.torch_compile: + return torch_tensorrt.dynamo.torch_compile( + module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs + ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 2497b99789..9046119756 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,126 +1 @@ -import torch -import logging -import collections.abc -import torch_tensorrt -from functools import partial - -from typing import Any -from torch_tensorrt import EngineCapability, Device -from torch_tensorrt.fx.utils import LowerPrecision - -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.backends import tensorrt_backend -from torch_tensorrt.dynamo._defaults import ( - PRECISION, - DEBUG, - MAX_WORKSPACE_SIZE, - MAX_NUM_TRT_ENGINES, -) - - -logger = logging.getLogger(__name__) - - -def compile( - gm: torch.nn.Module, - inputs: Any, - *, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=DEBUG, - capability=EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=MAX_WORKSPACE_SIZE, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - min_block_size=3, - torch_executed_ops=[], - torch_executed_modules=[], - **kwargs, -): - - logger.warn( - "The Dynamo backend is an experimental feature, for which only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" - ) - - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] - - inputs = prepare_inputs(inputs, prepare_device(device)) - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - lower_precision = LowerPrecision.FP16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - lower_precision = LowerPrecision.FP32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - lower_precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) - - custom_backend = create_backend( - precision=lower_precision, - debug=debug, - workspace_size=workspace_size, - **kwargs, - ) - - model = torch.compile(gm, backend=custom_backend) - - # Ensure compilation occurs by calling the function with provided inputs - model(*inputs) - - return model - - -from torch_tensorrt.fx.utils import LowerPrecision - -logger = logging.getLogger(__name__) - - -def create_backend( - precision: LowerPrecision = PRECISION, - debug: bool = DEBUG, - workspace_size: int = MAX_WORKSPACE_SIZE, - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, - **kwargs, -): - """Create torch.compile backend given specified arguments - - Args: - precision: - debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module - precision: Model Layer precision - Returns: - Backend for torch.compile - """ - settings = CompilationSettings( - debug=debug, - precision=precision, - workspace_size=workspace_size, - max_num_trt_engines=max_num_trt_engines, - ) - - return partial( - tensorrt_backend, - settings=settings, - ) +from .torch_compile import compile as torch_compile diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py deleted file mode 100644 index 930cd17fb6..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions -from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs diff --git a/py/torch_tensorrt/dynamo/test/conftest.py b/py/torch_tensorrt/dynamo/test/conftest.py new file mode 100644 index 0000000000..26299953d6 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/conftest.py @@ -0,0 +1,18 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--ir", + metavar="Internal Representation", + nargs=1, + type=str, + required=True, + help="IR to compile with", + choices=["torch_compile", "fx_ts_compat_compile"], + ) + + +@pytest.fixture +def ir(request): + return request.config.getoption("--ir")[0] diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py new file mode 100644 index 0000000000..4852f033bd --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -0,0 +1,144 @@ +import torch +import timm +import pytest + +import torch_tensorrt as torchtrt +import torchvision.models as models + +from transformers import BertModel + +from utils import COSINE_THRESHOLD, cosine_similarity + + +@pytest.mark.unit +def test_resnet18(ir): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_mobilenet_v2(ir): + model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_efficientnet_b0(ir): + model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_bert_base_uncased(ir): + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() + input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, + dtype=input.dtype, + format=torch.contiguous_format, + ), + torchtrt.Input( + input.shape, + dtype=input.dtype, + format=torch.contiguous_format, + ), + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + "debug": True, + "ir": ir, + } + trt_mod = torchtrt.compile(model, **compile_spec) + + model_outputs = model(input, input2) + trt_model_outputs = trt_mod(input, input2) + for key in model_outputs.keys(): + out, trt_out = model_outputs[key], trt_model_outputs[key] + cos_sim = cosine_similarity(out, trt_out) + assert ( + cos_sim > COSINE_THRESHOLD, + f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_resnet18_half(ir): + model = models.resnet18(pretrained=True).eval().to("cuda").half() + input = torch.randn((1, 3, 224, 224)).to("cuda").half() + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.half, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.half}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) diff --git a/py/torch_tensorrt/dynamo/test/utils.py b/py/torch_tensorrt/dynamo/test/utils.py new file mode 100644 index 0000000000..ff6bc39158 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/utils.py @@ -0,0 +1,54 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res + + +def same_output_format(trt_output, torch_output): + # For each encountered collection type, ensure the torch and trt outputs agree + # on type and size, checking recursively through all member elements. + if isinstance(trt_output, tuple): + return ( + isinstance(torch_output, tuple) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, list): + return ( + isinstance(torch_output, list) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, dict): + return ( + isinstance(torch_output, dict) + and (len(trt_output) == len(torch_output)) + and (trt_output.keys() == torch_output.keys()) + and all( + same_output_format(trt_output[key], torch_output[key]) + for key in trt_output.keys() + ) + ) + elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): + raise AssertionError( + "Unsupported output type 'set' encountered in output format check." + ) + else: + return type(trt_output) is type(torch_output) diff --git a/py/torch_tensorrt/dynamo/torch_compile/__init__.py b/py/torch_tensorrt/dynamo/torch_compile/__init__.py new file mode 100644 index 0000000000..32e5567c51 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/__init__.py @@ -0,0 +1,126 @@ +import torch +import logging +import collections.abc +import torch_tensorrt +from functools import partial + +from typing import Any +from torch_tensorrt import EngineCapability, Device +from torch_tensorrt.fx.utils import LowerPrecision + +from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings +from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device +from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend +from torch_tensorrt.dynamo.torch_compile._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, + MAX_NUM_TRT_ENGINES, +) + + +logger = logging.getLogger(__name__) + + +def compile( + gm: torch.nn.Module, + inputs: Any, + *, + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=DEBUG, + capability=EngineCapability.default, + num_avg_timing_iters=1, + workspace_size=MAX_WORKSPACE_SIZE, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + min_block_size=3, + torch_executed_ops=[], + torch_executed_modules=[], + **kwargs, +): + + logger.warn( + "The Dynamo backend is an experimental feature, for which only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" + ) + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + inputs = prepare_inputs(inputs, prepare_device(device)) + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + lower_precision = LowerPrecision.FP16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + lower_precision = LowerPrecision.FP32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + lower_precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + + custom_backend = create_backend( + precision=lower_precision, + debug=debug, + workspace_size=workspace_size, + **kwargs, + ) + + model = torch.compile(gm, backend=custom_backend) + + # Ensure compilation occurs by calling the function with provided inputs + model(*inputs) + + return model + + +from torch_tensorrt.fx.utils import LowerPrecision + +logger = logging.getLogger(__name__) + + +def create_backend( + precision: LowerPrecision = PRECISION, + debug: bool = DEBUG, + workspace_size: int = MAX_WORKSPACE_SIZE, + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + **kwargs, +): + """Create torch.compile backend given specified arguments + + Args: + precision: + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + Backend for torch.compile + """ + settings = CompilationSettings( + debug=debug, + precision=precision, + workspace_size=workspace_size, + max_num_trt_engines=max_num_trt_engines, + ) + + return partial( + tensorrt_backend, + settings=settings, + ) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/torch_compile/_defaults.py similarity index 100% rename from py/torch_tensorrt/dynamo/_defaults.py rename to py/torch_tensorrt/dynamo/torch_compile/_defaults.py diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/torch_compile/_settings.py similarity index 86% rename from py/torch_tensorrt/dynamo/_settings.py rename to py/torch_tensorrt/dynamo/torch_compile/_settings.py index c632943f53..276b8742ff 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/torch_compile/_settings.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo._defaults import ( +from torch_tensorrt.dynamo.torch_compile._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, diff --git a/py/torch_tensorrt/dynamo/backends.py b/py/torch_tensorrt/dynamo/torch_compile/backends.py similarity index 89% rename from py/torch_tensorrt/dynamo/backends.py rename to py/torch_tensorrt/dynamo/torch_compile/backends.py index ad8a14fd65..9ceab947f0 100644 --- a/py/torch_tensorrt/dynamo/backends.py +++ b/py/torch_tensorrt/dynamo/torch_compile/backends.py @@ -4,10 +4,15 @@ from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions -from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs -from torch_tensorrt.dynamo.conversion import convert_module +from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings +from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( + partition, + get_submod_inputs, +) +from torch_tensorrt.dynamo.torch_compile.conversion import convert_module from torch._dynamo.backends.common import fake_tensor_unsupported diff --git a/py/torch_tensorrt/dynamo/conversion.py b/py/torch_tensorrt/dynamo/torch_compile/conversion.py similarity index 100% rename from py/torch_tensorrt/dynamo/conversion.py rename to py/torch_tensorrt/dynamo/torch_compile/conversion.py diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py new file mode 100644 index 0000000000..e0a41df755 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py @@ -0,0 +1,7 @@ +from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( + partition, + get_submod_inputs, +) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py similarity index 100% rename from py/torch_tensorrt/dynamo/lowering/_decompositions.py rename to py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py similarity index 98% rename from py/torch_tensorrt/dynamo/lowering/_partition.py rename to py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py index cbd4904515..1dd38e0bd9 100644 --- a/py/torch_tensorrt/dynamo/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py @@ -2,7 +2,7 @@ import torch -from torch_tensorrt.dynamo._defaults import MAX_NUM_TRT_ENGINES +from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/torch_compile/utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/utils.py rename to py/torch_tensorrt/dynamo/torch_compile/utils.py diff --git a/tests/py/api/test_dynamo_backend.py b/tests/py/api/test_dynamo_backend.py deleted file mode 100644 index 77e2f344cd..0000000000 --- a/tests/py/api/test_dynamo_backend.py +++ /dev/null @@ -1,136 +0,0 @@ -import unittest -import torch -import timm - -import torch_tensorrt as torchtrt -import torchvision.models as models - -from transformers import BertModel -from utils import COSINE_THRESHOLD, cosine_similarity - - -class TestModels(unittest.TestCase): - def test_resnet18(self): - self.model = models.resnet18(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - self.input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - } - - trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) - cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - def test_mobilenet_v2(self): - self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda") - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - self.input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - } - - trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) - cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - def test_efficientnet_b0(self): - self.model = ( - timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") - ) - self.input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - self.input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - } - - trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) - cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - def test_bert_base_uncased(self): - self.model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() - self.input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - self.input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - self.input.shape, - dtype=self.input.dtype, - format=torch.contiguous_format, - ), - torchtrt.Input( - self.input.shape, - dtype=self.input.dtype, - format=torch.contiguous_format, - ), - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, - "debug": True, - } - trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) - - model_outputs = self.model(self.input, self.input2) - trt_model_outputs = trt_mod(self.input, self.input2) - for key in model_outputs.keys(): - out, trt_out = model_outputs[key], trt_model_outputs[key] - cos_sim = cosine_similarity(out, trt_out) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - def test_resnet18_half(self): - self.model = models.resnet18(pretrained=True).eval().to("cuda").half() - self.input = torch.randn((1, 3, 224, 224)).to("cuda").half() - - compile_spec = { - "inputs": [ - torchtrt.Input( - self.input.shape, dtype=torch.half, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.half}, - } - - trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec) - cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - -if __name__ == "__main__": - unittest.main() From dc7dd04d8efe96a06941355c9efa6a0cf73bff43 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Wed, 12 Apr 2023 15:23:34 -0700 Subject: [PATCH 39/45] Update conftest.py --- py/torch_tensorrt/dynamo/test/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/test/conftest.py b/py/torch_tensorrt/dynamo/test/conftest.py index 26299953d6..98be643435 100644 --- a/py/torch_tensorrt/dynamo/test/conftest.py +++ b/py/torch_tensorrt/dynamo/test/conftest.py @@ -9,7 +9,7 @@ def pytest_addoption(parser): type=str, required=True, help="IR to compile with", - choices=["torch_compile", "fx_ts_compat_compile"], + choices=["torch_compile", "fx_ts_compat"], ) From 0d68a47b21742561b762ad4264ae170d83e13dc6 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 14 Apr 2023 10:16:59 -0700 Subject: [PATCH 40/45] chore: Confine the testing to input and core level Signed-off-by: Dheeraj Peri --- .../acc_op/test_adaptive_avgpool.py | 102 - .../test/converters/acc_op/test_any.py | 91 - .../test/converters/acc_op/test_as_strided.py | 59 - .../test/converters/acc_op/test_avgpool.py | 268 -- .../test/converters/acc_op/test_batchnorm.py | 69 - .../test/converters/acc_op/test_binary_ops.py | 215 -- .../test/converters/acc_op/test_cat.py | 104 - .../test/converters/acc_op/test_chunk.py | 82 - .../test/converters/acc_op/test_clamp.py | 74 - .../converters/acc_op/test_convolution.py | 202 -- .../test/converters/acc_op/test_dequantize.py | 71 - .../test/converters/acc_op/test_einsum.py | 69 - .../test/converters/acc_op/test_elu.py | 55 - .../test/converters/acc_op/test_embedding.py | 110 - .../test/converters/acc_op/test_eq.py | 292 -- .../test/converters/acc_op/test_expand.py | 34 - .../test/converters/acc_op/test_flatten.py | 70 - .../test/converters/acc_op/test_gelu.py | 98 - .../test/converters/acc_op/test_getitem.py | 201 -- .../test/converters/acc_op/test_gt.py | 279 -- .../converters/acc_op/test_hard_sigmoid.py | 62 - .../test/converters/acc_op/test_hardtanh.py | 60 - .../converters/acc_op/test_interpolate.py | 153 - .../test/converters/acc_op/test_isinf.py | 66 - .../test/converters/acc_op/test_leaky_relu.py | 55 - .../test/converters/acc_op/test_linear.py | 63 - .../converters/acc_op/test_logical_and.py | 233 -- .../test/converters/acc_op/test_logical_or.py | 204 -- .../converters/acc_op/test_logical_xor.py | 204 -- .../test/converters/acc_op/test_lt.py | 277 -- .../converters/acc_op/test_masked_fill.py | 72 - .../test/converters/acc_op/test_matmul.py | 120 - .../test/converters/acc_op/test_max.py | 163 - .../test/converters/acc_op/test_maximum.py | 85 - .../test/converters/acc_op/test_maxpool.py | 382 --- .../test/converters/acc_op/test_min.py | 162 - .../test/converters/acc_op/test_minimum.py | 85 - .../test/converters/acc_op/test_narrow.py | 58 - .../test/converters/acc_op/test_ne.py | 307 -- .../test/converters/acc_op/test_new_ones.py | 76 - .../test/converters/acc_op/test_numel.py | 41 - .../test/converters/acc_op/test_pad.py | 102 - .../test/converters/acc_op/test_permute.py | 90 - .../test/converters/acc_op/test_prod.py | 121 - .../acc_op/test_quantize_per_tensor.py | 68 - .../test/converters/acc_op/test_reduce_ops.py | 111 - .../test/converters/acc_op/test_relu.py | 55 - .../acc_op/test_repeat_interleave.py | 79 - .../test/converters/acc_op/test_reshape.py | 141 - .../test/converters/acc_op/test_selu.py | 55 - .../test/converters/acc_op/test_sigmoid.py | 38 - .../test/converters/acc_op/test_silu.py | 52 - .../test/converters/acc_op/test_size.py | 74 - .../test/converters/acc_op/test_softmax.py | 68 - .../test/converters/acc_op/test_softsign.py | 55 - .../test/converters/acc_op/test_split.py | 110 - .../test/converters/acc_op/test_squeeze.py | 44 - .../test/converters/acc_op/test_std.py | 120 - .../test/converters/acc_op/test_tanh.py | 55 - .../test/converters/acc_op/test_tile.py | 148 - .../test/converters/acc_op/test_to_dtype.py | 322 -- .../test/converters/acc_op/test_topk.py | 87 - .../acc_op/test_transpose_convolution.py | 140 - .../test/converters/acc_op/test_type_as.py | 153 - .../test/converters/acc_op/test_unary_ops.py | 168 - .../test/converters/acc_op/test_unsqueeze.py | 63 - .../test/converters/acc_op/test_where.py | 114 - .../aten_op/test_adaptive_avgpool_aten.py | 130 - .../converters/aten_op/test_batchnorm_aten.py | 68 - .../aten_op/test_binary_ops_aten.py | 208 -- .../test/converters/aten_op/test_cat_aten.py | 61 - .../aten_op/test_convolution_aten.py | 206 -- .../converters/aten_op/test_expand_aten.py | 31 - .../converters/aten_op/test_flatten_aten.py | 73 - .../converters/aten_op/test_linear_aten.py | 74 - .../converters/aten_op/test_maxpool_aten.py | 248 -- .../test/converters/aten_op/test_relu_aten.py | 54 - .../converters/aten_op/test_reshape_aten.py | 105 - .../converters/vanilla/test_add_vanilla.py | 28 - .../vanilla/test_convolution_vanilla.py | 113 - ...test_fix_clamp_numerical_limits_to_fp16.py | 74 - .../test/passes/test_fix_reshape_batch_dim.py | 51 - .../passes/test_fuse_permute_linear_trt.py | 88 - .../passes/test_fuse_permute_matmul_trt.py | 142 - .../test/passes/test_graph_opts.py | 187 -- .../test/passes/test_multi_fuse_trt.py | 66 - .../test_remove_duplicate_output_args.py | 73 - .../test/passes/test_setitem_trt.py | 600 ---- .../fx_ts_compat/test/quant/test_quant_trt.py | 908 ------ .../test/tools/test_model_packager.py | 56 - .../test/tracer/test_acc_shape_prop.py | 98 - .../test/tracer/test_acc_tracer.py | 2801 ----------------- .../test/tracer/test_dispatch_tracer.py | 245 -- .../fx_ts_compat/test/tracer/test_resnet.py | 86 - .../test/trt_lower/test_diagnostics.py | 200 -- .../test/trt_lower/test_fx2trt_lower.py | 104 - .../test/trt_lower/test_observer.py | 128 - .../test/trt_lower/test_observer_gpu.py | 51 - .../trt_lower/trt_operator_supported_test.py | 82 - .../test/trt_lower/trt_splitter_test.py | 1179 ------- 100 files changed, 16724 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_flatten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_silu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_clamp_numerical_limits_to_fp16.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_reshape_batch_dim.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_graph_opts.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_remove_duplicate_output_args.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_shape_prop.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_tracer.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_diagnostics.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py deleted file mode 100644 index 37f8dcade8..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_adaptive_avgpool.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestAdaptiveAvgPoolConverter(AccTestCase): - @parameterized.expand( - [ - ((64, 64),), - ((128, 64),), - (64,), - ] - ) - def test_adaptive_avgpool( - self, - output_size, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d(output_size) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 256, 256)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.adaptive_avg_pool2d}) - - def test_adaptive_avgpool_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, 256, 256), - dtype=torch.float32, - shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool2d} - ) - - @parameterized.expand( - [ - ((16, 16, 16),), - ((32, 16, 4),), - (32,), - ] - ) - def test_adaptive_avgpool3d( - self, - output_size, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d(output_size) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 32, 64, 64)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.adaptive_avg_pool3d}) - - def test_adaptive_avgpool3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, 32, 64, 64), - dtype=torch.float32, - shape_ranges=[ - ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) - ], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool3d} - ) - - # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py deleted file mode 100644 index 7b50fd4515..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_any.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - -# from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import InputTensorSpec - - -class TestAnyConverters(AccTestCase): - @parameterized.expand( - [ - ("bool", torch.bool), - ("int", torch.int), - ("float", torch.float), - ] - ) - def test_ops(self, _, input_dtype): - class TestModule(nn.Module): - def forward(self, x): - return torch.any(x) - - inputs = [torch.randn(2, 3).to(input_dtype)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.any}, - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("bool", torch.bool, 0), - ("int", torch.int, 1), - ("float", torch.float, 0), - ] - ) - def test_ops_dim(self, _, input_dtype, dim): - class TestModule(nn.Module): - def forward(self, x): - return torch.any(x, dim, keepdim=True) - - inputs = [torch.randn(2, 3).to(input_dtype)] - self.run_test( - TestModule(), inputs, expected_ops={}, test_implicit_batch_dim=False - ) - - @parameterized.expand( - [ - ("bool", torch.bool), - ("int", torch.int), - ("float", torch.float), - ] - ) - def test_ops_method(self, _, input_dtype): - class TestModule(nn.Module): - def forward(self, x): - return x.any() - - inputs = [torch.randn(2, 3).to(input_dtype)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.any}, - test_implicit_batch_dim=False, - ) - - # Testing with shape (-1, -1, -1, -1) results into error: torch.zeros(tuple([*input_t.shape])). Trying to create tensor with negative dimension -1: [-1, -1, -1, -1] - """ - def test_ops_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return torch.any(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.any} - ) - """ - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py deleted file mode 100644 index 3aff0638d6..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_as_strided.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - -# from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase, InputTensorSpec - - -class TestConverter(AccTestCase): - @parameterized.expand( - [ - ("2d_dim_v1", (5, 5), (2, 3), (1, 2), 0), - ("2d_dim_v2", (5, 5), (2, 3), (2, 2), 1), - ("3d_dim_v1", (20, 20), (2, 3, 2), (2, 2, 2), 0), - # take long time on large dimensions, we do not have better implementation yet - # ("4d_dim_v1", (200, 200, 200, 200), (9, 9, 3, 2), (2, 2, 2, 3), 0), - # ("4d_dim_v2", (200, 200, 200, 200), (1, 15, 512, 1), (4096, 256, 1, 1), 0), - ] - ) - def test_as_strided(self, _, x_size, size, stride, offset): - class Stride(nn.Module): - def forward(self, x): - return torch.as_strided(x, size, stride, offset) - - inputs = [torch.randn(*x_size)] - self.run_test( - Stride(), - inputs, - expected_ops={acc_ops.as_strided}, - test_implicit_batch_dim=False, - ) - - # Testing with shape (-1, 3) results into error: - # RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 - - """ - def test_as_strided_with_dynamic_shape_four_dimensions(self): - class Stride(nn.Module): - def forward(self, x): - return torch.as_strided(torch.tensor([5, 5]), (2, 3), (1, 2), 0) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3), - dtype=torch.float32, - shape_ranges=[((1, 3), (2, 3), (2, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Stride(), input_specs, expected_ops={acc_ops.as_strided} - ) - """ - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py deleted file mode 100644 index f9cb1cb9cd..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_avgpool.py +++ /dev/null @@ -1,268 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestAvgPoolConverter(AccTestCase): - @parameterized.expand( - [ - ("default", 1), - ("kernal_size", 3), - ("stride", 1, 2), - ("tuple_parameters", 2, (1,), (1,)), - param("padding", 2, padding=1), - param("ceil_mode", 1, ceil_mode=True), - param("include_pad", 2, padding=1, count_include_pad=False), - ] - ) - def test_avg_pool1d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.avg_pool = torch.nn.AvgPool1d( - kernel_size, stride, padding, ceil_mode, count_include_pad - ) - - def forward(self, x): - return self.avg_pool(x) - - inputs = [torch.randn(1, 3, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d}) - - @parameterized.expand( - [ - ("default", 1), - ("kernal_size", 3), - ("stride", 1, 2), - ("tuple_parameters", 2, (1,), (1,)), - param("padding", 2, padding=1), - param("ceil_mode", 1, ceil_mode=True), - param("include_pad", 2, padding=1, count_include_pad=False), - ] - ) - def test_avg_pool1d_with_dynamic_shape( - self, - test_name="default", - kernel_size=1, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.avg_pool = torch.nn.AvgPool1d( - kernel_size, stride, padding, ceil_mode, count_include_pad - ) - - def forward(self, x): - return self.avg_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d} - ) - - def test_avg_pool2d_with_dynamic_shape_four_dimensions( - self, - test_name="default", - kernel_size=1, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.avg_pool = torch.nn.AvgPool2d( - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override, - ) - - def forward(self, x): - return self.avg_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} - ) - - @parameterized.expand( - [ - ("default", 1), - ("kernal_size", 3), - ("stride", 1, 2), - ("tuple_parameters", 2, (1, 1), (1, 1)), - param("padding", 2, padding=1), - param("ceil_mode", 1, ceil_mode=True), - param("include_pad", 2, padding=1, count_include_pad=False), - ] - ) - def test_avg_pool2d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.avg_pool = torch.nn.AvgPool2d( - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override, - ) - - def forward(self, x): - return self.avg_pool(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d}) - - @parameterized.expand( - [ - ("kernal_size", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_avg_pool1d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.avg_pool1d( - x, - kernel_size, - stride=stride, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ) - - inputs = [torch.randn(1, 3, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d}) - - @parameterized.expand( - [ - ("kernal_size", 2), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_avg_pool2d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.avg_pool2d( - x, - kernel_size, - stride=stride, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - divisor_override=divisor_override, - ) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d}) - - def test_stride_none_avg_pool2d_with_dynamic_shape_four_dimensions( - self, - test_name="default", - kernel_size=1, - stride=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.avg_pool2d( - x, - kernel_size, - stride=stride, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - divisor_override=divisor_override, - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py deleted file mode 100644 index d52bcd8905..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_batchnorm.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestBatchNormConverter(AccTestCase): - def test_batchnorm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - return self.bn(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.batch_norm}) - - def test_batchnorm1d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm1d(3) - - def forward(self, x): - return self.bn(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 5), - dtype=torch.float32, - shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.batch_norm} - ) - - def test_batchnorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - return self.bn(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.batch_norm} - ) - - # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py deleted file mode 100644 index ae006e03a9..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_binary_ops.py +++ /dev/null @@ -1,215 +0,0 @@ -from typing import Callable - -import torch -import torch.nn as nn - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - -NEED_TEST_BOTH_CONSTANTS_CASE = True - -elementwise_ops = [ - ((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE), - ((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE), - ((lambda x, y: torch.sub(x, y)), acc_ops.sub, False), - ((lambda x, y: x.sub(y)), acc_ops.sub, False), - ((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE), - ((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE), - ( - (lambda x, y: torch.div(x, y, rounding_mode="trunc")), - acc_ops.trunc_div, - not NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ( - (lambda x, y: torch.div(x, y, rounding_mode="floor")), - acc_ops.floor_div, - NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ((lambda x, y: torch.div(x, y)), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE), - ((lambda x, y: torch.fmod(x, y)), acc_ops.fmod, not NEED_TEST_BOTH_CONSTANTS_CASE), - # torch.floor_divide rounds result toward zero, rather than -Inf. - # https://github.com/pytorch/pytorch/issues/43874 - ( - (lambda x, y: torch.floor_divide(x, y)), - acc_ops.trunc_div, - not NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ((lambda x, y: x * y), acc_ops.mul, NEED_TEST_BOTH_CONSTANTS_CASE), - (torch.pow, acc_ops.pow, not NEED_TEST_BOTH_CONSTANTS_CASE), -] - - -class TestBinaryOpConverters(AccTestCase): - @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) - def test_elementwise_ops(self, name, orig_op: Callable, expected_op): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, x): - return self.orig_op(x, x) - - m = TestModule(orig_op) - # Avoid dividing by 0. - inputs = [torch.rand(1, 1) + 1] - self.run_test(m, inputs, expected_ops={expected_op}) - - @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) - def test_elementwise_ops_with_one_constant( - self, name, orig_op: Callable, expected_op - ): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.constant = torch.randn(1) - self.orig_op = orig_op - - def forward(self, x): - x = self.orig_op(x, self.constant) - return self.orig_op(x, -2) - - m = TestModule(orig_op) - inputs = [torch.randn(2, 2)] - self.run_test(m, inputs, expected_ops={expected_op}) - - @parameterized.expand( - [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] - ) - def test_elementwise_op_with_both_constants( - self, name, orig_op: Callable, expected_op - ): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.constant0 = torch.nn.Parameter(torch.randn(1)) - self.constant1 = torch.nn.Parameter(torch.randn(1)) - self.orig_op = orig_op - - def forward(self, x): - const = self.orig_op(self.constant0, self.constant1) - return self.orig_op(x, const) - - m = TestModule(orig_op) - inputs = [torch.randn(2, 2)] - self.run_test(m, inputs, expected_ops={expected_op}) - - @parameterized.expand( - [ - ( - f"no_broadcast_{op[1].__name__}", - (-1, -1), - ((1, 1), (2, 2), (3, 3)), - (-1, -1), - ((1, 1), (2, 2), (3, 3)), - op[0], - op[1], - ) - for op in elementwise_ops - ] - + [ - ( - f"broadcast_{op[1].__name__}", - (-1, -1, -1), - ((1, 1, 1), (2, 2, 2), (3, 3, 3)), - (-1, -1), - ((1, 1), (2, 2), (3, 3)), - op[0], - op[1], - ) - for op in elementwise_ops - ] - ) - def test_elementwise_op_with_dynamic_shape( - self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op - ): - class Op(nn.Module): - def forward(self, x, y): - return orig_op(x, y) - - input_specs = [ - InputTensorSpec( - shape=x_shape, - dtype=torch.float32, - shape_ranges=[x_shape_ranges], - ), - InputTensorSpec( - shape=y_shape, - dtype=torch.float32, - shape_ranges=[y_shape_ranges], - ), - ] - - self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) - - @parameterized.expand( - [ - ( - f"no_broadcast_{op[1].__name__}", - op[0], - op[1], - ) - for op in elementwise_ops - ] - + [ - ( - f"broadcast_{op[1].__name__}", - op[0], - op[1], - ) - for op in elementwise_ops - ] - ) - def test_elementwise_op_with_dynamic_shape_four_dimensions( - self, _, orig_op, expected_op - ): - class Op(nn.Module): - def forward(self, x, y): - return orig_op(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) - - def test_elementwise_ops_with_scalar_lhs(self): - def orig_op(x, y): - return x + y - - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.constant = torch.randn(1) - self.orig_op = orig_op - - def forward(self, x): - return self.orig_op(x, self.constant) - - m = TestModule(orig_op) - inputs = [torch.randn(10)] - self.run_test( - m, - inputs, - expected_ops={acc_ops.add}, - test_explicit_batch_dim=False, - test_implicit_batch_dim=True, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py deleted file mode 100644 index 807ab8842e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_cat.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestCatConverter(AccTestCase): - @parameterized.expand( - [ - param("cat", torch.cat), - param("concat", torch.concat), - ] - ) - def test_cat(self, _, op): - class Cat(nn.Module): - def forward(self, x, y, z): - return op((x, y, z), 1) - - inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] - self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) - - @parameterized.expand( - [ - param("cat", torch.cat), - param("concat", torch.concat), - ] - ) - def test_cat_neg(self, _, op): - class Cat(nn.Module): - def forward(self, x, y, z): - return op((x, y, z), -1) - - inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)] - self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) - - @parameterized.expand( - [ - param("cat", torch.cat), - param("concat", torch.concat), - ] - ) - def test_cat_with_dynamic_shape(self, _, op): - class Cat(nn.Module): - def forward(self, x, y): - x = x + y - return op((x, y), 0) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (2, 3, 4), (2, 3, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (2, 3, 4), (2, 3, 10))], - ), - ] - self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) - - @parameterized.expand( - [ - param("cat", torch.cat), - param("concat", torch.concat), - ] - ) - def test_cat_with_dynamic_shape_four_dimensions(self, _, op): - class Cat(nn.Module): - def forward(self, x, y): - x = x + y - return op((x, y), 0) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 4), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 4), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) - - def test_concat(self): - class Cat(nn.Module): - def forward(self, x, y, z): - return torch.concat((x, y, z), 1) - - inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] - self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py deleted file mode 100644 index 42706d8e1f..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_chunk.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestChunkConverter(AccTestCase): - @parameterized.expand( - [ - ("chunk", 3, 1), - ("chunk", 2000, 2), - ("chunk", 3, -2), - ] - ) - def test_chunk(self, _, chunk, dim): - class Chunk(nn.Module): - def forward(self, x): - return x.chunk(chunk, dim)[0] - - inputs = [torch.randn(3, 10, 20)] - self.run_test( - Chunk(), - inputs, - expected_ops={acc_ops.chunk}, - ) - - @parameterized.expand( - [ - ("chunk", 3, 1), - ("chunk", 2000, 1), - ("chunk", 3, -2), - ] - ) - def test_chunk_with_dynamic_shape(self, _, chunk, dim): - class Chunk(nn.Module): - def forward(self, x): - return x.chunk(chunk, dim)[0] - - input_specs = [ - InputTensorSpec( - shape=(-1, 10, -1), - dtype=torch.float32, - shape_ranges=[((1, 10, 20), (5, 10, 20), (10, 10, 20))], - ), - ] - self.run_test_with_dynamic_shape( - Chunk(), input_specs, expected_ops={acc_ops.chunk} - ) - - # Testing with (-1, -1, -1, -1) results in Error: AssertionError: Can't chunk on dynamic shape dimension! - @parameterized.expand( - [ - ("chunk", 3, 1), - ("chunk", 2000, 1), - ("chunk", 3, -2), - ] - ) - def test_chunk_with_dynamic_shape_four_dimensions(self, _, chunk, dim): - class Chunk(nn.Module): - def forward(self, x): - return x.chunk(chunk, dim)[0] - - input_specs = [ - InputTensorSpec( - shape=(-1, 1, 3, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 3, 5), (3, 1, 3, 5), (5, 1, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - Chunk(), input_specs, expected_ops={acc_ops.chunk} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py deleted file mode 100644 index a64d58a98b..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_clamp.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestClampConverter(AccTestCase): - @parameterized.expand( - [ - param("default", min=-1, max=0), - param("min", min=0.5), - param("max", max=0.5), - param("minBiggerThanMax", min=1, max=0), - param("float32Boundary", min=-3.4028234663852886e38), - ] - ) - def test_clamp( - self, - test_name, - min=None, - max=None, - ): - class TestModule(torch.nn.Module): - def forward(self, x): - return torch.clamp(x, min, max) - - inputs = [torch.randn(3, 4)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.clamp}) - - @parameterized.expand( - [ - param("default", min=-1, max=0), - param("min", min=0.5), - param("max", max=0.5), - param("minBiggerThanMax", min=1, max=0), - ] - ) - def test_clamp_with_dynamic_shape_four_dimensions( - self, - test_name, - min=None, - max=None, - ): - class TestModule(torch.nn.Module): - def forward(self, x): - return torch.clamp(x, min, max) - - class TestScalarModule(torch.nn.Module): - def forward(self, x): - y = torch.sum(x) - return torch.clamp(y, min, max) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.clamp} - ) - self.run_test_with_dynamic_shape( - TestScalarModule(), input_specs, expected_ops={acc_ops.clamp} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py deleted file mode 100644 index ab29f0dfc3..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_convolution.py +++ /dev/null @@ -1,202 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestConvolutionConverter(AccTestCase): - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1), (1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv1d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.conv1d}, - test_explicit_precision=True, - ) - - @parameterized.expand( - [ - ("default", 1), - ] - ) - def test_conv1d_with_dynamic_shape( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv1d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1), (1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv2d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) - - # Testing with (-1, -1, -1, -1) results into Error: - # AssertionError: Channel dim can't be dynamic for convolution. - - def test_conv2d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 1) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv2d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - # TODO TRT 8.4.1 will trigger issue with this test. T127981773 - # param("groups", 1, groups=3), - ] - ) - def test_conv3d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) - - # Testing with (-1, -1, -1, -1, -1) results into Error: - # AssertionError: Channel dim can't be dynamic for convolution. - - def test_conv3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d(3, 6, 1) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv3d} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py deleted file mode 100644 index 1f7f6cbe88..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_dequantize.py +++ /dev/null @@ -1,71 +0,0 @@ -import unittest - -import tensorrt as trt -import torch.fx -import torch.nn as nn - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -@unittest.skip( - """ - Tests related to quantize have issue creating engine, disable now. - """ -) -@unittest.skipIf( - trt.__version__ < "8.0", - "Explicit quantization only supported in TensorRT 8.0 and later", -) -class TestDequantizeConverter(AccTestCase): - def test_dequantize(self): - class TestModule(nn.Module): - def forward(self, x): - x = torch.quantize_per_tensor(x, 1, 0, torch.quint8) - return x.dequantize() - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.dequantize}) - - def test_dequantize_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - x = torch.quantize_per_tensor(x, 1, 0, torch.quint8) - return x.dequantize() - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.dequantize} - ) - - def test_dequantize_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - x = torch.quantize_per_tensor(x, 1, 0, torch.quint8) - return x.dequantize() - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.dequantize} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py deleted file mode 100644 index c6beebdf4c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_einsum.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestConverter(AccTestCase): - @parameterized.expand( - [ - ("2d_dim", "ij,jk->ik", (2, 3), (3, 4)), - ("2d_dim_ext", "ij,kj->ik", (2, 3), (4, 3)), - ("3d_dim", "cxd,cyd->cxy", (3, 4, 5), (3, 6, 5)), - ("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)), - ("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)), - # TRT does not support ellipsis or diagonal operations - ] - ) - def test_einsum(self, _, equation, x_size, y_size): - class Einsum(nn.Module): - def forward(self, x, y): - return torch.einsum(equation, x, y) - - inputs = [torch.randn(*x_size), torch.randn(*y_size)] - self.run_test( - Einsum(), - inputs, - expected_ops={acc_ops.einsum}, - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)), - ("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)), - # TRT does not support ellipsis or diagonal operations - ] - ) - def test_einsum_with_dynamic_shape_four_dimensions( - self, _, equation, x_size, y_size - ): - class Einsum(nn.Module): - def forward(self, x, y): - return torch.einsum(equation, x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Einsum(), input_specs, expected_ops={acc_ops.einsum} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py deleted file mode 100644 index c35154bd76..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_elu.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestELUConverter(AccTestCase): - def test_elu(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.elu(x, alpha=1.5) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.elu}) - - def test_elu_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.elu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.elu} - ) - - def test_elu_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.elu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 5), (3, 3, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.elu} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py deleted file mode 100644 index 05186300b4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_embedding.py +++ /dev/null @@ -1,110 +0,0 @@ -import unittest - -import torch - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -@unittest.skip( - "Current implementation is limited. All implementations in hf use int64. T113156424" -) -class TestEmbeddingConverter(AccTestCase): - @parameterized.expand( - [ - param( - test_name="1d_indices", - indices_tensor=torch.tensor([3, 1, 2]), - weights_tensor=torch.randn(5, 10), - ), - param( - test_name="2d_indices", - indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]]), - weights_tensor=torch.randn(5, 10), - ), - param( - test_name="3d_indices", - indices_tensor=torch.tensor([[[0, 1], [2, 3]], [[3, 4], [4, 0]]]), - weights_tensor=torch.randn(5, 10), - ), - ] - ) - def test_embedding( - self, - test_name, - indices_tensor, - weights_tensor, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - ): - class TestEmbedding(torch.nn.Module): - def forward(self, indices, weights): - return torch.nn.functional.embedding( - input=indices, - weight=weights, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - ) - - self.run_test( - TestEmbedding(), - inputs=[indices_tensor.int(), weights_tensor.float()], - expected_ops={acc_ops.embedding}, - test_implicit_batch_dim=False, - test_explicit_batch_dim=True, - ) - - def test_embedding_with_dynamic_shape_four_dimensions( - self, - test_name, - indices_tensor, - weights_tensor, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - ): - class TestEmbedding(torch.nn.Module): - def forward(self, indices, weights): - return torch.nn.functional.embedding( - input=indices, - weight=weights, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestEmbedding(), input_specs, expected_ops={acc_ops.embedding} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py deleted file mode 100644 index 8cb9185673..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_eq.py +++ /dev/null @@ -1,292 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestEqConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ( - "rand_4d_float_bool_dim", - torch.randn(3, 4, 5, 6).to(torch.float), - torch.randn(3, 1, 1, 6).to(torch.bool), - ), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def forward(self, x, y): - mask = torch.eq(x, y) - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False - ) - - -class TestEqMethodConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ( - "rand_4d_float_bool_dim", - torch.randn(3, 4, 5, 6).to(torch.float), - torch.randn(3, 1, 1, 6).to(torch.bool), - ), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def forward(self, x, y): - mask = x.eq(y) - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ( - "rand_4d_float_bool_dim", - torch.randn(3, 4, 5, 6).to(torch.float), - torch.randn(3, 1, 1, 6).to(torch.bool), - ), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def forward(self, x, y): - mask = x == y - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def forward(self, x, y): - return x == y - - inputs = [ - input, - other, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase): - def test_eq(self): - class Eq(torch.nn.Module): - def forward(self, x, y): - return x == y - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq}) - - -class TestEqOperatorConstantConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ("rand_2d_float_single_bool", torch.randn(3, 4), False), - ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), - ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def __init__(self): - super().__init__() - self.other = other - - def forward(self, x): - return x == self.other - - inputs = [ - input, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverter(AccTestCase): - def test_eq(self): - class Eq(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] == 4 - - input = torch.randn(3, 4) - inputs = [ - input, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.eq}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverterWithDynamicShape(AccTestCase): - def test_eq(self): - class Eq(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] == 4 - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py deleted file mode 100644 index e7021e2353..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_expand.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -class TestExpandConverter(AccTestCase): - @parameterized.expand( - [ - ("2d_dim", (2, 3), (2, 1)), - ("3d_dim", (2, 3, 4), (2, 1, 1)), - ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), - ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), - ] - ) - def test_expand(self, _, sizes, init_size): - class Expand(nn.Module): - def forward(self, x): - return x.expand(*sizes) - - inputs = [torch.randn(*init_size)] - self.run_test( - Expand(), - inputs, - expected_ops={acc_ops.expand}, - ) - - # Dynamic shape is not suitable for the expand operation. - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_flatten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_flatten.py deleted file mode 100644 index 346669d695..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_flatten.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec - - -class TestFlattenConverter(AccTestCase): - @parameterized.expand( - [ - ("flatten_middle_dims", 1, 2), - ("flatten_last_3_dims", 1, 3), - ("flatten_last_1", 3, 3), - ("flatten_all", 0, 3), - ] - ) - def test_flatten(self, _, start_dim, end_dim): - class Flatten(nn.Module): - def __init__(self, start, end): - super().__init__() - self.start = start - self.end = end - - def forward(self, x): - return torch.flatten(x, self.start, self.end) - - inputs = [torch.randn(1, 2, 3, 1)] - self.run_test( - Flatten(start_dim, end_dim), - inputs, - expected_ops={acc_ops.flatten}, - test_implicit_batch_dim=(start_dim != 0), - ) - - @parameterized.expand( - [ - ("flatten_middle_dims", 1, 2), - ("flatten_last_3_dims", 2, 4), - ("flatten_last_1", 4, 4), - ("flatten_first_2", 0, 1), - ("flatten_all", 0, 4), - ] - ) - def test_flatten_with_dynamic_shape(self, _, start_dim, end_dim): - class Flatten(nn.Module): - def __init__(self, start, end): - super().__init__() - self.start = start - self.end = end - - def forward(self, x): - return torch.flatten(x, self.start, self.end) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 3, 2, 1), (3, 3, 3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Flatten(start_dim, end_dim), - input_specs, - expected_ops={acc_ops.flatten}, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py deleted file mode 100644 index 1c7c8264f2..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gelu.py +++ /dev/null @@ -1,98 +0,0 @@ -import unittest - -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -@unittest.skip( - reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4" -) -class TestGELU(AccTestCase): - def test_gelu(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.gelu(x) - - inputs = [torch.randn(3, 10, 20)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.gelu}, - test_implicit_batch_dim=False, - ) - - def test_gelu_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.gelu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.gelu} - ) - - def test_gelu_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.gelu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.gelu} - ) - - def test_gelu_module(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.gelu = torch.nn.GELU() - - def forward(self, x): - return self.gelu(x) - - inputs = [torch.randn(3, 10, 20)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.gelu}, - test_implicit_batch_dim=False, - ) - - def test_gelu_module_throw(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.gelu = torch.nn.GELU(approximate="tanh") - - def forward(self, x): - return self.gelu(x) - - inputs = [torch.randn(3, 10, 20)] - self.run_test_with_assert_error( - TestModule(), - inputs, - expect_error=RuntimeError, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py deleted file mode 100644 index 880cbe2418..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_getitem.py +++ /dev/null @@ -1,201 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestGetitemConverter(AccTestCase): - @parameterized.expand( - [ - ("slice_batch_dim", slice(None, None, None)), - ("slice_basic", (slice(None, None, None), slice(0, 3, 2))), - ("slice_full", (slice(None, None, None), slice(0, 10, 3))), - ("ellipsis", (slice(None, None, None), ..., slice(0, 3, 2))), - ( - "slice_all_none", - (slice(None, None, None), slice(None, None, None)), - ), - ( - "slice_start_none", - (slice(None, None, None), slice(None, 2, 1)), - ), - ("slice_end_none", (slice(None, None, None), slice(1, None, 1))), - ( - "slice_step_none", - (slice(None, None, None), slice(0, 3, None)), - ), - ("slice_neg_idx", (slice(None, None, None), -1)), - ("slice_neg_slice", (slice(None, None, None), slice(-8, -2, 3))), - ("multi_dim", (slice(None, None, None), 0, 1)), - ( - "slice_multi_dim", - (slice(None, None, None), slice(0, 3, 2), slice(1, -1, 3)), - ), - ( - "none", - (slice(None, None, None), None, slice(1, -1, 3), 1), - ), - ( - "slice_zero_slice", - (slice(None, None, None), slice(None, None, None), slice(0, 0, None)), - ), - ] - ) - def test_getitem(self, _, idx): - class Getitem(nn.Module): - def __init__(self, idx): - super().__init__() - self.idx = idx - - def forward(self, x): - x = x + x - return x[self.idx] - - inputs = [torch.randn(2, 10, 10, 10)] - self.run_test(Getitem(idx), inputs, expected_ops={acc_ops.getitem}) - - @parameterized.expand( - [ - ("slice_batch_dim", slice(None, None, None)), - ("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))), - ( - "slice_all_none", - (slice(None, None, None), slice(None, None, None)), - ), - ( - "slice_end_none", - (slice(None, None, None), slice(None, None, None), slice(1, None, 1)), - ), - ( - "slice_step_none", - (slice(None, None, None), slice(None, None, None), slice(0, 3, None)), - ), - ("slice_neg_idx", (slice(None, None, None), -1, slice(None, None, None))), - ( - "slice_neg_slice", - (slice(None, None, None), slice(None, None, None), slice(-8, -2, 3)), - ), - ("multi_dim", (slice(None, None, None), 0, 1)), - ( - "slice_multi_dim", - (slice(None, None, None), slice(0, 3, 2), slice(1, -1, 3)), - ), - ( - "none", - (slice(None, None, None), None, slice(1, -1, 3)), - ), - ] - ) - def test_getitem_with_dynamic_shape(self, _, idx): - class Getitem(nn.Module): - def __init__(self, idx): - super().__init__() - self.idx = idx - - def forward(self, x): - x = x + x - return x[self.idx] - - input_specs = [ - InputTensorSpec( - shape=(-1, 256, 256), - dtype=torch.float32, - shape_ranges=[((1, 256, 256), (3, 256, 256), (5, 256, 256))], - ), - ] - self.run_test_with_dynamic_shape( - Getitem(idx), input_specs, expected_ops={acc_ops.getitem} - ) - - @parameterized.expand( - [ - ("slice_batch_dim", slice(None, None, None)), - ("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))), - ( - "slice_all_none", - (slice(None, None, None), slice(None, None, None)), - ), - ( - "slice_end_none", - (slice(None, None, None), slice(None, None, None), slice(1, None, 1)), - ), - ( - "slice_step_none", - (slice(None, None, None), slice(None, None, None), slice(0, 3, None)), - ), - ] - ) - def test_getitem_with_multi_dynamic_shape(self, _, idx): - class Getitem(nn.Module): - def __init__(self, idx): - super().__init__() - self.idx = idx - - def forward(self, x): - x = x + x - return x[self.idx] - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, 256), - dtype=torch.float32, - shape_ranges=[((1, 128, 256), (3, 192, 256), (5, 256, 256))], - ), - ] - self.run_test_with_dynamic_shape( - Getitem(idx), input_specs, expected_ops={acc_ops.getitem} - ) - - # Testing with following parameters results into Error: - # AssertionError: We don't support slicing tensor on dynamic shape. - """ - ("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))), - ( - "slice_end_none", - (slice(None, None, None), slice(None, None, None), slice(1, None, 1)), - ), - ( - "slice_step_none", - (slice(None, None, None), slice(None, None, None), slice(0, 3, None)), - ), - """ - - @parameterized.expand( - [ - ("slice_batch_dim", slice(None, None, None)), - ( - "slice_all_none", - (slice(None, None, None), slice(None, None, None)), - ), - ] - ) - def test_getitem_with_dynamic_shape_four_dimensions(self, _, idx): - class Getitem(nn.Module): - def __init__(self, idx): - super().__init__() - self.idx = idx - - def forward(self, x): - x = x + x - return x[self.idx] - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - Getitem(idx), input_specs, expected_ops={acc_ops.getitem} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py deleted file mode 100644 index fac763acb0..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_gt.py +++ /dev/null @@ -1,279 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestGtConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ] - ) - def test_gt(self, _, input, other): - class Gt(torch.nn.Module): - def forward(self, x, y): - mask = torch.gt(x, y) - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False - ) - - -class TestGtMethodConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ] - ) - def test_gt(self, _, input, other): - class Gt(torch.nn.Module): - def forward(self, x, y): - mask = x.gt(y) - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False - ) - - -class TestGtOperatorConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ] - ) - def test_gt(self, _, input, other): - class Gt(torch.nn.Module): - def forward(self, x, y): - mask = x > y - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def forward(self, x, y): - return x > y - - inputs = [ - input, - other, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase): - def test_eq( - self, - ): - class Eq(torch.nn.Module): - def forward(self, x, y): - return x > y - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.gt}) - - -class TestEqOperatorConstantConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ("rand_2d_float_single_bool", torch.randn(3, 4), False), - ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), - ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def __init__(self): - super().__init__() - self.other = other - - def forward(self, x): - return x > self.other - - inputs = [ - input, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverter(AccTestCase): - def test_gt(self): - class Gt(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] > 4 - - input = torch.randn(3, 4) - inputs = [ - input, - ] - self.run_test( - Gt(), inputs, expected_ops={acc_ops.gt}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverterWithDynamicShape(AccTestCase): - def test_gt(self): - class Gt(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] > 4 - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape(Gt(), input_specs, expected_ops={acc_ops.gt}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py deleted file mode 100644 index cfe0e2b52e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hard_sigmoid.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -from parameterized import parameterized -from torch import nn -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) -from torch_tensorrt.fx.tracer.acc_tracer import acc_ops - - -class TestHardSigmoid(AccTestCase): - @parameterized.expand( - [ - ("3", 3), - ("0", 0), - ("-3", -4), - ] - ) - def test_hardsigmoid(self, _, pad): - class Hardsigmoid(nn.Module): - def forward(self, x): - return torch.nn.functional.hardsigmoid(x) - - inputs = [torch.randn(1, 2, 3) + pad] - self.run_test(Hardsigmoid(), inputs, expected_ops={acc_ops.hardsigmoid}) - - def test_hardsigmoid_with_dynamic_shape(self): - class Hardsigmoid(nn.Module): - def forward(self, x): - return torch.nn.functional.hardsigmoid(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Hardsigmoid(), input_specs, expected_ops={acc_ops.hardsigmoid} - ) - - def test_hardsigmoid_with_dynamic_shape_four_dimensions(self): - class Hardsigmoid(nn.Module): - def forward(self, x): - return torch.nn.functional.hardsigmoid(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Hardsigmoid(), input_specs, expected_ops={acc_ops.hardsigmoid} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py deleted file mode 100644 index 469816e2b4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_hardtanh.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestHardtanhConverter(AccTestCase): - @parameterized.expand( - [ - (-2.0, 6), - (0, 1), - (0.5, 7), - ] - ) - def test_hardtanh(self, test_min_value, test_max_value): - class Hardtanh(nn.Module): - def forward(self, x): - return nn.functional.hardtanh( - x, min_val=test_min_value, max_val=test_max_value - ) - - inputs = [torch.randn(2, 10, 10, 10)] - self.run_test(Hardtanh(), inputs, expected_ops={acc_ops.hardtanh}) - - -class TestHardtanhConverterWithDynamicShape(AccTestCase): - @parameterized.expand( - [ - (-2.0, 6), - (0, 1), - (0.5, 7), - ] - ) - def test_hardtanh(self, test_min_value, test_max_value): - class Hardtanh(nn.Module): - def forward(self, x): - return nn.functional.hardtanh( - x, min_val=test_min_value, max_val=test_max_value - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Hardtanh(), input_specs, expected_ops={acc_ops.hardtanh} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py deleted file mode 100644 index 8eefb88ed9..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_interpolate.py +++ /dev/null @@ -1,153 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestInterpolateConverter(AccTestCase): - @parameterized.expand( - [ - # 3D - ("3d_dim_scale", (2, 3, 4), (None), (2), ("nearest"), (None)), - ("3d_dim_scale_seq", (2, 3, 4), (None), (2,), ("nearest"), (None)), - ("3d_dim_size", (2, 3, 4), (2), (None), ("nearest"), (None)), - ("3d_dim_size_seq", (2, 3, 4), (8,), (None), ("nearest"), (None)), - ( - "3d_dim_scale_linear", - (2, 3, 4), - (None), - (2), - ("linear"), - (None), - ), # linear for 3D only - ( - "3d_dim_scale_align", - (2, 3, 4), - (None), - (2), - ("linear"), - (True), - ), # align_corners for linear,bilinear,trilinear,bicubic only - # 4D - ("4d_dim_scale", (2, 3, 4, 5), (None), (2), ("nearest"), (None)), - ("4d_dim_scale_seq", (2, 3, 4, 5), (None), (2, 2), ("nearest"), (None)), - ("4d_dim_size", (2, 3, 4, 5), (2), (None), ("nearest"), (None)), - ("4d_dim_size_seq", (2, 3, 4, 5), (8, 10), (None), ("nearest"), (None)), - ( - "4d_dim_scale_bilinear", - (2, 3, 4, 5), - (None), - (2), - ("bilinear"), - (None), - ), # linear for 4D only - ( - "4d_dim_scale_bilinear_align_corners_bool", - (2, 3, 4, 5), - (None), - (2), - ("bilinear"), - (False), - ), # linear for 4D only - ( - "4d_dim_scale_align", - (2, 3, 4, 5), - (None), - (2), - ("bilinear"), - (True), - ), # align_corners for linear,bilinear,trilinear,bicubic only - # 5D - ("5d_dim_scale", (2, 3, 4, 5, 6), (None), (2), ("nearest"), (None)), - ( - "5d_dim_scale_seq", - (2, 3, 4, 5, 6), - (None), - (2, 2, 2), - ("nearest"), - (None), - ), - ("5d_dim_size", (2, 3, 4, 5, 6), (2), (None), ("nearest"), (None)), - ( - "5d_dim_size_seq", - (2, 3, 4, 5, 6), - (8, 10, 12), - (None), - ("nearest"), - (None), - ), - ( - "5d_dim_scale_trilinear", - (2, 3, 4, 5, 6), - (None), - (2), - ("trilinear"), - (None), - ), # trilinear for 5D only - ( - "5d_dim_scale_align", - (2, 3, 4, 5, 6), - (None), - (2), - ("trilinear"), - (True), - ), # align_corners for linear,bilinear,trilinear,bicubic only - ] - ) - def test_interpolate(self, _, init_size, size, scale_factor, mode, align_corners): - class Interpolate(nn.Module): - def forward(self, x): - return torch.nn.functional.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - ) # only one of size or scale_factor should be defined - - inputs = [torch.randn(*init_size)] - self.run_test( - Interpolate(), - inputs, - expected_ops={acc_ops.interpolate}, - ) - - @parameterized.expand( - [ - # 4D - ("4d_dim_scale", (2, 3, 4, 5), (None), (2), ("nearest"), (None)), - ] - ) - def test_interpolate_with_dynamic_shape_four_dimensions( - self, _, init_size, size, scale_factor, mode, align_corners - ): - class Interpolate(nn.Module): - def forward(self, x): - return torch.nn.functional.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - ) # only one of size or scale_factor should be defined - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Interpolate(), input_specs, expected_ops={acc_ops.interpolate} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py deleted file mode 100644 index 89c65e7eff..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_isinf.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest - -import torch - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -@unittest.skip("Implementation is commented out due to accuracy issue T113156424") -class TestInfConverter(AccTestCase): - def test_isinf(self): - class Test(torch.nn.Module): - def forward(self, x): - return torch.isinf(x) - - input = torch.randn(2, 3) - input[0][0] = float("inf") - input[0][1] = float("-inf") - input.cuda() - inputs = [ - input, - ] - self.run_test( - Test(), inputs, expected_ops={acc_ops.isinf}, test_implicit_batch_dim=False - ) - - def test_isinf_large(self): - class Test(torch.nn.Module): - def forward(self, x): - return torch.isinf(x) - - input = torch.randn(2, 3, 4, 5) - input[0][0][0][:] = float("inf") - input[0][0][1][:] = float("-inf") - input.cuda() - inputs = [ - input, - ] - self.run_test( - Test(), inputs, expected_ops={acc_ops.isinf}, test_implicit_batch_dim=False - ) - - def test_isinf_large_with_dynamic_shape_four_dimensions(self): - class Test(torch.nn.Module): - def forward(self, x): - return torch.isinf(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Test(), input_specs, expected_ops={acc_ops.isinf} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py deleted file mode 100644 index 02deb0ee57..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_leaky_relu.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestLeakyReLUConverter(AccTestCase): - def test_leaky_relu(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.leaky_relu(x, negative_slope=0.05) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.leaky_relu}) - - def test_leaky_relu_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.leaky_relu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.leaky_relu} - ) - - def test_leaky_relu_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.leaky_relu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.leaky_relu} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py deleted file mode 100644 index 25353e8f29..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_linear.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestLinearConverter(AccTestCase): - @parameterized.expand( - [ - ("default", [1, 512]), - ("matrix", [32, 512]), - ("no_bias", [1, 512], False), - ] - ) - def test_linear( - self, - test_name, - shape, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(512, 256, bias) - - def forward(self, x): - return self.linear(x) - - inputs = [torch.randn(shape)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.linear}) - - def test_linear_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(512, 256) - - def forward(self, x): - return self.linear(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 512), - dtype=torch.float32, - shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={acc_ops.linear}, - ) - - # Testing with (-1, -1, 512) results into following error: - # AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py deleted file mode 100644 index 71851221c2..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_and.py +++ /dev/null @@ -1,233 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestAndMethodSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_and(self, _, input, other): - class And(torch.nn.Module): - def forward(self, x, y): - return x.logical_and(y) - - inputs = [ - input, - other, - ] - self.run_test( - And(), - inputs, - expected_ops={acc_ops.logical_and}, - test_implicit_batch_dim=False, - ) - - -class TestAndMethodSimpleConverterWithDynamicShape(AccTestCase): - def test_and(self): - class And(torch.nn.Module): - def forward(self, x, y): - return x.logical_and(y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - And(), input_specs, expected_ops={acc_ops.logical_and} - ) - - -class TestAndFunctionSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_and(self, _, input, other): - class And(torch.nn.Module): - def forward(self, x, y): - return torch.logical_and(x, y) - - inputs = [ - input, - other, - ] - self.run_test( - And(), - inputs, - expected_ops={acc_ops.logical_and}, - test_implicit_batch_dim=False, - ) - - -class TestAndOperatorSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_and(self, _, input, other): - class And(torch.nn.Module): - def forward(self, x, y): - return x & y - - inputs = [ - input, - other, - ] - self.run_test( - And(), - inputs, - expected_ops={acc_ops.bitwise_and}, - test_implicit_batch_dim=False, - ) - - -class TestAndOperatorConstantConverter(AccTestCase): - @parameterized.expand( - [ - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_and(self, _, input, other): - class And(torch.nn.Module): - def __init__(self): - super().__init__() - self.other = other - - def forward(self, x): - return x & self.other - - inputs = [ - input, - ] - self.run_test( - And(), - inputs, - expected_ops={acc_ops.bitwise_and}, - test_implicit_batch_dim=False, - ) - - -class TestAndFunctionSimpleConverterWithDynamicShape(AccTestCase): - def test_and(self): - class And(torch.nn.Module): - def forward(self, x, y): - return torch.logical_and(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.bool, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.bool, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - And(), input_specs, expected_ops={acc_ops.logical_and} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py deleted file mode 100644 index 4f45612b34..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_or.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestLogicalOrMethodSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), - ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), - ( - "rand_4d_bool_bool", - torch.randn(3, 4, 5, 6) > 0, - torch.randn(3, 4, 5, 6) > 0, - ), - ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4) > 0, - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0) > 0, - ), - ] - ) - def test_logical_or(self, _, input, other): - class LogicalOr(torch.nn.Module): - def forward(self, x, y): - return x.logical_or(y) - - inputs = [ - input, - other, - ] - self.run_test( - LogicalOr(), - inputs, - expected_ops={acc_ops.logical_or}, - test_implicit_batch_dim=False, - ) - - -class TestLogicalOrMethodSimpleConverterWithDynamicShape(AccTestCase): - def test_logical_or(self): - class LogicalOr(torch.nn.Module): - def forward(self, x, y): - return x.logical_or(y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - LogicalOr(), input_specs, expected_ops={acc_ops.logical_or} - ) - - -class TestLogicalOrFunctionSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), - ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), - ( - "rand_4d_bool_bool", - torch.randn(3, 4, 5, 6) > 0, - torch.randn(3, 4, 5, 6) > 0, - ), - ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4) > 0, - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0) > 0, - ), - ] - ) - def test_logical_or(self, _, input, other): - class LogicalOr(torch.nn.Module): - def forward(self, x, y): - return torch.logical_or(x, y) - - inputs = [ - input, - other, - ] - self.run_test( - LogicalOr(), - inputs, - expected_ops={acc_ops.logical_or}, - test_implicit_batch_dim=False, - ) - - -class TestLogicalOrFunctionSimpleConverterWithDynamicShape(AccTestCase): - def test_logical_or(self): - class LogicalOr(torch.nn.Module): - def forward(self, x, y): - return torch.logical_or(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - LogicalOr(), input_specs, expected_ops={acc_ops.logical_or} - ) - - -class TestLogicalOrOperatorSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), - ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), - ( - "rand_4d_bool_bool", - torch.randn(3, 4, 5, 6) > 0, - torch.randn(3, 4, 5, 6) > 0, - ), - ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4) > 0, - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0) > 0, - ), - ] - ) - def test_logical_or(self, _, input, other): - class LogicalOr(torch.nn.Module): - def forward(self, x, y): - return x | y - - inputs = [ - input, - other, - ] - self.run_test( - LogicalOr(), - inputs, - expected_ops={acc_ops.logical_or}, - test_implicit_batch_dim=False, - ) - - -class TestLogicalOrOperatorSimpleConverterWithDynamicShape(AccTestCase): - def test_logical_or(self): - class LogicalOr(torch.nn.Module): - def forward(self, x, y): - return x | y - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.bool, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.bool, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - LogicalOr(), input_specs, expected_ops={acc_ops.logical_or} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py deleted file mode 100644 index 591c7322bf..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_logical_xor.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestLogicalXorMethodSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), - ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), - ( - "rand_4d_bool_bool", - torch.randn(3, 4, 5, 6) > 0, - torch.randn(3, 4, 5, 6) > 0, - ), - ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4) > 0, - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0) > 0, - ), - ] - ) - def test_logical_xor(self, _, input, other): - class LogicalXor(torch.nn.Module): - def forward(self, x, y): - return x.logical_xor(y) - - inputs = [ - input, - other, - ] - self.run_test( - LogicalXor(), - inputs, - expected_ops={acc_ops.logical_xor}, - test_implicit_batch_dim=False, - ) - - -class TestLogicalXorMethodSimpleConverterWithDynamicShape(AccTestCase): - def test_logical_xor(self): - class LogicalXor(torch.nn.Module): - def forward(self, x, y): - return x.logical_xor(y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - LogicalXor(), input_specs, expected_ops={acc_ops.logical_xor} - ) - - -class TestLogicalXorFunctionSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), - ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), - ( - "rand_4d_bool_bool", - torch.randn(3, 4, 5, 6) > 0, - torch.randn(3, 4, 5, 6) > 0, - ), - ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4) > 0, - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0) > 0, - ), - ] - ) - def test_logical_xor(self, _, input, other): - class LogicalXor(torch.nn.Module): - def forward(self, x, y): - return torch.logical_xor(x, y) - - inputs = [ - input, - other, - ] - self.run_test( - LogicalXor(), - inputs, - expected_ops={acc_ops.logical_xor}, - test_implicit_batch_dim=False, - ) - - -class TestLogicalXorFunctionSimpleConverterWithDynamicShape(AccTestCase): - def test_logical_xor(self): - class LogicalXor(torch.nn.Module): - def forward(self, x, y): - return torch.logical_xor(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - LogicalXor(), input_specs, expected_ops={acc_ops.logical_xor} - ) - - -class TestLogicalXorOperatorSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_bool_bool", torch.randn(3, 4) > 0, torch.randn(3, 4) > 0), - ("rand_3d_bool_bool", torch.randn(3, 4, 5) > 0, torch.randn(3, 4, 5) > 0), - ( - "rand_4d_bool_bool", - torch.randn(3, 4, 5, 6) > 0, - torch.randn(3, 4, 5, 6) > 0, - ), - ("rand_2d_bool_single_bool", torch.randn(3, 4) > 0, torch.tensor(0) > 0), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4) > 0, - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0) > 0, - ), - ] - ) - def test_logical_xor(self, _, input, other): - class LogicalXor(torch.nn.Module): - def forward(self, x, y): - return x ^ y - - inputs = [ - input, - other, - ] - self.run_test( - LogicalXor(), - inputs, - expected_ops={acc_ops.logical_xor}, - test_implicit_batch_dim=False, - ) - - -class TestLogicalXorOperatorSimpleConverterWithDynamicShape(AccTestCase): - def test_logical_xor(self): - class LogicalXor(torch.nn.Module): - def forward(self, x, y): - return x ^ y - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.bool, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.bool, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - LogicalXor(), input_specs, expected_ops={acc_ops.logical_xor} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py deleted file mode 100644 index 6d037145ac..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_lt.py +++ /dev/null @@ -1,277 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestLtConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - (torch.randn(3, 4)).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ] - ) - def test_lt(self, _, input, other): - class Lt(torch.nn.Module): - def forward(self, x, y): - mask = torch.lt(x, y) - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False - ) - - -class TestLtMethodConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ] - ) - def test_lt(self, _, input, other): - class Lt(torch.nn.Module): - def forward(self, x, y): - mask = x.lt(y) - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False - ) - - -class TestLtOperatorConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d", torch.randn(3, 4), torch.randn(3, 4)), - ("rand_3d", torch.randn(3, 4, 5), torch.randn(3, 4, 5)), - ("rand_4d", torch.randn(3, 4, 5, 6), torch.randn(3, 4, 5, 6)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_bool", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.zeros(3, 4).to(torch.int), - ), - ] - ) - def test_lt(self, _, input, other): - class Lt(torch.nn.Module): - def forward(self, x, y): - mask = x < y - return x.masked_fill(mask, 5) - - inputs = [ - input, - other, - ] - self.run_test( - Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorSimpleConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def forward(self, x, y): - return x < y - - inputs = [ - input, - other, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False - ) - - -class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase): - def test_eq(self): - class Eq(torch.nn.Module): - def forward(self, x, y): - return x < y - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.lt}) - - -class TestEqOperatorConstantConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ("rand_2d_float_single_bool", torch.randn(3, 4), False), - ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), - ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), - ] - ) - def test_eq(self, _, input, other): - class Eq(torch.nn.Module): - def __init__(self): - super().__init__() - self.other = other - - def forward(self, x): - return x < self.other - - inputs = [ - input, - ] - self.run_test( - Eq(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverter(AccTestCase): - def test_lt(self): - class Lt(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] < 4 - - input = torch.randn(3, 4) - inputs = [ - input, - ] - self.run_test( - Lt(), inputs, expected_ops={acc_ops.lt}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverterWithDynamicShape(AccTestCase): - def test_lt(self): - class Lt(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] < 4 - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ) - ] - - self.run_test_with_dynamic_shape(Lt(), input_specs, expected_ops={acc_ops.lt}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py deleted file mode 100644 index 3c56d50750..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_masked_fill.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -class TestMaskedFill(AccTestCase): - @parameterized.expand( - [ - ("same_dims", (2, 3), 5), - ("same_dims_tensor", (2, 3), torch.tensor(5)), - ("not_same_dims", (2, 1), 5), - ("not_same_dims_tensor", (2, 1), torch.tensor(5)), - ] - ) - def test_masked_fill(self, _, input_shape, value): - class MaskedFill(nn.Module): - def __init__(self, input_shape): - super().__init__() - self.mask = torch.zeros(input_shape) - self.mask[0, 0] = 1 - self.mask = self.mask.to(torch.bool) - self.value = value - - def forward(self, x): - return x.masked_fill(self.mask, self.value) - - inputs = [torch.ones(*input_shape)] - self.run_test( - MaskedFill(input_shape), - inputs, - expected_ops={acc_ops.masked_fill}, - test_implicit_batch_dim=False, - ) - - # Testing with (-1, -1, -1, -1) results into following error: - # RuntimeError: Trying to create tensor with negative dimension -1: [-1, -1, -1, -1] - - @parameterized.expand( - [ - ("same_dims", (2, 3), (2, 3), 5), - ("expand_first_dims", (2, 3), (1, 3), 5), - ("expand_second_dims", (2, 3), (2, 1), 5), - ("expand_third_dims", (2, 3, 4), (2, 3, 1), 5), - ] - ) - def test_masked_fill_expand(self, _, input_shape, mask_shape, value): - class MaskedFill(nn.Module): - def __init__(self, input_shape): - super().__init__() - self.value = value - - def forward(self, x, mask_input): - return x.masked_fill(mask_input, self.value) - - mask_input = torch.zeros(*mask_shape) - index = (0) * len(mask_shape) - mask_input[index] = 1 - mask_input = mask_input.to(torch.bool) - inputs = [torch.ones(*input_shape), mask_input] - self.run_test( - MaskedFill(input_shape), - inputs, - expected_ops={acc_ops.masked_fill}, - test_implicit_batch_dim=False, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py deleted file mode 100644 index 2f979f1243..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_matmul.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMatMulConverter(AccTestCase): - @parameterized.expand( - [ - ("2_2", (2, 3), (3, 2)), - ("2_1", (2, 3), (3,)), - ("4_2", (1, 2, 2, 3), (3, 2)), - ("1_2", (3,), (3, 2)), - ] - ) - def test_matmul_other_constant(self, _, input_shape, other_shape): - class MatMul(nn.Module): - def __init__(self): - super().__init__() - self.other = nn.Parameter(torch.randn(*other_shape)) - - def forward(self, input): - return torch.matmul(input, self.other) - - inputs = [torch.randn(*input_shape)] - self.run_test( - MatMul(), - inputs, - expected_ops={acc_ops.matmul}, - test_implicit_batch_dim=(len(input_shape) > 1), - ) - - @parameterized.expand( - [ - ("2_2", (2, 3), (3, 2)), - ("1_2", (3,), (3, 2)), - ("3_4", (2, 2, 3), (3, 1, 3, 3)), - ] - ) - def test_matmul_input_constant(self, _, input_shape, other_shape): - class MatMul(nn.Module): - def __init__(self): - super().__init__() - self.input = nn.Parameter(torch.randn(*input_shape)) - - def forward(self, other): - return torch.matmul(self.input, other) - - inputs = [torch.randn(*other_shape)] - self.run_test( - MatMul(), - inputs, - expected_ops={acc_ops.matmul}, - test_implicit_batch_dim=(len(other_shape) > 2), - ) - - @parameterized.expand( - [ - ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)), - ("4_2", (2, 1, 2, 3), (3, 2)), - ("2_3", (2, 3), (2, 3, 4)), - ("2_2", (2, 3), (3, 2)), - ("2_1", (2, 3), (3,)), - ("1_2", (3,), (3, 2)), - ("1_1", (3,), (3,)), - ] - ) - def test_matmul(self, _, input_shape, other_shape): - class MatMul(nn.Module): - def forward(self, input, other): - return torch.matmul(input, other) - - inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] - test_implicit_batch_dim = ( - input_shape[0] == other_shape[0] - and len(input_shape) > 2 - and len(other_shape) > 2 - ) - self.run_test( - MatMul(), - inputs, - expected_ops={acc_ops.matmul}, - test_implicit_batch_dim=test_implicit_batch_dim, - ) - - def test_matmal_dynamic_shape( - self, - ): - class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input, other): - return torch.matmul(input, other) - - input_specs = [ - InputTensorSpec( - shape=(-1, 1, 2, 3), - dtype=torch.float32, - shape_ranges=[((1, 1, 2, 3), (9, 1, 2, 3), (9, 1, 2, 3))], - ), - InputTensorSpec( - shape=(-1, -1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Matmul(), input_specs, expected_ops={acc_ops.matmul} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py deleted file mode 100644 index be6b4cdedc..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_max.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMaxConverter(AccTestCase): - @parameterized.expand( - [ - ("dim0_keepdim", 0, True, torch.randn(2, 2, 3)), - ("dim1_keepdim", 1, True, torch.randn(2, 2, 3)), - ("dim2_keepdim", 2, True, torch.randn(2, 2, 3)), - ("dim3_keepdim", 3, True, torch.randn(2, 2, 3, 3)), - ("dim2_no_keepdim", 2, False, torch.randn(2, 2, 3)), - ("dim1_no_keepdim", 1, False, torch.randn(2, 2, 3)), - ("dim0_no_keepdim", 0, False, torch.randn(2, 2, 3)), - ] - ) - def test_max_dim_reduce(self, test_name, dim, keepdim, input): - class MaxDimReduce(torch.nn.Module): - def __init__(self, dim, keepdim): - super().__init__() - self.dim = dim - self.keepdim = keepdim - - def forward(self, x): - return torch.max(x, self.dim, self.keepdim) - - inputs = [input] - self.run_test( - MaxDimReduce(dim, keepdim), - inputs, - expected_ops={acc_ops.max_dim_reduce}, - test_implicit_batch_dim=(dim != 0), - ) - - @parameterized.expand( - [ - ("no_dim_no_keepdim"), - ] - ) - def test_max_full_reduce( - self, - test_name, - ): - class MaxFullReduce(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.max(x) - - inputs = [torch.randn(3, 2, 3, 3)] - self.run_test( - MaxFullReduce(), - inputs, - expected_ops={acc_ops.max_full_reduce}, - # We can't do a full reduce over the batch dimension - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("max_method_no_dim_no_keepdim"), - ("max_method_no_dim_no_keepdim"), - ] - ) - def test_max_method(self, test_name): - class MaxMethod(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input, other): - return input.max(other) - - inputs = [torch.randn(3, 4), torch.randn(3, 4)] - self.run_test(MaxMethod(), inputs, expected_ops={acc_ops.maximum}) - - -class TestMaxConverterWithDynamicShape(AccTestCase): - @parameterized.expand( - [ - # keepdim can not be False for dynamic shape - ("dim0_keepdim", 0, True), - ("dim1_keepdim", 1, True), - ("dim2_keepdim", 2, True), - ("dim3_keepdim", 3, True), - ] - ) - def test_max_dim_reduce(self, _, dim, keepdim): - class MaxDimReduce(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.max(x, dim, keepdim=keepdim) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce} - ) - - def test_max_full_reduce( - self, - ): - class MaxFullReduce(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.max(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce} - ) - - def test_max_method(self): - class MaxMethod(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input, other): - return input.max(other) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - MaxMethod(), input_specs, expected_ops={acc_ops.maximum} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py deleted file mode 100644 index 8c1522d3ad..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maximum.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMaximumConverter(AccTestCase): - def test_maximum(self): - class Maximum(torch.nn.Module): - def forward(self, x, y): - return torch.maximum(x, y) - - inputs = [ - torch.randn(3, 4), - torch.randn(3, 4), - ] - self.run_test(Maximum(), inputs, expected_ops={acc_ops.maximum}) - - -class TestMaximumConverterWithDynamicShape(AccTestCase): - def test_maximum(self): - class Maximum(torch.nn.Module): - def forward(self, x, y): - return torch.maximum(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - Maximum(), input_specs, expected_ops={acc_ops.maximum} - ) - - -class TestMaximumMethodConverter(AccTestCase): - def test_maximum(self): - class Maximum(torch.nn.Module): - def forward(self, x, y): - return x.maximum(y) - - inputs = [ - torch.randn(3, 4), - torch.randn(3, 4), - ] - self.run_test(Maximum(), inputs, expected_ops={acc_ops.maximum}) - - -class TestMaximumMethodConverterWithDynamicShape(AccTestCase): - def test_maximum(self): - class Maximum(torch.nn.Module): - def forward(self, x, y): - return x.maximum(y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - Maximum(), input_specs, expected_ops={acc_ops.maximum} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py deleted file mode 100644 index 7ed6301467..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_maxpool.py +++ /dev/null @@ -1,382 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMaxPoolConverter(AccTestCase): - @parameterized.expand( - [ - ("default", 1), - ("kernel_3", 3), - ("stride", 1, 2), - param("padding", 2, padding=1), - param("padding_even", 5, padding=2), - param("ceil_mode", 1, ceil_mode=True), - ] - ) - def test_max_pool1d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - dilation=1, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool1d( - kernel_size, stride, padding, ceil_mode=ceil_mode, dilation=dilation - ) - - def forward(self, x): - return self.max_pool(x) - - inputs = [torch.randn(1, 3, 224)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.max_pool1d}, - ) - - def test_max_pool1d_with_dynamic_shape( - self, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool1d(1) - - def forward(self, x): - return self.max_pool(x) - - # shape is not set to (-1, -1, -1) as reshape dimension with - # more than one -1 wildcard is not allowed while adding unsqueeze layer - input_specs = [ - InputTensorSpec( - shape=(1, 1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 1, 4), (1, 1, 4))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={acc_ops.max_pool1d}, - ) - - @parameterized.expand( - [ - ("default", 1), - ("stride", 1, 2), - ("tuple_parameters", 2, (1, 1), (1, 1)), - param("padding", 2, padding=1), - param("ceil_mode", 1, ceil_mode=True), - ] - ) - def test_max_pool2d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool2d( - kernel_size, stride, padding, ceil_mode=ceil_mode - ) - - def forward(self, x): - return self.max_pool(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool2d}) - - def test_max_pool2d_with_dynamic_shape( - self, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool2d(1, 1) - - def forward(self, x): - return self.max_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.max_pool2d} - ) - - @parameterized.expand( - [ - ("default", 1), - ("stride", 1, 2), - ("tuple_parameters", 2, (1, 1, 1), (1, 1, 1)), - param("padding", 2, padding=1), - param("ceil_mode", 1, ceil_mode=True), - ] - ) - def test_max_pool3d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool3d( - kernel_size, stride, padding, ceil_mode=ceil_mode - ) - - def forward(self, x): - return self.max_pool(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool3d}) - - def test_max_pool3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool3d(1, 1) - - def forward(self, x): - return self.max_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.max_pool3d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool1d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - dilation=1, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool1d( - x, - kernel_size, - stride=stride, - padding=padding, - ceil_mode=ceil_mode, - dilation=dilation, - ) - - inputs = [torch.randn(1, 3, 224)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.max_pool1d}, - test_explicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool2d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool2d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool2d}) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool3d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool3d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool3d}) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool1d_with_dynamic_shape( - self, - test_name, - kernel_size, - stride=None, - padding=0, - dilation=1, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool1d( - x, - kernel_size, - stride=stride, - padding=padding, - ceil_mode=ceil_mode, - dilation=dilation, - ) - - # shape is not set to (-1, -1, -1) as reshape dimension with - # more than one -1 wildcard is not allowed while adding unsqueeze layer - input_specs = [ - InputTensorSpec( - shape=(1, 1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 1, 4), (1, 1, 4))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={acc_ops.max_pool1d}, - ) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool2d_with_dynamic_shape( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool2d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.max_pool2d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool3d_with_dynamic_shape( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool3d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.max_pool3d} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py deleted file mode 100644 index 6d09db1d5c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_min.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMinConverter(AccTestCase): - @parameterized.expand( - [ - ("dim0_keepdim", 0, True, torch.randn(2, 2, 3)), - ("dim1_keepdim", 1, True, torch.randn(2, 2, 3)), - ("dim2_keepdim", 2, True, torch.randn(2, 2, 3)), - ("dim3_keepdim", 3, True, torch.randn(2, 2, 3, 3)), - ("dim2_no_keepdim", 2, False, torch.randn(2, 2, 3)), - ("dim1_no_keepdim", 1, False, torch.randn(2, 2, 3)), - ("dim0_no_keepdim", 0, False, torch.randn(2, 2, 3)), - ] - ) - def test_min_dim_reduce(self, test_name, dim, keepdim, input): - class MinDimReduce(torch.nn.Module): - def __init__(self, dim, keepdim): - super().__init__() - self.dim = dim - self.keepdim = keepdim - - def forward(self, x): - return torch.min(x, self.dim, self.keepdim) - - inputs = [input] - self.run_test( - MinDimReduce(dim, keepdim), - inputs, - expected_ops={acc_ops.min_dim_reduce}, - test_implicit_batch_dim=(dim != 0), - ) - - @parameterized.expand( - [ - ("no_dim_no_keepdim"), - ] - ) - def test_min_full_reduce( - self, - test_name, - ): - class MinFullReduce(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.min(x) - - inputs = [torch.randn(3, 2, 3, 3)] - self.run_test( - MinFullReduce(), - inputs, - expected_ops={acc_ops.min_full_reduce}, - # We can't do a full reduce over the batch dimension - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("min_method_no_dim_no_keepdim"), - ("min_method_no_dim_no_keepdim"), - ] - ) - def test_min_method(self, test_name): - class MinMethod(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input, other): - return input.min(other) - - inputs = [torch.randn(3, 4), torch.randn(3, 4)] - self.run_test(MinMethod(), inputs, expected_ops={acc_ops.minimum}) - - -class TestMinConverterWithDynamicShape(AccTestCase): - @parameterized.expand( - [ - ("dim0_keepdim", 0, True), - ("dim1_keepdim", 1, True), - ("dim2_keepdim", 2, True), - ("dim3_keepdim", 3, True), - ] - ) - def test_min_dim_reduce(self, test_name, dim, keepdim): - class MinDimReduce(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.min(x, dim, keepdim) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce} - ) - - def test_min_full_reduce( - self, - ): - class MinFullReduce(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.min(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce} - ) - - def test_min_method(self): - class MinMethod(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input, other): - return input.min(other) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - MinMethod(), input_specs, expected_ops={acc_ops.minimum} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py deleted file mode 100644 index 7778f784a2..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_minimum.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMinimumConverter(AccTestCase): - def test_minimum(self): - class Minimum(torch.nn.Module): - def forward(self, x, y): - return torch.minimum(x, y) - - inputs = [ - torch.randn(3, 4), - torch.randn(3, 4), - ] - self.run_test(Minimum(), inputs, expected_ops={acc_ops.minimum}) - - -class TestMinimumMethodConverter(AccTestCase): - def test_minimum(self): - class Minimum(torch.nn.Module): - def forward(self, x, y): - return x.minimum(y) - - inputs = [ - torch.randn(3, 4), - torch.randn(3, 4), - ] - self.run_test(Minimum(), inputs, expected_ops={acc_ops.minimum}) - - -class TestMinimumConverterWithDynamicShape(AccTestCase): - def test_minimum(self): - class Minimum(torch.nn.Module): - def forward(self, x, y): - return torch.minimum(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - Minimum(), input_specs, expected_ops={acc_ops.minimum} - ) - - -class TestMinimumMethodConverterWithDynamicShape(AccTestCase): - def test_minimum(self): - class Minimum(torch.nn.Module): - def forward(self, x, y): - return x.minimum(y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - Minimum(), input_specs, expected_ops={acc_ops.minimum} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py deleted file mode 100644 index 13d0e257ac..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_narrow.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestNarrowConverterWithDynamicShape(AccTestCase): - @parameterized.expand( - [ - ("positive_dim", 1, 0, 1), - ] - ) - def test_narrow(self, _, dim, start, length): - class Narrow(nn.Module): - def forward(self, x): - return x.narrow(dim, start, length) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - Narrow(), input_specs, expected_ops={acc_ops.slice_tensor} - ) - - -class TestNarrowConverter(AccTestCase): - @parameterized.expand( - [ - ("positive_dim", 1, 0, 1), - ("negative_dim", -1, 1, 2), - ] - ) - def test_narrow(self, _, dim, start, length): - class Narrow(nn.Module): - def forward(self, x): - return x.narrow(dim, start, length) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Narrow(), - inputs, - expected_ops={acc_ops.slice_tensor}, - test_explicit_batch_dim=False, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py deleted file mode 100644 index 2fd99787b4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_ne.py +++ /dev/null @@ -1,307 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestNeFunctionConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_ne(self, _, input, other): - class Ne(torch.nn.Module): - def forward(self, x, y): - return torch.ne(x, y) - - inputs = [ - input, - other, - ] - self.run_test( - Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False - ) - - -class TestNeFunctionConverterWithDynamicShape(AccTestCase): - def test_ne(self): - class Ne(torch.nn.Module): - def forward(self, x, y): - return torch.ne(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) - - -class TestNeMethodConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_ne(self, _, input, other): - class Ne(torch.nn.Module): - def forward(self, x, y): - return x.ne(y) - - inputs = [ - input, - other, - ] - self.run_test( - Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False - ) - - -class TestNeMethodConverterWithDynamicShape(AccTestCase): - def test_ne(self): - class Ne(torch.nn.Module): - def forward(self, x, y): - return x.ne(y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) - - -class TestNeOperatorConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ( - "rand_2d_float_single_bool", - torch.randn(3, 4), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_int_single_bool", - torch.randn(3, 4).to(torch.int), - torch.tensor(0).to(torch.bool), - ), - ( - "rand_2d_bool_single_bool", - torch.randn(3, 4).to(torch.bool), - torch.tensor(0).to(torch.bool), - ), - ] - ) - def test_ne(self, _, input, other): - class Ne(torch.nn.Module): - def forward(self, x, y): - return x != y - - inputs = [ - input, - other, - ] - self.run_test( - Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False - ) - - -class TestNeOperatorConverterWithDynamicShape(AccTestCase): - def test_ne(self): - class Ne(torch.nn.Module): - def forward(self, x, y): - return x != y - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) - - -class TestNeOperatorConstantConverter(AccTestCase): - @parameterized.expand( - [ - ("rand_2d_float_bool", torch.randn(3, 4), torch.randn(3, 4).to(torch.bool)), - ( - "rand_2d_int_bool", - torch.randn(3, 4).to(torch.int), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_bool_bool", - torch.randn(3, 4).to(torch.bool), - torch.randn(3, 4).to(torch.bool), - ), - ( - "rand_2d_float_int", - torch.randn(3, 4).to(torch.float), - torch.randn(3, 4).to(torch.int), - ), - ("rand_2d_float_single_bool", torch.randn(3, 4), False), - ("rand_2d_int_single_bool", torch.randn(3, 4).to(torch.int), False), - ("rand_2d_bool_single_bool", torch.randn(3, 4).to(torch.bool), False), - ] - ) - def test_ne(self, _, input, other): - class Ne(torch.nn.Module): - def __init__(self): - super().__init__() - self.other = other - - def forward(self, x): - return x != self.other - - inputs = [ - input, - ] - self.run_test( - Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverter(AccTestCase): - def test_ne(self): - class Ne(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] != 4 - - input = torch.randn(3, 4) - inputs = [ - input, - ] - self.run_test( - Ne(), inputs, expected_ops={acc_ops.ne}, test_implicit_batch_dim=False - ) - - -class TestConstInputConverterWithDynamicShape(AccTestCase): - def test_ne(self): - class Ne(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.shape[0] != 4 - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape(Ne(), input_specs, expected_ops={acc_ops.ne}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py deleted file mode 100644 index 79754b38d4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_new_ones.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestNewOnesConverter(AccTestCase): - def test_newone_no_dtype(self): - class TestModule(nn.Module): - def forward(self, x): - return x.new_ones((3, 5)) - - inputs = [torch.randn(1, 10)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.new_ones}, - test_implicit_batch_dim=False, - ) - - def test_newone_device(self): - class TestModule(nn.Module): - def forward(self, x): - return x.new_ones((3, 5), device="cuda") - - inputs = [torch.randn(1, 10)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.new_ones}, - test_implicit_batch_dim=False, - ) - - -class TestNewOnesConverterWithDynamicShape(AccTestCase): - def test_newone_no_dtype(self): - class TestModule(nn.Module): - def forward(self, x): - return x.new_ones((3, 5)) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.new_ones} - ) - - def test_newone_device(self): - class TestModule(nn.Module): - def forward(self, x): - return x.new_ones((3, 5), device="cuda") - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.new_ones} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py deleted file mode 100644 index a79e29600d..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_numel.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -class TestNumelConverter(AccTestCase): - def test_numel(self): - class Numel(nn.Module): - def forward(self, x): - return torch.numel(x) * x - - inputs = [torch.ones(1, 2, 3, 4)] - self.run_test(Numel(), inputs, expected_ops={acc_ops.numel}) - - -# Testing with (-1, -1, -1 , -1) results in following error: -# RuntimeError: numel does not support dynamic shapes. -""" -class TestNumelConverterWithDynamicShape(AccTestCase): - def test_numel(self): - class Numel(nn.Module): - def forward(self, x): - return torch.numel(x) * x - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - Numel(), input_specs, expected_ops={acc_ops.numel} - ) -""" - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py deleted file mode 100644 index 7625f4edee..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_pad.py +++ /dev/null @@ -1,102 +0,0 @@ -import unittest - -import tensorrt as trt -import torch -import torch.nn as nn - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - -# from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec - - -class TestPadConverter(AccTestCase): - @parameterized.expand( - [ - ("1d", (1, 2), 9), - ("2d", (2, 0, 0, 1), 10), - ] - ) - def test_pad_value(self, _, pad, value): - class Pad(nn.Module): - def forward(self, x): - return torch.nn.functional.pad(x, pad, value=value) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Pad(), - inputs, - expected_ops={acc_ops.pad}, - # enable value will not work with implicit batch - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("1d", (1, 2)), - ("2d", (2, 0, 0, 1)), - ] - ) - def test_pad(self, _, pad): - class Pad(nn.Module): - def forward(self, x): - return torch.nn.functional.pad(x, pad) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Pad(), - inputs, - expected_ops={acc_ops.pad}, - # enable value will not work with implicit batch - test_implicit_batch_dim=False, - ) - - # Testing with (-1, 3, 3, 3) results into following error: - # test_pad_with_dynamic_shape_four_dimensions_0_2d (deeplearning.trt.torch_tensorrt.py.torch_tensorrt.fx.test.converters.acc_op.test_pad.TestPadConverter) ... [07/15/2022-09:23:18] [TRT] [E] 2: [intInterval.cpp::max::26] Error Code 2: Internal Error (Assertion !empty() failed. ) - # Segmentation fault (core dumped) - - """ - def test_pad_with_dynamic_shape_four_dimensions(self): - class Pad(nn.Module): - def forward(self, x): - return torch.nn.functional.pad(x, (1, 1)) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3, 3), (2, 3, 3, 3), (2, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape(Pad(), input_specs, expected_ops={acc_ops.pad}) - """ - - @parameterized.expand( - [ - ("3d", (2, 2, 3, 1, 2, 2)), - ] - ) - @unittest.skipIf( - trt.__version__ < "8.2", - "Padding 3d only supported in TensorRT 8.2 and later", - ) - def test_pad_3d(self, _, pad): - class Pad(nn.Module): - def forward(self, x): - return torch.nn.functional.pad(x, pad) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Pad(), - inputs, - expected_ops={acc_ops.pad}, - # enable value will not work with implicit batch - test_implicit_batch_dim=False, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py deleted file mode 100644 index a8b8c95f0b..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_permute.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestPermuteConverter(AccTestCase): - @parameterized.expand( - [ - ("positive", [0, 2, 1]), - ("negative", [0, -1, -2]), - ] - ) - def test_permute_list(self, _, permutation): - class Permute(nn.Module): - def forward(self, x): - return x.permute(permutation) - - inputs = [torch.randn(1, 3, 2)] - self.run_test(Permute(), inputs, expected_ops={acc_ops.permute}) - - @parameterized.expand( - [ - ("positive", [0, 2, 1]), - ("negative", [0, -1, -2]), - ] - ) - def test_permute(self, _, permutation): - class Permute(nn.Module): - def forward(self, x): - return x.permute(*permutation) - - inputs = [torch.randn(1, 3, 2)] - self.run_test(Permute(), inputs, expected_ops={acc_ops.permute}) - - @parameterized.expand( - [ - ("positive", (1, 2)), - ("negative", (-1, -2)), - ] - ) - def test_transpose(self, _, dims): - class Transpose(nn.Module): - def forward(self, x): - return x.transpose(*dims) - - inputs = [torch.randn(1, 2, 3)] - self.run_test(Transpose(), inputs, expected_ops={acc_ops.permute}) - - def test_permute_with_dynamic_shape(self): - class Permute(nn.Module): - def forward(self, x): - return x.permute(1, 2, 0) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Permute(), input_specs, expected_ops={acc_ops.permute} - ) - - def test_permute_with_dynamic_shape_four_dimensions(self): - class Permute(nn.Module): - def forward(self, x): - return x.permute(1, 2, 3, 0) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - Permute(), input_specs, expected_ops={acc_ops.permute} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py deleted file mode 100644 index e13c8b3048..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_prod.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - -# NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples - - -class TestProdConverter(AccTestCase): - @parameterized.expand( - [ - ( - f"{acc_ops.prod.__name__}_dim0_keepdim", - 0, - True, - torch.prod, - acc_ops.prod, - ), - ( - f"{acc_ops.prod.__name__}_dim0_no_keepdim", - 0, - False, - torch.prod, - acc_ops.prod, - ), - ( - f"{acc_ops.prod.__name__}_dim1_keepdim", - 1, - True, - torch.prod, - acc_ops.prod, - ), - ( - f"{acc_ops.prod.__name__}_dim1_no_keepdim", - 1, - False, - torch.prod, - acc_ops.prod, - ), - ( - f"{acc_ops.prod.__name__}_dim1_keepdim", - 2, - True, - torch.prod, - acc_ops.prod, - ), - ( - f"{acc_ops.prod.__name__}_dim1_no_keepdim", - 2, - False, - torch.prod, - acc_ops.prod, - ), - ] - ) - def test_prod(self, test_name, dim, keepdim, op, expected_acc_op): - class Prod(torch.nn.Module): - def __init__(self, dim, keepdim): - super().__init__() - self.dim = dim - self.keepdim = keepdim - - def forward(self, x): - return op(x, dim=self.dim, keepdim=self.keepdim) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Prod(dim, keepdim), - inputs, - expected_ops={expected_acc_op}, - test_implicit_batch_dim=(dim != 0), - ) - - @parameterized.expand( - [(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)] - ) - def test_prod_all_dims( - self, - test_name, - op, - expected_acc_op, - ): - class Prod(torch.nn.Module): - def forward(self, x): - return op(x) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Prod(), - inputs, - expected_ops={expected_acc_op}, - test_implicit_batch_dim=False, - ) - - def test_prod_all_dims_with_dynamic_shape( - self, - op=torch.prod, - ): - class Prod(torch.nn.Module): - def forward(self, x): - return op(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - Prod(), input_specs, expected_ops={acc_ops.prod} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py deleted file mode 100644 index eaef10df94..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_quantize_per_tensor.py +++ /dev/null @@ -1,68 +0,0 @@ -import unittest - -import tensorrt as trt -import torch.fx -import torch.nn as nn - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -@unittest.skip( - """ - Tests related to quantize have issue creating engine, disable now. - """ -) -@unittest.skipIf( - trt.__version__ < "8.0", - "Explicit quantization only supported in TensorRT 8.0 and later", -) -class TestQuantizePerTensorConverter(AccTestCase): - def test_quantize_per_tensor(self): - class TestModule(nn.Module): - def forward(self, x): - return torch.quantize_per_tensor(x, 1, 0, torch.quint8) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.quantize_per_tensor}) - - def test_quantize_per_tensor_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return torch.quantize_per_tensor(x, 1, 0, torch.quint8) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.quantize_per_tensor} - ) - - def test_quantize_per_tensor_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return torch.quantize_per_tensor(x, 1, 0, torch.quint8) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.quantize_per_tensor} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py deleted file mode 100644 index 4fe7f8511c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reduce_ops.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - -reduce_ops = [(torch.sum, acc_ops.sum), (torch.mean, acc_ops.mean)] - - -class TestReduceConverter(AccTestCase): - @parameterized.expand( - case - for op, acc_op in reduce_ops - for case in [ - (f"{acc_op.__name__}_single_dim_no_keepdim", 1, False, op, acc_op), - (f"{acc_op.__name__}_single_dim_keepdim", 1, True, op, acc_op), - (f"{acc_op.__name__}_two_dim_no_keepdim", (1, 2), False, op, acc_op), - (f"{acc_op.__name__}_two_dim_keepdim", (1, 2), True, op, acc_op), - (f"{acc_op.__name__}_three_dim_no_keepdim", (1, 2, 3), False, op, acc_op), - (f"{acc_op.__name__}_three_dim_keepdim", (1, 2, 3), True, op, acc_op), - (f"{acc_op.__name__}_dim0_keepdim", 0, True, op, acc_op), - (f"{acc_op.__name__}_dim0_no_keepdim", 0, False, op, acc_op), - (f"{acc_op.__name__}_neg_single_dim_no_keepdim", -1, False, op, acc_op), - (f"{acc_op.__name__}_neg_single_dim_keepdim", -1, True, op, acc_op), - (f"{acc_op.__name__}_neg_two_dim_no_keepdim", (-1, -2), False, op, acc_op), - (f"{acc_op.__name__}_neg_two_dim_keepdim", (-1, -2), True, op, acc_op), - ( - f"{acc_op.__name__}_neg_pos_two_dim_no_keepdim", - (-1, 1), - False, - op, - acc_op, - ), - (f"{acc_op.__name__}_neg_pos_two_dim_keepdim", (-1, 1), True, op, acc_op), - ] - ) - def test_reduce(self, test_name, dim, keepdim, op, expected_acc_op): - class Reduce(torch.nn.Module): - def __init__(self, dim, keepdim): - super().__init__() - self.dim = dim - self.keepdim = keepdim - - def forward(self, x): - return op(x, dim=self.dim, keepdim=self.keepdim) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Reduce(dim, keepdim), - inputs, - expected_ops={expected_acc_op}, - test_implicit_batch_dim=(dim != 0), - ) - - @parameterized.expand( - [ - (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) - for op, acc_op in reduce_ops - ] - ) - def test_reduce_all_dims( - self, - test_name, - op, - expected_acc_op, - ): - class Reduce(torch.nn.Module): - def forward(self, x): - return op(x) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Reduce(), - inputs, - expected_ops={expected_acc_op}, - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) - for op, acc_op in reduce_ops - ] - ) - def test_reduce_all_dims_with_dynamic_shape_four_dimensions( - self, - test_name, - op, - expected_acc_op, - ): - class Reduce(torch.nn.Module): - def forward(self, x): - return op(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Reduce(), input_specs, expected_ops={expected_acc_op} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py deleted file mode 100644 index 774cd6fec7..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_relu.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestReLUConverter(AccTestCase): - def test_relu(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.relu(x) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.relu}) - - def test_relu_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.relu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.relu} - ) - - def test_relu_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.relu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.relu} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py deleted file mode 100644 index 0c4360d53f..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_repeat_interleave.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch import nn -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestRepeatInterLeave(AccTestCase): - @parameterized.expand( - [ - ("none_dim", (2, 3, 4), 3, None), - ("dim_0", (2, 3, 4), 3, 0), - ("dim_1", (2, 3, 4), 3, 1), - ("dim_2", (2, 3, 4), 3, 2), - ] - ) - def test_repeat_interleave(self, _, input_shape, repeat, dim): - class RepeatInterleave(nn.Module): - def __init__(self, dim): - super().__init__() - self.repeat = repeat - self.dim = dim - - def forward(self, x): - return torch.repeat_interleave(x, self.repeat, self.dim) - - inputs = [torch.randn(*input_shape)] - expected_ops = {acc_ops.tile, acc_ops.unsqueeze, acc_ops.reshape} - if dim is not None: - expected_ops.update({acc_ops.getitem, acc_ops.size}) - self.run_test( - RepeatInterleave(dim), - inputs, - expected_ops=expected_ops, - test_implicit_batch_dim=dim is not None and dim != 0, - ) - - @parameterized.expand( - [ - ("none_dim", (-1, 2, 3), 3, None), - ("dim_0", (-1, 2, 3), 3, 0), - ("dim_1", (-1, 2, 3), 3, 1), - ("dim_2", (-1, 3, 2), 3, 2), - ] - ) - def test_repeat_interleave_with_dynamic_shape(self, _, input_shape, repeat, dim): - class RepeatInterleave(nn.Module): - def __init__(self, dim): - super().__init__() - self.repeat = repeat - self.dim = dim - - def forward(self, x): - return torch.repeat_interleave(x, self.repeat, self.dim) - - input_specs = [ - InputTensorSpec( - shape=input_shape, - dtype=torch.float32, - shape_ranges=[ - ( - tuple(i if i != -1 else 1 for i in input_shape), - tuple(i if i != -1 else 2 for i in input_shape), - tuple(i if i != -1 else 3 for i in input_shape), - ) - ], - ), - ] - self.run_test_with_dynamic_shape( - RepeatInterleave(dim), input_specs, expected_ops={acc_ops.tile} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py deleted file mode 100644 index dba833276f..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_reshape.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestReshapeConverter(AccTestCase): - @parameterized.expand( - [ - ((1, 20),), - ((1, 10, -1),), - ] - ) - def test_reshape(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): - super().__init__() - self.target_shape = target_shape - - def forward(self, x): - return torch.reshape(x, self.target_shape) - - inputs = [torch.randn(1, 2, 10)] - self.run_test(TestModule(target_shape), inputs, expected_ops={acc_ops.reshape}) - - @parameterized.expand( - [ - ((-1, 2),), - ((1, 2, -1),), - ] - ) - def test_reshape_with_dynamic_shape(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): - super().__init__() - self.target_shape = target_shape - - def forward(self, x): - return torch.reshape(x, self.target_shape) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(target_shape), input_specs, expected_ops={acc_ops.reshape} - ) - - @parameterized.expand( - [ - ((-1, 2),), - ((1, 2, -1),), - ] - ) - def test_reshape_with_dynamic_shape_with_four_dimensions(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): - super().__init__() - self.target_shape = target_shape - - def forward(self, x): - return torch.reshape(x, self.target_shape) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(target_shape), input_specs, expected_ops={acc_ops.reshape} - ) - - def test_reshape_with_dynamic_shape_size(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - shape_y = y.shape - t = shape_y[1] - return torch.reshape(x, [-1, t, 3]) - - input_specs = [ - InputTensorSpec( - shape=(-1, 5, 6), - dtype=torch.float32, - shape_ranges=[((1, 5, 6), (2, 5, 6), (3, 5, 6))], - ), - InputTensorSpec( - shape=(-1, 5), - dtype=torch.float32, - shape_ranges=[((1, 5), (1, 5), (3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.reshape} - ) - - def test_reshape_with_dynamic_shape_mul(self): - class TestModule(torch.nn.Module): - def forward(self, x, y, z): - t = 8000 - a = torch.reshape(x, [-1, t, 64]) - b = torch.reshape(y, [-1, t, 64]) - c = torch.reshape(z, [-1, t, 64]) - d = a + b + c - return d - - input_specs = [ - InputTensorSpec( - shape=(-1, 42, 512), - dtype=torch.float32, - shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], - ), - InputTensorSpec( - shape=(-1, 42, 512), - dtype=torch.float32, - shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], - ), - InputTensorSpec( - shape=(-1, 42, 512), - dtype=torch.float32, - shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.reshape} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py deleted file mode 100644 index cbc4c04117..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_selu.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSeLUConverter(AccTestCase): - def test_selu(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.selu(x) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.selu}) - - def test_selu_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.selu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.selu} - ) - - def test_selu_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.selu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.selu} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py deleted file mode 100644 index 77aa8c9392..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_sigmoid.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSigmoid(AccTestCase): - def test_sigmoid(self): - class Sigmoid(nn.Module): - def forward(self, x): - return torch.sigmoid(x) - - inputs = [torch.randn(1, 2, 3)] - self.run_test(Sigmoid(), inputs, expected_ops={acc_ops.sigmoid}) - - def test_sigmoid_with_dynamic_shape_four_dimensions(self): - class Sigmoid(nn.Module): - def forward(self, x): - return torch.sigmoid(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Sigmoid(), input_specs, expected_ops={acc_ops.sigmoid} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_silu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_silu.py deleted file mode 100644 index 38d8f5b645..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_silu.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from torch import nn -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec -from torch_tensorrt.fx.tracer.acc_tracer import acc_ops - - -class TestSilu(AccTestCase): - def test_silu(self): - class Silu(nn.Module): - def forward(self, x): - return torch.nn.functional.silu(x) - - inputs = [torch.randn(1, 2, 3)] - self.run_test(Silu(), inputs, expected_ops={acc_ops.sigmoid, acc_ops.mul}) - - def test_silu_with_dynamic_shape(self): - class Silu(nn.Module): - def forward(self, x): - return torch.nn.functional.silu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Silu(), input_specs, expected_ops={acc_ops.sigmoid, acc_ops.mul} - ) - - def test_silu_with_dynamic_shape_four_dimensions(self): - class Silu(nn.Module): - def forward(self, x): - return torch.nn.functional.silu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Silu(), input_specs, expected_ops={acc_ops.sigmoid, acc_ops.mul} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py deleted file mode 100644 index 411b8b6a46..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_size.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSizeConverter(AccTestCase): - def test_size(self): - class Size(nn.Module): - def forward(self, x): - bs = x.size(0) - return x.view(bs, -1) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test(Size(), inputs, expected_ops={acc_ops.size}) - - def test_size_param(self): - class Size(nn.Module): - def __init__(self, x): - super().__init__() - self.param = torch.nn.Parameter(x) - - def forward(self, y): - bs = self.param.size(0) - return y.view(bs, -1) - - self.run_test( - Size(torch.randn(1, 2, 3, 4)), - [torch.randn(1, 2, 3, 4)], - expected_ops={acc_ops.size}, - ) - - def test_size_dynamic_shape(self): - class Size(nn.Module): - def forward(self, x): - bs = x.size(0) - return x.view(bs, -1) - - input_specs = [ - InputTensorSpec( - shape=(-1, 12, 32), - dtype=torch.float32, - shape_ranges=[((1, 12, 32), (3, 12, 32), (100, 12, 32))], - ), - ] - self.run_test_with_dynamic_shape( - Size(), input_specs, expected_ops={acc_ops.size} - ) - - def test_size_dynamic_shape_four_dimensions(self): - class Size(nn.Module): - def forward(self, x): - bs = x.size(0) - return x.view(bs, -1) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 12, 32, 3), (3, 12, 32, 3), (100, 12, 32, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Size(), input_specs, expected_ops={acc_ops.size} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py deleted file mode 100644 index 20c4ab744d..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softmax.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSoftmaxConverter(AccTestCase): - @parameterized.expand( - [("none_dim", None), ("basic", 1), ("batch_dim", 0), ("negative_dim", -2)] - ) - def test_softmax(self, _, dim): - class Softmax(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - return nn.functional.softmax(x, dim=self.dim) - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Softmax(dim), - inputs, - expected_ops={acc_ops.softmax}, - test_implicit_batch_dim=(dim is None or dim % len(inputs[0].shape) != 0), - ) - - def test_softmax_with_dynamic_shape(self): - class Softmax(nn.Module): - def forward(self, x): - return nn.functional.softmax(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Softmax(), input_specs, expected_ops={acc_ops.softmax} - ) - - def test_softmax_with_dynamic_shape_four_dimensions(self): - class Softmax(nn.Module): - def forward(self, x): - return nn.functional.softmax(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Softmax(), input_specs, expected_ops={acc_ops.softmax} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py deleted file mode 100644 index 73b97a02b6..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_softsign.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSoftsignConverter(AccTestCase): - def test_softsign(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.softsign(x) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.softsign}) - - def test_softsign_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.softsign(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.softsign} - ) - - def test_softsign_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.softsign(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.softsign} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py deleted file mode 100644 index 20f63ab958..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_split.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSplitConverter(AccTestCase): - @parameterized.expand( - [ - ("split_size", 3, 1), - ("sections", [5, 2, 3], 1), - ] - ) - def test_split(self, _, split_size_or_sections, dim): - class Split(nn.Module): - def forward(self, x): - return x.split(split_size_or_sections, dim)[0] - - inputs = [torch.randn(1, 10)] - self.run_test( - Split(), - inputs, - expected_ops={ - acc_ops.split - if isinstance(split_size_or_sections, int) - else acc_ops.slice_tensor - }, - test_explicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("split_with_size", [2, 3, 5], 1), - ] - ) - def test_split_with_size(self, _, split_size, dim): - class Split(nn.Module): - def forward(self, x): - return x.split_with_sizes(split_size, dim) - - inputs = [torch.randn(1, 10)] - self.run_test( - Split(), - inputs, - expected_ops={acc_ops.slice_tensor}, - test_explicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("split_size", 3, 1), - ("sections", [5, 2, 3], 1), - ] - ) - def test_split_with_dynamic_shape(self, _, split_size_or_sections, dim): - class Split(nn.Module): - def forward(self, x): - return x.split(split_size_or_sections, dim)[0] - - input_specs = [ - InputTensorSpec( - shape=(-1, 10, -1), - dtype=torch.float32, - shape_ranges=[((1, 10, 10), (5, 10, 15), (10, 10, 20))], - ), - ] - self.run_test_with_dynamic_shape( - Split(), - input_specs, - expected_ops={ - acc_ops.split - if isinstance(split_size_or_sections, int) - else acc_ops.slice_tensor - }, - ) - - # Testing with (-1, -1, -1) results into following error: - # AssertionError: Can't chunk on dynamic shape dimension! - - @parameterized.expand( - [ - ("split_with_size", [2, 3, 5], 1), - ] - ) - def test_split_with_size_dynamic_shape(self, _, split_size, dim): - class Split(nn.Module): - def forward(self, x): - return x.split_with_sizes(split_size, dim) - - input_specs = [ - InputTensorSpec( - shape=(-1, 10, -1), - dtype=torch.float32, - shape_ranges=[((1, 10, 20), (5, 10, 20), (10, 10, 20))], - ), - ] - self.run_test_with_dynamic_shape( - Split(), - input_specs, - expected_ops={acc_ops.slice_tensor}, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py deleted file mode 100644 index f1cc4fe96d..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_squeeze.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestSqueeze(AccTestCase): - def test_squeeze(self): - class Squeeze(nn.Module): - def forward(self, x): - return x.squeeze(2) - - inputs = [torch.randn(1, 2, 1)] - self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) - - # Testing with shape=(-1, -1, -1, -1) results in error: - # AssertionError: We don't support squeeze dynamic dim. - - # Testing with more than one dynamic dim results in error: - # AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. - - def test_squeeze_with_dynamic_shape(self): - class Squeeze(nn.Module): - def forward(self, x): - return x.squeeze(0) - - input_specs = [ - InputTensorSpec( - shape=(1, -1, 2), - dtype=torch.float32, - shape_ranges=[((1, 1, 2), (1, 2, 2), (1, 3, 2))], - ), - ] - self.run_test_with_dynamic_shape( - Squeeze(), input_specs, expected_ops={acc_ops.squeeze} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py deleted file mode 100644 index bc1d0ece89..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_std.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestMinConverter(AccTestCase): - @parameterized.expand( - [ - ("norm_1d", (-1), False), - ("norm_1d", (-1), True), - ("norm_2d", (2, 3), False), - ("norm_2d", (2, 3), True), - ] - ) - def test_std(self, _, dim, unbiased): - class Std(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.std(x, dim, unbiased=unbiased, keepdim=True) - - inputs = [torch.randn(2, 3, 4, 5)] - self.run_test( - Std(), - inputs, - expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, - ) - - @parameterized.expand( - [ - ("norm_1d", (-1), False), - ("norm_1d", (-1), True), - ("norm_2d", (2, 3), False), - ("norm_2d", (2, 3), True), - ] - ) - def test_std_with_dynamic_shape_four_dimensions(self, _, dim, unbiased): - class Std(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.std(x, dim, unbiased=unbiased, keepdim=True) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Std(), - input_specs, - expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, - ) - - @parameterized.expand( - [ - ("norm_1d", (-1), True), - ("norm_1d", (-1), False), - ("norm_2d", (2, 3), True), - ("norm_2d", (2, 3), False), - ] - ) - def test_std_method(self, _, dim, unbiased): - class Std(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.std(dim, unbiased=unbiased, keepdim=True) - - inputs = [torch.randn(2, 3, 4, 5)] - self.run_test( - Std(), - inputs, - expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, - ) - - @parameterized.expand( - [ - ("norm_1d", (-1), True), - ("norm_1d", (-1), False), - ("norm_2d", (2, 3), True), - ("norm_2d", (2, 3), False), - ] - ) - def test_std_method_with_dynamic_shape_four_dimensions(self, _, dim, unbiased): - class Std(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x.std(dim, unbiased=unbiased, keepdim=True) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Std(), - input_specs, - expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py deleted file mode 100644 index dd39d29d41..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tanh.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestTanh(AccTestCase): - def test_tanh(self): - class Tanh(nn.Module): - def forward(self, x): - return torch.tanh(x) - - inputs = [torch.randn(1, 2, 3)] - self.run_test(Tanh(), inputs, expected_ops={acc_ops.tanh}) - - def test_tanh_with_dynamic_shape(self): - class Tanh(nn.Module): - def forward(self, x): - return torch.tanh(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Tanh(), input_specs, expected_ops={acc_ops.tanh} - ) - - def test_tanh_with_dynamic_shape_four_dimensions(self): - class Tanh(nn.Module): - def forward(self, x): - return torch.tanh(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 3), (1, 2, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Tanh(), input_specs, expected_ops={acc_ops.tanh} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py deleted file mode 100644 index c370c58eba..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_tile.py +++ /dev/null @@ -1,148 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestTile(AccTestCase): - @parameterized.expand( - [ - ("same_num_dims", (2, 2, 3), (1, 2, 2)), - ("less_dims", (2, 2, 3), (2,)), - ("more_dims", (2, 3), (1, 2, 2, 1)), - ] - ) - def test_tile(self, _, input_shape, dims): - class Tile(nn.Module): - def __init__(self, dims): - super().__init__() - self.dims = dims - - def forward(self, x): - return torch.tile(x, self.dims) - - inputs = [torch.randn(*input_shape)] - self.run_test( - Tile(dims), - inputs, - expected_ops={acc_ops.tile}, - test_implicit_batch_dim=( - len(input_shape) > len(dims) - or (len(input_shape) == len(dims) and dims[0] == 1) - ), - ) - - @parameterized.expand( - [ - ("same_num_dims", (-1, 2, 3), (1, 2, 2)), - ("less_dims", (-1, 2, 3), (2,)), - ("more_dims", (-1, 3), (1, 2, 2, 1)), - ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)), - ] - ) - def test_tile_with_dynamic_shape(self, _, shape, dims): - class Tile(nn.Module): - def __init__(self, dims): - super().__init__() - self.dims = dims - - def forward(self, x): - return torch.tile(x, self.dims) - - input_specs = [ - InputTensorSpec( - shape=shape, - dtype=torch.float32, - shape_ranges=[ - ( - tuple(i if i != -1 else 1 for i in shape), - tuple(i if i != -1 else 2 for i in shape), - tuple(i if i != -1 else 3 for i in shape), - ) - ], - ), - ] - self.run_test_with_dynamic_shape( - Tile(dims), input_specs, expected_ops={acc_ops.tile} - ) - - @parameterized.expand( - [ - ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)), - ] - ) - def test_tile_with_dynamic_shape_four_dimensions(self, _, shape, dims): - class Tile(nn.Module): - def __init__(self, dims): - super().__init__() - self.dims = dims - - def forward(self, x): - return torch.tile(x, self.dims) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Tile(dims), input_specs, expected_ops={acc_ops.tile} - ) - - def test_tile_non_int_dims(self): - class Tile(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y = y * 2 - return torch.tile(x, (1, y.shape[1], y.shape[1])) - - inputs = [torch.randn(2, 2, 3), torch.randn(2, 2, 3)] - batch_size_range = (1, 2, 3) - input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( - inputs, batch_size_range - ) - self.run_test_with_dynamic_shape( - Tile(), - input_specs, - expected_ops={acc_ops.tile}, - ) - - def test_tile_non_int_dims_with_dynamic_shape_four_dimensions(self): - class Tile(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y = y * 2 - return torch.tile(x, (1, y.shape[1], y.shape[1])) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Tile(), input_specs, expected_ops={acc_ops.tile} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py deleted file mode 100644 index 788a252e6e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_to_dtype.py +++ /dev/null @@ -1,322 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision - - -class TestToConverter(AccTestCase): - def test_fp16(self): - class To(torch.nn.Module): - def forward(self, x): - return x.to(torch.float16) - - input = torch.randn(2, 2) - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP16, - ) - - # Testing with shape shape=(-1, -1, -1, -1) results into following error: - # Error: assert engine - """ - def test_fp16_with_dynamic_shape_four_dimension(self): - class To(torch.nn.Module): - def forward(self, x): - return x.to(torch.float16) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float16, - shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], - ).cuda(), - ] - - self.run_test_with_dynamic_shape( - To(), input_specs, expected_ops={acc_ops.to_dtype} - ) - """ - - def test_fp32(self): - class To(torch.nn.Module): - def forward(self, x): - return x.to(torch.float32) - - input = torch.randn(2, 2).to(torch.float16) - inputs = [ - input, - ] - self.run_test( - To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False - ) - - def test_cuda_fp16(self): - class To(torch.nn.Module): - def forward(self, x): - return x.to(torch.device("cuda:0"), torch.float16) - - input = torch.randn(2, 2) - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP16, - ) - - def test_cuda(self): - class To(torch.nn.Module): - def forward(self, x): - x = x.to(torch.device("cuda")) - # append extra layer since to(device) is skipped in TRT - return x + torch.randn(2, 2).cuda() - - input = torch.randn(2, 2) - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype, acc_ops.add}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP32, - ) - - def test_cuda_with_dynamic_shape_four_dimensions(self): - class To(torch.nn.Module): - def forward(self, x): - x = x.to(torch.device("cuda")) - # append extra layer since to(device) is skipped in TRT - return x + torch.randn(3, 3, 3, 3).cuda() - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float16, - shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add} - ) - - def test_device(self): - class To(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 2) - - def forward(self, x): - idevice = x.device - a = self.a.to(idevice) - return x + a - - input = torch.randn(2, 2).cuda() - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP32, - ) - - def test_device_with_dynamic_shape_four_dimensions(self): - class To(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(3, 3, 3, 3) - - def forward(self, x): - idevice = x.device - a = self.a.to(idevice) - return x + a - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float16, - shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add} - ) - - def test_device_fp16(self): - class To(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 2) - - def forward(self, x): - idevice = x.device - idtype = x.dtype - a = self.a.to(idevice) - # fx tracer could not handle "to(idevice, torch.float16)" - # TypeError: to() received an invalid combination of arguments - got (Attribute, torch.dtype) - return a.to(idtype) - - input = torch.randn(2, 2).half().cuda() - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP16, - ) - - # Testing with shape shape=(-1, -1, -1, -1) results into following error: - # Error: assert engine - """ - def test_device_fp16_with_dynamic_shape_four_dimensions(self): - class To(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 2) - - def forward(self, x): - idevice = x.device - idtype = x.dtype - a = self.a.to(idevice) - # fx tracer could not handle "to(idevice, torch.float16)" - # TypeError: to() received an invalid combination of arguments - got (Attribute, torch.dtype) - return a.to(idtype) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float16, - shape_ranges=[((2, 2, 2, 2), (4, 4, 4, 4), (4, 4, 4, 4))], - ), - ] - - self.run_test_with_dynamic_shape( - To(), input_specs, expected_ops={acc_ops.to_dtype} - ) - """ - - # tensor.float() - def test_float(self): - class To(torch.nn.Module): - def forward(self, x): - return x.float() - - input = torch.randn(2, 2).half() - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP32, - ) - - # tensor.float() - def test_float_with_dynamic_shape_four_dimensions(self): - class To(torch.nn.Module): - def forward(self, x): - return x.float() - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - To(), input_specs, expected_ops={acc_ops.to_dtype} - ) - - # Half is not suitable for dynamic shape - # Error: assert engine - - # tensor.half() - def test_half(self): - class To(torch.nn.Module): - def forward(self, x): - return x.half() - - input = torch.randn(2, 2) - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP16, - ) - - # TODO Open in future. TRT 8.5 does not work for this test - # The test is a rare case. We need to remove it in graph maybe. - # def test_int(self): - # class To(torch.nn.Module): - # def forward(self, x): - # x = x.int() - # # we do not expect int to be output type, so add an extra layer - # x = x.float() - # return x - - # input = torch.randn(2, 2) - # inputs = [ - # input, - # ] - # self.run_test( - # To(), - # inputs, - # expected_ops={acc_ops.to_dtype}, - # test_implicit_batch_dim=False, - # precision=LowerPrecision.FP32, - # ) - - # # tensor.int() - # def test_int_with_dynamic_shape_four_dimensions(self): - # class To(torch.nn.Module): - # def forward(self, x): - # x = x.int() - # # we do not expect int to be output type, so add an extra layer - # x = x.float() - # return x - - # input_specs = [ - # InputTensorSpec( - # shape=(-1, -1, -1, -1), - # dtype=torch.int, - # shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - # ), - # ] - - # self.run_test_with_dynamic_shape( - # To(), input_specs, expected_ops={acc_ops.to_dtype} - # ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py deleted file mode 100644 index 83de8eb894..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_topk.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestTopKConverter(AccTestCase): - @parameterized.expand( - [ - ("top1", 1, -1), - ("top2", 2, -1), - ("none_dim", 1, None), - ("smallest", 1, -1, False), - ("top1_dim0", 1, 0, False), - ] - ) - def test_topk(self, _, k, dim, largest=True): - class TopK(nn.Module): - def __init__(self, k, dim): - super().__init__() - self.k = k - self.dim = dim - self.largest = largest - - def forward(self, x): - if self.dim is not None: - out = torch.topk( - x, k=self.k, dim=self.dim, largest=self.largest, sorted=False - ) - else: - out = torch.topk(x, k=self.k, largest=self.largest, sorted=False) - return out[0], out[1] - - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - TopK(k, dim), - inputs, - expected_ops={acc_ops.topk}, - test_implicit_batch_dim=(dim != 0), - ) - - @parameterized.expand( - [ - ("top1", 1, -1), - ("top2", 2, -1), - ("none_dim", 1, None), - ("smallest", 1, -1, False), - ("top1_dim0", 1, 0, False), - ] - ) - def test_topk_with_dynamic_shape_four_dimensions(self, _, k, dim, largest=True): - class TopK(nn.Module): - def __init__(self, k, dim): - super().__init__() - self.k = k - self.dim = dim - self.largest = largest - - def forward(self, x): - if self.dim is not None: - out = torch.topk( - x, k=self.k, dim=self.dim, largest=self.largest, sorted=False - ) - else: - out = torch.topk(x, k=self.k, largest=self.largest, sorted=False) - return out[0], out[1] - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TopK(k, dim), input_specs, expected_ops={acc_ops.topk} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py deleted file mode 100644 index 1f837c12f7..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_transpose_convolution.py +++ /dev/null @@ -1,140 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestTransposeConvolutionConverter(AccTestCase): - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1), (1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv_transpose2d( - self, - _, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - bias=True, - dilation=1, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose2d( - in_channels=3, - out_channels=6, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - ) - - def forward(self, x): - return self.conv_transpose(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv_transpose2d}) - - def test_conv_transpose2d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose2d(3, 3, 1) - - def forward(self, x): - return self.conv_transpose(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv_transpose2d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv_transpose3d( - self, - _, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - bias=True, - dilation=1, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose3d( - in_channels=3, - out_channels=6, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - ) - - def forward(self, x): - return self.conv_transpose(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv_transpose3d}) - - def test_conv_transpose3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose3d(3, 6, 1) - - def forward(self, x): - return self.conv_transpose(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv_transpose3d} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py deleted file mode 100644 index 2b6869d0f0..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_type_as.py +++ /dev/null @@ -1,153 +0,0 @@ -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision - - -class TestTypeAsConverter(AccTestCase): - def test_device_fp32(self): - class Type_as(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 2) - - def forward(self, x): - b = self.a.type_as(x) - return b - - # self.a = self.a.type_as(x) # error is throw - # return self.a - - input = torch.randn(2, 2).cuda() - inputs = [ - input, - ] - self.run_test( - Type_as(), - inputs, - expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, - test_implicit_batch_dim=False, - ) - - def test_device_fp16(self): - class Type_as(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 2) - - def forward(self, x): - return self.a.type_as(x) - - input = torch.randn(2, 2).half().cuda() - inputs = [ - input, - ] - self.run_test( - Type_as(), - inputs, - expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP16, - ) - - def test_device_fp32_tensor(self): - class Type_as(torch.nn.Module): - def forward(self, input, other): - return other.type_as(input) - - input = torch.randn(2, 2).cuda() - other = torch.randn(2, 2) - inputs = [ - input, - other, - ] - self.run_test( - Type_as(), - inputs, - expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, - ) - - def test_device_fp16_tensor(self): - class Type_as(torch.nn.Module): - def forward(self, input, other): - return other.type_as(input) - - input = torch.randn(2, 2).half().cuda() - other = torch.randn(2, 2) - inputs = [ - input, - other, - ] - self.run_test( - Type_as(), - inputs, - expected_ops={acc_ops.to_dtype, acc_ops.device, acc_ops.dtype}, - precision=LowerPrecision.FP16, - ) - - def test_type_tensor(self): - class Type_as(torch.nn.Module): - def forward(self, input): - return input.type(dtype=torch.float16) - - input = torch.randn(2, 2) - - inputs = [ - input, - ] - self.run_test( - Type_as(), - inputs, - expected_ops={acc_ops.to_dtype}, - precision=LowerPrecision.FP16, - ) - - @unittest.skip("Does not pass in TRT 8.4.1 T127981773") - def test_type_tensor_with_dynamic_shape_four_dimensions(self): - class Type_as(torch.nn.Module): - def forward(self, input): - return input.type(dtype=torch.float32) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.int, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - Type_as(), - input_specs, - expected_ops={acc_ops.to_dtype}, - ) - - def test_type_tensor_ext(self): - class Type_as(torch.nn.Module): - def forward(self, input, other): - t = input.type() - return other.type(t) - - input = torch.randn(2, 2).to(dtype=torch.float16) - other = torch.randn(2, 2) - - inputs = [ - input, - other, - ] - self.run_test( - Type_as(), - inputs, - expected_ops={acc_ops.to_dtype, acc_ops.dtype}, - precision=LowerPrecision.FP16, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py deleted file mode 100644 index 2015fc21ef..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unary_ops.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import Callable - -import torch -import torch.nn as nn - -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - -unary_ops = [ - (torch.sin, acc_ops.sin, False), - (torch.cos, acc_ops.cos, False), - (torch.tan, acc_ops.tan, False), - (torch.sinh, acc_ops.sinh, False), - (torch.cosh, acc_ops.cosh, False), - (torch.asin, acc_ops.asin, True), - (torch.acos, acc_ops.acos, True), - (torch.atan, acc_ops.atan, True), - (torch.abs, acc_ops.abs, False), - (torch.neg, acc_ops.neg, False), - (torch.reciprocal, acc_ops.reciprocal, False), - (torch.sqrt, acc_ops.sqrt, False), - (torch.log, acc_ops.log, False), - (torch.exp, acc_ops.exp, False), - (torch.floor, acc_ops.floor, False), - (torch.ceil, acc_ops.ceil, False), - (torch.sign, acc_ops.sign, False), -] - - -class TestUnaryOpConverters(AccTestCase): - @parameterized.expand([(op[1].__name__, op[0], op[1], op[2]) for op in unary_ops]) - def test_unary_ops( - self, name, orig_op: Callable, expected_op: Callable, range_req: bool - ): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, x): - return self.orig_op(x) - - m = TestModule(orig_op) - inputs = ( - [torch.distributions.uniform.Uniform(-1, 1).sample([2, 2, 3])] - if range_req - else [torch.randn(2, 2, 3)] - ) - self.run_test(m, inputs, expected_ops={expected_op}) - - -class TestUnaryVOpConvertersWithDynamicShapeFourDimensions(AccTestCase): - @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops]) - def test_unary_ops(self, name, orig_op: Callable, expected_op): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, x): - return self.orig_op(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(orig_op), input_specs, expected_ops={expected_op} - ) - - -class TestUnaryOpNotConverters(AccTestCase): - @parameterized.expand( - [ - ("not_bool", torch.logical_not, acc_ops.logical_not, torch.bool), - ("not_float", torch.logical_not, acc_ops.logical_not, torch.float), - ("not_int", torch.logical_not, acc_ops.logical_not, torch.int), - ] - ) - def test_unary_ops(self, name, orig_op: Callable, expected_op, input_dtype): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, x): - x = self.orig_op(x) - return self.orig_op(x) - - m = TestModule(orig_op) - inputs = [torch.randn(2, 2, 3).to(input_dtype)] - self.run_test( - m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False - ) - - -class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase): - @parameterized.expand( - [ - ("not_bool", torch.logical_not, acc_ops.logical_not, torch.bool), - ("not_float", torch.logical_not, acc_ops.logical_not, torch.float), - ("not_int", torch.logical_not, acc_ops.logical_not, torch.int), - ] - ) - def test_unary_ops(self, name, orig_op: Callable, expected_op, input_dtype): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, x): - x = self.orig_op(x) - return self.orig_op(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(orig_op), input_specs, expected_ops={expected_op} - ) - - -class TestUnaryRSQRTConverters(AccTestCase): - def test_unary_ops(self): - class TestModule(nn.Module): - def forward(self, x): - return torch.rsqrt(x) - - m = TestModule() - inputs = [torch.randn(2, 2, 3)] - self.run_test(m, inputs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}) - - -class TestUnaryRSQRTConvertersWithDynamicShapeFourDimensions(AccTestCase): - def test_unary_ops(self): - class TestModule(nn.Module): - def forward(self, x): - return torch.rsqrt(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py deleted file mode 100644 index 059374194c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_unsqueeze.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import torch.fx -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - AccTestCase, - InputTensorSpec, -) - - -class TestUnsqueeze(AccTestCase): - @parameterized.expand( - [ - ("negative_dim", -2), - ("positive_dim", 2), - ] - ) - def test_unsqueeze(self, _, dim): - class Unsqueeze(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - return torch.unsqueeze(x, self.dim) - - inputs = [torch.randn(1, 2, 3)] - self.run_test(Unsqueeze(dim), inputs, expected_ops={acc_ops.unsqueeze}) - - # Testing with more than one dynamic dims results in following error: - # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. - - @parameterized.expand( - [ - ("negative_dim_dynamic", -4), - ("positive_dim_dynamic", 1), - ] - ) - def test_unsqueeze_with_dynamic_shape(self, _, dim): - class Unsqueeze(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - return torch.unsqueeze(x, self.dim) - - input_specs = [ - InputTensorSpec( - shape=(-1, 2, 3), - dtype=torch.float32, - shape_ranges=[((1, 2, 3), (2, 2, 3), (3, 2, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Unsqueeze(dim), input_specs, expected_ops={acc_ops.unsqueeze} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py deleted file mode 100644 index 9c846709bf..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/acc_op/test_where.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -class TestWhere(AccTestCase): - @parameterized.expand( - [ - ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), - ] - ) - def test_where(self, _, condition_shape, x_shape, y_shape): - class Where(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, condition, x, y): - return torch.where(condition, x, y) - - inputs = [ - (torch.randn(condition_shape) > 0), - torch.randn(x_shape), - torch.ones(y_shape), - ] - self.run_test( - Where(), - inputs, - expected_ops={acc_ops.where}, - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), - ] - ) - def test_where_attribute_condition(self, _, condition_shape, x_shape, y_shape): - class Where(nn.Module): - def __init__(self, condition_shape): - super().__init__() - self.condition = torch.randn(condition_shape) > 0 - - def forward(self, x, y): - return torch.where(self.condition, x, y) - - inputs = [torch.randn(x_shape), torch.ones(y_shape)] - self.run_test( - Where(condition_shape), - inputs, - expected_ops={acc_ops.where}, - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), - ] - ) - def test_where_attribute_condition_x(self, _, condition_shape, x_shape, y_shape): - class Where(nn.Module): - def __init__(self, condition_shape, x_shape): - super().__init__() - self.condition = torch.randn(condition_shape) > 0 - self.x = torch.randn(x_shape) - - def forward(self, y): - return torch.where(self.condition, self.x, y) - - inputs = [torch.ones(y_shape)] - self.run_test( - Where(condition_shape, x_shape), - inputs, - expected_ops={acc_ops.where}, - test_implicit_batch_dim=False, - ) - - @parameterized.expand( - [ - ("same_shape", (1, 3, 2), (1, 3, 2), (1, 3, 2)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 1)), - ("broadcast_shape", (1, 3, 2), (1, 1, 1), (1, 1, 2)), - ] - ) - def test_where_attribute_x_y(self, _, condition_shape, x_shape, y_shape): - class Where(nn.Module): - def __init__(self, x_shape, y_shape): - super().__init__() - - self.x = torch.randn(x_shape) - self.y = torch.ones(y_shape) - - def forward(self, condition): - return torch.where(condition, self.x, self.y) - - inputs = [(torch.randn(condition_shape) > 0)] - self.run_test( - Where(x_shape, y_shape), - inputs, - expected_ops={acc_ops.where}, - test_implicit_batch_dim=False, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py deleted file mode 100644 index b3d8550bb6..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_adaptive_avgpool_aten.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestAdaptiveAvgPoolConverter(DispatchTestCase): - def test_adaptive_avgpool_mean(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 256, 256)] - self.run_test( - TestModule(), - inputs, - expected_ops={torch.ops.aten.mean.dim}, - ) - - @parameterized.expand( - [ - ((64, 64),), - ((128, 64),), - (64,), - ] - ) - def test_adaptive_avgpool( - self, - output_size, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d(output_size) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 256, 256)] - self.run_test( - TestModule(), - inputs, - expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, - ) - - def test_adaptive_avgpool_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, 256, 256), - dtype=torch.float32, - shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, - ) - - @parameterized.expand( - [ - ((16, 16, 16),), - ((32, 16, 4),), - (32,), - ] - ) - def test_adaptive_avgpool3d( - self, - output_size, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d(output_size) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 32, 64, 64)] - self.run_test( - TestModule(), - inputs, - expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, - ) - - def test_adaptive_avgpool3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, 32, 64, 64), - dtype=torch.float32, - shape_ranges=[ - ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) - ], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, - ) - - # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py deleted file mode 100644 index 2ca9b7ed82..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_batchnorm_aten.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestBatchNormConverter(DispatchTestCase): - def test_batchnorm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - return self.bn(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) - - def test_batchnorm1d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm1d(3) - - def forward(self, x): - return self.bn(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 5), - dtype=torch.float32, - shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} - ) - - def test_batchnorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - return self.bn(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} - ) - - # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py deleted file mode 100644 index a328b8655c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_binary_ops_aten.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import Callable - -import torch -import torch.nn as nn - -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - -NEED_TEST_BOTH_CONSTANTS_CASE = True - -elementwise_ops = [ - ((lambda x, y: x + y), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), - ( - (lambda x, y: torch.add(x, y)), - torch.ops.aten.add.Tensor, - NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ((lambda x, y: x.add(y)), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), - ((lambda x, y: x - y), torch.ops.aten.sub.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), - ((lambda x, y: torch.sub(x, y)), torch.ops.aten.sub.Tensor, False), - ((lambda x, y: x.sub(y)), torch.ops.aten.sub.Tensor, False), - ((lambda x, y: x / y), torch.ops.aten.div.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), - ( - (lambda x, y: x // y), - torch.ops.aten.floor_divide.default, - NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ( - (lambda x, y: torch.div(x, y, rounding_mode="trunc")), - torch.ops.aten.div.Tensor_mode, - not NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ( - (lambda x, y: torch.div(x, y, rounding_mode="floor")), - torch.ops.aten.div.Tensor_mode, - NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ( - (lambda x, y: torch.div(x, y)), - torch.ops.aten.div.Tensor, - NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ( - (lambda x, y: torch.fmod(x, y)), - torch.ops.aten.fmod.Tensor, - not NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ## torch.floor_divide rounds result toward zero, rather than -Inf. - ## https://github.com/pytorch/pytorch/issues/43874 - ( - (lambda x, y: torch.floor_divide(x, y)), - torch.ops.aten.floor_divide.default, - not NEED_TEST_BOTH_CONSTANTS_CASE, - ), - ((lambda x, y: x * y), torch.ops.aten.mul.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), - (torch.pow, torch.ops.aten.pow.Tensor_Tensor, not NEED_TEST_BOTH_CONSTANTS_CASE), -] - - -class TestBinaryOpConverters(DispatchTestCase): - @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) - def test_elementwise_ops(self, name, orig_op: Callable, expected_op): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, x): - return self.orig_op(x, x) - - m = TestModule(orig_op) - # Avoid dividing by 0. - inputs = [torch.rand(1, 1) + 1] - self.run_test(m, inputs, expected_ops={expected_op}) - - @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) - def test_elementwise_ops_with_one_constant( - self, name, orig_op: Callable, expected_op - ): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.constant = torch.randn(1) - self.orig_op = orig_op - - def forward(self, x): - x = self.orig_op(x, self.constant) - return self.orig_op(x, -2) - - m = TestModule(orig_op) - inputs = [torch.randn(2, 2)] - self.run_test(m, inputs, expected_ops={expected_op}) - - @parameterized.expand( - [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] - ) - def test_elementwise_op_with_both_constants( - self, name, orig_op: Callable, expected_op - ): - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.constant0 = torch.nn.Parameter(torch.randn(1)) - self.constant1 = torch.nn.Parameter(torch.randn(1)) - self.orig_op = orig_op - - def forward(self, x): - const = self.orig_op(self.constant0, self.constant1) - return self.orig_op(x, const) - - m = TestModule(orig_op) - inputs = [torch.randn(2, 2)] - self.run_test(m, inputs, expected_ops={expected_op}) - - # Dynamic shape test - @parameterized.expand( - [ - ( - f"no_broadcast_{op[1].__name__}", - (-1, -1), - ((1, 1), (2, 2), (3, 3)), - (-1, -1), - ((1, 1), (2, 2), (3, 3)), - op[0], - op[1], - ) - for op in elementwise_ops - ] - + [ - ( - f"broadcast_{op[1].__name__}", - (-1, -1, -1), - ((1, 1, 1), (2, 2, 2), (3, 3, 3)), - (-1, -1), - ((1, 1), (2, 2), (3, 3)), - op[0], - op[1], - ) - for op in elementwise_ops - ] - ) - def test_elementwise_op_with_dynamic_shape( - self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op - ): - class Op(nn.Module): - def forward(self, x, y): - return orig_op(x, y) - - input_specs = [ - InputTensorSpec( - shape=x_shape, - dtype=torch.float32, - shape_ranges=[x_shape_ranges], - ), - InputTensorSpec( - shape=y_shape, - dtype=torch.float32, - shape_ranges=[y_shape_ranges], - ), - ] - self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) - - @parameterized.expand( - [ - ( - f"no_broadcast_{op[1].__name__}", - op[0], - op[1], - ) - for op in elementwise_ops - ] - + [ - ( - f"broadcast_{op[1].__name__}", - op[0], - op[1], - ) - for op in elementwise_ops - ] - ) - def test_elementwise_op_with_dynamic_shape_four_dimensions( - self, _, orig_op, expected_op - ): - class Op(nn.Module): - def forward(self, x, y): - return orig_op(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py deleted file mode 100644 index 50190113ad..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_cat_aten.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -import torch.nn as nn -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestCatConverter(DispatchTestCase): - @parameterized.expand( - [ - ("pos", 1), - # ("neg", -2), #Dynamo tracer issue - ] - ) - def test_cat(self, _, dim): - class Cat(nn.Module): - def forward(self, x, y, z): - return torch.cat((x, y, z), dim) - - inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] - self.run_test( - Cat(), - inputs, - expected_ops={torch.ops.aten.cat.default}, - ) - - @parameterized.expand( - [ - ("pos", 1), - # ("neg", -2), #Dynamo tracer issue - ] - ) - def test_cat_dynamic_shape(self, _, dim): - class Cat(nn.Module): - def forward(self, x, y): - return torch.cat((x, y), dim) - - input_specs = [ - InputTensorSpec( - shape=(16, -1, 3), - dtype=torch.float32, - shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))], - ), - InputTensorSpec( - shape=(16, -1, 3), - dtype=torch.float32, - shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))], - ), - ] - self.run_test_with_dynamic_shape( - Cat(), - input_specs, - expected_ops={torch.ops.aten.cat.default}, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py deleted file mode 100644 index 9c4ceaa9bf..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_convolution_aten.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestConvolutionConverter(DispatchTestCase): - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1), (1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv1d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32)] - self.run_test( - TestModule(), - inputs, - expected_ops={torch.ops.aten.convolution.default}, - test_explicit_precision=True, - ) - - def test_conv1d_with_dynamic_shape( - self, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} - ) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1), (1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv2d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - 3, - 6, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32)] - self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} - ) - - # Testing with (-1, -1, -1, -1) results into Error: - # AssertionError: Channel dim can't be dynamic for convolution. - - def test_conv2d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 1) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} - ) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 - # param("groups", 1, groups=3), - ] - ) - def test_conv3d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} - ) - - # Testing with (-1, -1, -1, -1, -1) results into Error: - # AssertionError: Channel dim can't be dynamic for convolution. - - def test_conv3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d(3, 6, 1) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py deleted file mode 100644 index fe8b32692e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_expand_aten.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn as nn -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import DispatchTestCase - - -class TestExpandConverter(DispatchTestCase): - @parameterized.expand( - [ - ("2d_dim", (2, 3), (2, 1)), - ("3d_dim", (2, 3, 4), (2, 1, 1)), - ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), - ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), - ] - ) - def test_expand(self, _, sizes, init_size): - class Expand(nn.Module): - def forward(self, x): - return x.expand(*sizes) - - inputs = [torch.randn(*init_size)] - self.run_test( - Expand(), - inputs, - expected_ops={torch.ops.aten.expand.default}, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py deleted file mode 100644 index ca9e8143ce..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_flatten_aten.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -import torch -import torch.nn as nn -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestFlattenConverter(DispatchTestCase): - @parameterized.expand( - [ - ("flatten_middle_dims", 1, 2), - ("flatten_last_3_dims", 1, 3), - ("flatten_all", 0, 3), - ] - ) - @unittest.skip("Not support yet") - def test_flatten(self, _, start_dim, end_dim): - class Flatten(nn.Module): - def __init__(self, start, end): - super().__init__() - self.start = start - self.end = end - - def forward(self, x): - return torch.flatten(x, self.start, self.end) - - inputs = [torch.randn(1, 2, 3, 1)] - self.run_test( - Flatten(start_dim, end_dim), - inputs, - expected_ops={torch.ops.aten.view.default}, - ) - - ## Dynamic shape does not work due to flatten converts to reshape in tracing. And batch or dynamic dimension is converted to fixed integer and loose dynamic - ## For ex., flatten (1, 512, 1, 1) with start_dim=1, end_dim=-1. After convert to reshape, output size=(1, 512) which is not correct since dim=0 is -1. - ## This problem may be solved using dynamic shape propogation. And we will know dim=0 is dynamic and we should set -1 in converter. - - # @parameterized.expand( - # [ - # ("flatten_middle_dims", 1, 2), - # ] - # ) - # def test_flatten_with_dynamic_shape(self, _, start_dim, end_dim): - # class Flatten(nn.Module): - # def __init__(self, start, end): - # super().__init__() - # self.start = start - # self.end = end - - # def forward(self, x): - # return torch.flatten(x, self.start, self.end) - - # input_specs = [ - # InputTensorSpec( - # shape=(-1, -1, -1, -1, -1), - # dtype=torch.float32, - # shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 3, 2, 1), (3, 3, 3, 3, 3))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # Flatten(start_dim, end_dim), - # input_specs, - # expected_ops={torch.ops.aten._reshape_alias.default}, - # ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py deleted file mode 100644 index 8790cdeecc..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_linear_aten.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestLinearConverter(DispatchTestCase): - @parameterized.expand( - [ - ("default", [1, 512], True, torch.ops.aten.linear), - ("matrix", [5, 512], True, torch.ops.aten.linear), - ("no_bias", [1, 512], False, torch.ops.aten.linear), - ( - "multi_dim_matrix", - [4, 5, 512], - True, - torch.ops.aten.linear, - ), - ( - "multi_dim_matrix", - [4, 5, 512], - False, - torch.ops.aten.linear, - ), - ] - ) - def test_linear(self, test_name, shape, bias, op): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(512, 256, bias) - - def forward(self, x): - return self.linear(x) - - inputs = [torch.randn(shape)] - self.run_test(TestModule(), inputs, expected_ops={op}) - - # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern - # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. - - # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. - - # def test_linear_with_dynamic_shape(self): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.linear = torch.nn.Linear(512, 256) - - # def forward(self, x): - # return self.linear(x) - - # input_specs = [ - # InputTensorSpec( - # shape=(-1, 3, 512), - # dtype=torch.float32, - # shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # TestModule(), - # input_specs, - # expected_ops={torch.ops.aten.addmm.default}, - # ) - - ## Testing with (-1, -1, 512) results into following error: - ## AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py deleted file mode 100644 index 3ffd59ed19..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_maxpool_aten.py +++ /dev/null @@ -1,248 +0,0 @@ -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestMaxPoolConverter(DispatchTestCase): - # TODO max_pool1d. It needs support of squeeze and unsqueeze - - @parameterized.expand( - [ - ("default", 1), - ("stride", 1, 2), - ("tuple_parameters", 2, (1, 1), (1, 1)), - param("padding", 2, padding=1), - param("ceil_mode", 1, ceil_mode=True), - ] - ) - def test_max_pool2d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool2d( - kernel_size, stride, padding, ceil_mode=ceil_mode - ) - - def forward(self, x): - return self.max_pool(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) - - def test_max_pool2d_with_dynamic_shape( - self, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool2d(1, 1) - - def forward(self, x): - return self.max_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={torch.ops.aten.max_pool2d}, - ) - - @parameterized.expand( - [ - ("default", 1), - # ("stride", 1, 2), - # ("tuple_parameters", 2, (1, 1, 1), (1, 1, 1)), - # param("padding", 2, padding=1), - # param("ceil_mode", 1, ceil_mode=True), - ] - ) - @unittest.skip("PT2 tracer issue") - def test_max_pool3d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool3d( - kernel_size, stride, padding, ceil_mode=ceil_mode - ) - - def forward(self, x): - return self.max_pool(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={}) - - @unittest.skip("PT2 tracer issue") - def test_max_pool3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.max_pool = torch.nn.MaxPool3d(1, 1) - - def forward(self, x): - return self.max_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool3d} - ) - - @parameterized.expand( - [ - ("default", 1), - # param("stride", 2, stride=()), #PT2 tracer issue - ] - ) - def test_stride_none_max_pool2d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool2d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - @unittest.skip("PT2 tracer issue") - def test_stride_none_max_pool3d( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool3d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool3d}) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - def test_stride_none_max_pool2d_with_dynamic_shape( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool2d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool2d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("stride", 2, stride=()), - ] - ) - @unittest.skip("PT2 tracer issue") - def test_stride_none_max_pool3d_with_dynamic_shape( - self, - test_name, - kernel_size, - stride=None, - padding=0, - ceil_mode=False, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.nn.functional.max_pool3d( - x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode - ) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool3d} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py deleted file mode 100644 index 3367e237fb..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_relu_aten.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestReLUConverter(DispatchTestCase): - def test_relu(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.relu(x) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.relu.default}) - - def test_relu_with_dynamic_shape(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.relu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} - ) - - def test_relu_with_dynamic_shape_four_dimensions(self): - class TestModule(nn.Module): - def forward(self, x): - return nn.functional.relu(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py deleted file mode 100644 index 0382ad7788..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/aten_op/test_reshape_aten.py +++ /dev/null @@ -1,105 +0,0 @@ -import unittest - -import tensorrt as trt -import torch -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import ( - DispatchTestCase, - InputTensorSpec, -) - - -class TestReshapeConverter(DispatchTestCase): - @parameterized.expand( - [ - ((1, 20),), - ((1, 10, -1),), - ] - ) - @unittest.skipIf( - trt.__version__ < "8.5", - "Shape tensor supported well in TensorRT 8.5 and later", - ) - def test_reshape(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): - super().__init__() - self.target_shape = target_shape - - def forward(self, x): - return torch.reshape(x, self.target_shape) - - inputs = [torch.randn(1, 2, 10)] - self.run_test( - TestModule(target_shape), - inputs, - expected_ops={torch.ops.aten.view.default}, - ) - - @parameterized.expand( - [ - ((-1, 10),), - ((-1, 5),), - ((2, 2, -1),), - ] - ) - @unittest.skipIf( - trt.__version__ < "8.5", - "Shape tensor supported well in TensorRT 8.5 and later", - ) - def test_reshape_with_dynamic_shape(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): - super().__init__() - self.target_shape = target_shape - - def forward(self, x): - return torch.reshape(x, self.target_shape) - - input_specs = [ - InputTensorSpec( - shape=(-1, 2, 5), - dtype=torch.float32, - shape_ranges=[((1, 2, 5), (10, 2, 5), (10, 2, 5))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(target_shape), - input_specs, - expected_ops={torch.ops.aten.view.default}, - ) - - @unittest.skipIf( - trt.__version__ < "8.5", - "Shape tensor supported well in TensorRT 8.5 and later", - ) - def test_reshape_with_dynamic_shape_size(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - shape_y = y.shape - t = shape_y[1] - return torch.reshape(x, [-1, t, 3]) - - input_specs = [ - InputTensorSpec( - shape=(-1, 5, 6), - dtype=torch.float32, - shape_ranges=[((1, 5, 6), (3, 5, 6), (3, 5, 6))], - ), - InputTensorSpec( - shape=(-1, 5), - dtype=torch.float32, - shape_ranges=[((1, 5), (3, 5), (3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={torch.ops.aten.view.default}, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py deleted file mode 100644 index 1e6c748cc1..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_add_vanilla.py +++ /dev/null @@ -1,28 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import operator - -import torch -import torch.fx -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import VanillaTestCase - - -class TestAddConverter(VanillaTestCase): - def test_operator_add(self): - def add(x): - return x + x - - inputs = [torch.randn(1, 1)] - self.run_test(add, inputs, expected_ops={operator.add}) - - def test_torch_add(self): - def add(x): - return torch.add(x, x) - - inputs = [torch.randn(1, 1)] - self.run_test(add, inputs, expected_ops={torch.add}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py deleted file mode 100644 index 4bd1c7519d..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/converters/vanilla/test_convolution_vanilla.py +++ /dev/null @@ -1,113 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import torch -import torch.fx -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import VanillaTestCase - - -class TestConvolutionConverter(VanillaTestCase): - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1), (0)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv1d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv1d}) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1), (0, 0)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv2d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv2d}) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1, 1), (0, 0, 0)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - # TODO: Enable this when TRT fixes https://github.com/pytorch/TensorRT/issues/1445 - # param("groups", 1, groups=3), - ] - ) - def test_conv3d( - self, - test_name, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv3d}) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_clamp_numerical_limits_to_fp16.py deleted file mode 100644 index 2fbac9cfb4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_clamp_numerical_limits_to_fp16.py +++ /dev/null @@ -1,74 +0,0 @@ -import logging -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch_tensorrt.fx.passes.lower_basic_pass import ( - fix_clamp_numerical_limits_to_fp16, -) - - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: - """ - Helper func to print model's graph in plain and tabular format, also print code. - """ - _LOGGER.info(mod_graph.graph) - mod_graph.graph.print_tabular() - _LOGGER.info(mod_graph.code) - - -class ClampNumericalLimitsTest(unittest.TestCase): - def setUp(self): - torch.manual_seed(0) - - def test_clamp_numerical_limits_to_fp16(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - y = torch.clamp(x + x, min=-1e8, max=1e8) - return y - - module = TestModule() - inputs = [torch.rand(3, 2, 1)] - - module.eval() - - # Before Opt - before_results = module(*inputs) - mod_traced = acc_tracer.trace(module, inputs) - before_node_list = list(mod_traced.graph.nodes) - clamp_node_before = [node for node in before_node_list if "clamp" in str(node)] - min_val_before = clamp_node_before[0].kwargs["min"] - max_val_before = clamp_node_before[0].kwargs["max"] - _LOGGER.info("Model before opt.") - debug_print_graph_module(mod_traced) - - # Apply Opt - module_after_pass = fix_clamp_numerical_limits_to_fp16(mod_traced, inputs) - - # After Opt - after_results = module_after_pass(*inputs) - after_node_list = list(mod_traced.graph.nodes) - clamp_node_after = [node for node in after_node_list if "clamp" in str(node)] - min_val_after = clamp_node_after[0].kwargs["min"] - max_val_after = clamp_node_after[0].kwargs["max"] - _LOGGER.info("Model after opt.") - mod_traced.recompile() - debug_print_graph_module(mod_traced) - - # Tests - # * Numerics - tol_args = {"rtol": 1e-2, "atol": 1e-2} - torch.testing.assert_close(before_results, after_results, **tol_args) - - # graph should not change - self.assertTrue(before_node_list == after_node_list) - - # values of clamp node changed - self.assertTrue(min_val_before != min_val_after) - self.assertTrue(max_val_before != max_val_after) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_reshape_batch_dim.py deleted file mode 100644 index bd04692ad5..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fix_reshape_batch_dim.py +++ /dev/null @@ -1,51 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import logging -from copy import deepcopy - -import torch -import torch.fx as fx -import torch.nn as nn - -from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim -from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer - -_LOGGER = logging.getLogger(__name__) - - -class TestFixReshapeBatchDim(TestCase): - def test_fix_reshape_batch_dim(self): - class Repro(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return y.view(x.size(0), -1, 3) - - mod = Repro() - modt = fx.symbolic_trace(mod) - inp = [ - torch.rand([10, 60]), - torch.rand([10, 60]), - ] - mod(*inp) - mod_acc_traced = acc_tracer.trace(modt, inp) - mod_fixed = fix_reshape_batch_dim(deepcopy(mod_acc_traced)) - - expected_graph = r""" -graph(): - %x : [#users=0] = placeholder[target=x] - %y : [#users=2] = placeholder[target=y] - %size : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.size](args = (), kwargs = {input: %y}) - %getitem_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.getitem](args = (), kwargs = {idx: 0, input: %size}) - %reshape : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) - return reshape -""" - assert ( - str(mod_fixed.graph).strip() == expected_graph.strip() - ), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}" - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py deleted file mode 100644 index b9b39c8a1b..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_linear_trt.py +++ /dev/null @@ -1,88 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.passes.lower_basic_pass import ( - fuse_permute_linear, - trt_transposed_linear, -) -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -class TestFusePermuteLinear(AccTestCase): - def test_fuse_permute_linear(self): - class TestModule(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - return self.linear(x.permute(0, 2, 1)) - - inputs = [torch.randn(6, 10, 20)] - a = TestModule(10, 30) - self.run_test( - TestModule(10, 30), - inputs, - {trt_transposed_linear}, - apply_passes=[fuse_permute_linear], - ) - - # def test_fuse_permute_linear_keep_permute(self): - # """ - # Fusion while keep permute node since permute has more than one consumers - # """ - # - # class TestModule(torch.nn.Module): - # def __init__(self, in_features, out_features): - # super().__init__() - # self.linear = torch.nn.Linear(in_features, out_features) - # - # def forward(self, x): - # y = x.permute(0, 2, 1) - # return self.linear(y), y - # - # inputs = [torch.randn(6, 10, 20)] - # a = TestModule(10, 30) - # self.run_test( - # TestModule(10, 30), - # inputs, - # {acc_ops.permute, trt_transposed_linear}, - # apply_passes=[fuse_permute_linear], - # ) - # - # # TODO: The following test has been disabled due to a bug in TRT 8.5.1.7 - # # with self.linear2. Issue : https://github.com/pytorch/TensorRT/issues/1444 - # @unittest.skip( - # reason="test_multi_fuse_permute_linear has been disabled due to a bug in TRT 8.5.1.7 https://github.com/pytorch/TensorRT/issues/1444" - # ) - # def test_multi_fuse_permute_linear(self): - # """ - # Fusion when permute output is shared by multiple linears - # """ - # - # class TestModule(torch.nn.Module): - # def __init__(self, in_features, out_features): - # super().__init__() - # self.linear1 = torch.nn.Linear(in_features, out_features) - # self.linear2 = torch.nn.Linear(in_features, out_features) - # - # def forward(self, x): - # y = x.permute(0, 2, 1) - # return self.linear1(y) + self.linear2(y) - # - # inputs = [torch.randn(8, 10, 20)] - # a = TestModule(10, 30) - # self.run_test( - # TestModule(10, 30), - # inputs, - # {trt_transposed_linear}, - # apply_passes=[fuse_permute_linear], - # ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py deleted file mode 100644 index 6570f6c276..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_fuse_permute_matmul_trt.py +++ /dev/null @@ -1,142 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.passes.lower_basic_pass import ( - fuse_permute_matmul, - trt_transposed_matmul, -) -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -def tranpose_last_two_dims(x): - return x.transpose(-1, -2) - - -def permute021(x): - return x.permute(0, 2, 1) - - -class TestFusePermuteMatmul(AccTestCase): - @parameterized.expand( - [ - ("transpose_lhs_bmm", (3, 3, 2), (3, 3, 4), tranpose_last_two_dims), - param( - "transpose_rhs_bmm", (3, 2, 3), (3, 4, 3), rhs_op=tranpose_last_two_dims - ), - ("permute_lhs_bmm", (3, 3, 2), (3, 3, 4), permute021), - param("permute_rhs_bmm", (3, 2, 3), (3, 4, 3), rhs_op=permute021), - ("permute_both_bmm", (3, 3, 2), (3, 4, 3), permute021, permute021), - ( - "permute_both_matmul", - (3, 2, 3, 2), - (3, 2, 4, 3), - lambda x: x.permute(0, 1, 3, 2), - lambda x: x.permute(0, 1, 3, 2), - torch.matmul, - ), - param( - "transpose_lhs_bmm_broadcast", - (3, 2), - (3, 3, 4), - tranpose_last_two_dims, - op=torch.matmul, - ), - param( - "transpose_rhs_bmm_broadcast", - (3, 3, 4), - (3, 4), - rhs_op=tranpose_last_two_dims, - op=torch.matmul, - ), - ] - ) - def test_fuse_permute_matmul( - self, - _, - lhs_shape, - rhs_shape, - lhs_op=lambda x: x, - rhs_op=lambda x: x, - op=torch.bmm, - ): - class TestModule(torch.nn.Module): - def forward(self, x, y): - return op(lhs_op(x), rhs_op(y)) - - inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)] - self.run_test( - TestModule(), - inputs, - {trt_transposed_matmul}, - apply_passes=[fuse_permute_matmul], - test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)), - ) - - @parameterized.expand( - [ - ("permute_both_bmm", (3, 3, 2), (3, 4, 3), permute021, permute021), - ] - ) - def test_fuse_permute_matmul_keep_permute( - self, - _, - lhs_shape, - rhs_shape, - lhs_op=lambda x: x, - rhs_op=lambda x: x, - op=torch.bmm, - ): - """ - Fusion permute while keep permute node which has more than one consumers - """ - - class TestModule(torch.nn.Module): - def forward(self, x, y): - z = lhs_op(x) - return op(z, rhs_op(y)), z - - inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)] - self.run_test( - TestModule(), - inputs, - {trt_transposed_matmul, acc_ops.permute}, - apply_passes=[fuse_permute_matmul], - ) - - @parameterized.expand( - [ - ("permute_both_bmm", (3, 3, 2), (3, 4, 3), (3, 4, 3)), - ] - ) - def test_multifuse_permute_matmul( - self, - _, - x_shape, - y_shape, - z_shape, - ): - """ - Test cases when we have multiple bmm users of one permute - """ - - class TestModule(torch.nn.Module): - def forward(self, x, y, z): - x = permute021(x) - y = permute021(y) - z = permute021(z) - return torch.bmm(x, y) + torch.bmm(x, z) - - inputs = [torch.randn(*x_shape), torch.randn(*y_shape), torch.randn(*z_shape)] - self.run_test( - TestModule(), - inputs, - {trt_transposed_matmul}, - apply_passes=[fuse_permute_matmul], - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_graph_opts.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_graph_opts.py deleted file mode 100644 index c91c456eb3..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_graph_opts.py +++ /dev/null @@ -1,187 +0,0 @@ -import logging -import unittest -from collections import Counter -from typing import Callable, Dict, List - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination - - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: - """ - Helper func to print model's graph in plain and tabular format, also print code. - """ - _LOGGER.info(mod_graph.graph) - mod_graph.graph.print_tabular() - _LOGGER.info(mod_graph.code) - - -@torch.fx.wrap -def _test_op(keys, value): - return value - - -class GraphOptsTest(unittest.TestCase): - def setUp(self): - torch.manual_seed(0) - - def _test_opt_with_module( - self, - module: torch.nn.Module, - inputs: List, - opt: Callable, - should_change_graph: bool, - deleted_ops: Dict = None, - created_ops: Dict = None, - rtol: float = None, - atol: float = None, - ): - assert should_change_graph or not bool(deleted_ops or created_ops) - deleted_ops = deleted_ops or {} - created_ops = created_ops or {} - module.eval() - - # Before Opt - before_results = module(*inputs) - mod_traced = acc_tracer.trace(module, inputs) - before_node_list = list(mod_traced.graph.nodes) - _LOGGER.info("Model before opt.") - debug_print_graph_module(mod_traced) - - # Apply Opt - graph_changed = bool(opt(mod_traced)) - - # After Opt - after_results = mod_traced(*inputs) - after_node_list = list(mod_traced.graph.nodes) - _LOGGER.info("Model after opt.") - mod_traced.recompile() - debug_print_graph_module(mod_traced) - - # Tests - # * Numerics - tol_args = {} - if rtol is not None: - tol_args["rtol"] = rtol - if atol is not None: - tol_args["atol"] = atol - torch.testing.assert_close(before_results, after_results, **tol_args) - - # * opt changes graph - self.assertEqual(graph_changed, before_node_list != after_node_list) - self.assertEqual(should_change_graph, graph_changed) - - # * modified nodes - before_node_set = set(before_node_list) - after_node_set = set(after_node_list) - self.assertEqual( - dict(Counter([node.target for node in before_node_set - after_node_set])), - deleted_ops, - ) - self.assertEqual( - dict(Counter([node.target for node in after_node_set - before_node_set])), - created_ops, - ) - - return mod_traced - - def test_common_subexpression_elimination(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - xx = x + x - xx2 = x + x - return xx * xx2 - x - - self._test_opt_with_module( - module=TestModule(), - inputs=[torch.rand(3, 2, 1)], - opt=common_subexpression_elimination, - should_change_graph=True, - deleted_ops={acc_ops.add: 1}, - ) - - def test_common_subexpression_elimination2(self): - class TestModule2(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x + x - - self._test_opt_with_module( - module=TestModule2(), - inputs=[torch.rand(3, 2, 1)], - opt=common_subexpression_elimination, - should_change_graph=False, - ) - - def test_common_subexpression_elimination3(self): - class TestModule3(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b, c): - x = a * b - y = b - c - z = a * b - xy = x + y - zy = z + y - return xy - zy - - self._test_opt_with_module( - module=TestModule3(), - inputs=[ - torch.rand(3, 2, 1), - torch.rand(3, 2, 1), - torch.rand(3, 2, 1), - ], - opt=common_subexpression_elimination, - should_change_graph=True, - deleted_ops={acc_ops.add: 1, acc_ops.mul: 1}, - ) - - def test_common_subexpression_elimination4(self): - class TestModule3(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b, c): - x = torch.cat([a, b, c]) - y = torch.cat([a, b, c]) - z = torch.cat([c, b, a]) - return x + y + z - - self._test_opt_with_module( - module=TestModule3(), - inputs=[ - torch.rand(3, 2, 1), - torch.rand(3, 2, 1), - torch.rand(3, 2, 1), - ], - opt=common_subexpression_elimination, - should_change_graph=True, - deleted_ops={acc_ops.cat: 1}, - ) - - def test_common_subexpression_elimination_string_arg(self): - class TestModule(torch.nn.Module): - def forward(self, a): - x = _test_op(["foo", "bar"], a) - return x - - self._test_opt_with_module( - module=TestModule(), - inputs=[ - torch.rand(3, 2, 1), - ], - opt=common_subexpression_elimination, - should_change_graph=False, - ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py deleted file mode 100644 index 9712ca3b91..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_multi_fuse_trt.py +++ /dev/null @@ -1,66 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.passes.lower_basic_pass import ( - fuse_permute_linear, - fuse_permute_matmul, - trt_transposed_linear, - trt_transposed_matmul, -) -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -def permute021(x): - return x.permute(0, 2, 1) - - -class TestMultiFuse(AccTestCase): - @parameterized.expand( - [ - ("permute_both_bmm", (3, 3, 2), (3, 4, 3), permute021, permute021), - ] - ) - def test_fuse_permute_matmul( - self, - _, - lhs_shape, - rhs_shape, - lhs_op=lambda x: x, - rhs_op=lambda x: x, - op=torch.bmm, - ): - """ - Module: permute1 with linear and matmul, permute2 with matmul. - Permute1 permute2 - | | | - linear matmul - Fusion should crete pass fuse_permute_matmul and fuse_permute_linear, and eliminate both - permute node. - """ - - class TestModule(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x, y): - z = lhs_op(x) - bmm = op(z, rhs_op(y)) - linear = self.linear(z) - return (bmm, linear) - - inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)] - self.run_test( - TestModule(3, 6), - inputs, - {trt_transposed_matmul, trt_transposed_linear}, - {acc_ops.permute}, - apply_passes=[fuse_permute_matmul, fuse_permute_linear], - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_remove_duplicate_output_args.py deleted file mode 100644 index 1bb76c6691..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_remove_duplicate_output_args.py +++ /dev/null @@ -1,73 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import logging - -import torch.fx as fx -import torch.nn as nn - -import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup -from torch.testing._internal.common_utils import run_tests, TestCase - -_LOGGER = logging.getLogger(__name__) - - -class TestFx2TrtPasses(TestCase): - def test_remove_duplicate_output_args(self): - class Sub(nn.Module): - def forward(self, x): - return (x, x) - - class Top(nn.Module): - def __init__(self): - super().__init__() - self.a = Sub() - - def forward(self, x): - a_res = self.a(x) - return a_res[0] + a_res[1] - - class Tracer(fx.Tracer): - def is_leaf_module(self, m, qn): - if isinstance(m, Sub): # don't trace into - return True - return False - - top = Top() - ttop = fx.GraphModule(top, Tracer().trace(top), "top") - ttop.a = fx.symbolic_trace(ttop.a) - - name_to_processed_subnet = dedup.remove_duplicate_output_args(ttop, ["a"]) - - ttop(1) # run inference should work - - processed_a = name_to_processed_subnet["a"] - *_, a_output = processed_a.module.graph.nodes - a_output: fx.Node - - ttop_graph_actual = str(ttop.graph).strip() - ttop_graph_expected = """ -graph(): - %x : [#users=1] = placeholder[target=x] - %a : [#users=2] = call_module[target=a](args = (%x,), kwargs = {}) - %getitem : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {}) - %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {}) - %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %getitem_1), kwargs = {}) - return add -""".strip() - assert ( - ttop_graph_expected == ttop_graph_actual - ), f"Unexpected ttop graph: {ttop_graph_actual}" - - ttop_a_graph_actual = str(ttop.a.graph).strip() - ttop_a_graph_expected = """ -graph(): - %x : [#users=1] = placeholder[target=x] - return (x,) -""".strip() - assert ( - ttop_a_graph_expected == ttop_a_graph_actual - ), f"Unexpected ttop.a graph: {ttop_a_graph_actual}" - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py deleted file mode 100644 index 777796a083..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/passes/test_setitem_trt.py +++ /dev/null @@ -1,600 +0,0 @@ -import torch -import torch._dynamo as torchdynamo -from parameterized import parameterized -from torch._dynamo.optimizations import backends -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem -from torch_tensorrt.dynamo.fx_ts_compat.tools.common_fx2trt import AccTestCase - - -class TestTransformSetitem(AccTestCase): - def test_setitem1d(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - y[0:2] = x - return y - - inputs = [torch.randn(2), torch.randn(3)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - def test_setitem1d_c2(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - y[:-1] = x - y[1:] = x - return y - - inputs = [torch.randn(2), torch.randn(3)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - def test_setitem1d_c3(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - y[1] = x - return y - - inputs = [torch.randn(2), torch.randn(3)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (4, 2), (4, 5), 0, 2), - ("c2", (4, 2), (4, 5), 1, 3), - ] - ) - def test_setitem2d_1v(self, name, x_shape, y_shape, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, y_start:y_end] = x - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (4, 2), (8, 2), 0, 2), - ("c2", (4, 2), (8, 2), 1, 3), - ] - ) - def test_setitem2d_1v_ex(self, name, x_shape, y_shape, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[y_start:y_end, :] = x - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (4, 2), (4, 2), 0, 1), - ] - ) - def test_setitem2d_1v_ex2(self, name, x_shape, y_shape, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, y_start:y_end] = x[:, 0] - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (3, 2), (4, 5), 0, 3, 0, 2), - ("c2", (3, 2), (4, 5), 1, 4, 1, 3), - ] - ) - def test_setitem2d_2v(self, name, x_shape, y_shape, x_start, x_end, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[x_start:x_end, y_start:y_end] = x - y = y + 3 - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4), (2, 5, 6), 0, 3, 0, 4), - ("c2", (2, 3, 4), (2, 5, 6), 1, 4, 1, 5), - ] - ) - def test_setitem3d_2v(self, name, x_shape, y_shape, start_1, end_1, start_2, end_2): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, start_1:end_1, start_2:end_2] = x - y = y + 3 - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (3, 2, 4), (5, 2, 6), 0, 3, 0, 4), - ("c2", (3, 2, 4), (5, 2, 6), 1, 4, 1, 5), - ] - ) - def test_setitem3d_2v_ext( - self, name, x_shape, y_shape, start_0, end_0, start_2, end_2 - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, :, start_2:end_2] = x - y = y + 3 - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4), (4, 5, 6), 0, 2, 0, 3, 0, 4), - ("c2", (2, 3, 4), (4, 5, 6), 1, 3, 1, 4, 1, 5), - ] - ) - def test_setitem3d_3v( - self, name, x_shape, y_shape, start_0, end_0, start_1, end_1, start_2, end_2 - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, start_1:end_1, start_2:end_2] = x - y = y + 3 - x = y[start_0:end_0, start_1:end_1, start_2:end_2] - return x - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (2, 3, 6, 7), 0, 4, 0, 5), - ("c2", (2, 3, 4, 5), (2, 3, 6, 7), 1, 5, 1, 6), - ] - ) - def test_setitem4d_2v(self, name, x_shape, y_shape, start_2, end_2, start_3, end_3): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, :, start_2:end_2, start_3:end_3] = x - y = y + 3 - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (2, 5, 4, 7), 0, 3, 0, 5), - ("c2", (2, 3, 4, 5), (2, 5, 4, 7), 1, 4, 1, 6), - ] - ) - def test_setitem4d_2v_ext( - self, name, x_shape, y_shape, start_1, end_1, start_3, end_3 - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, start_1:end_1, :, start_3:end_3] = x - y = y + 3 - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (2, 5, 6, 7), 0, 3, 0, 4, 0, 5), - ("c2", (2, 3, 4, 5), (2, 5, 6, 7), 1, 4, 1, 5, 1, 6), - ] - ) - def test_setitem4d_3v( - self, name, x_shape, y_shape, start_1, end_1, start_2, end_2, start_3, end_3 - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, start_1:end_1, start_2:end_2, start_3:end_3] = x - y = y + 3 - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), - ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), - ] - ) - def test_setitem4d_4v( - self, - name, - x_shape, - y_shape, - start_0, - end_0, - start_1, - end_1, - start_2, - end_2, - start_3, - end_3, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x - y = y + 3 - x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] - return x - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5, 6), (4, 5, 6, 7, 6), 0, 2, 0, 3, 0, 4, 0, 5), - ] - ) - def test_setitem5d_warning( - self, - name, - x_shape, - y_shape, - start_0, - end_0, - start_1, - end_1, - start_2, - end_2, - start_3, - end_3, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3, :] = x - y = y + 3 - x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] - return x - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - def transform_fx(gm, example_inputs): - gm = transform_setitem(gm, example_inputs) - return gm - - optimize_mod = torchdynamo.optimize( - transform_fx, - nopython=True, - )(m) - - optimize_mod(*inputs) - - # test with torchdynamo - def test_setitem1d_trt(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[1] = x - return y - - inputs = [torch.randn(1), torch.randn(3)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - ref_output = m(*inputs) - - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - @parameterized.expand( - [ - ("c1", (4, 2), (4, 5), 0, 2), - ("c2", (4, 2), (4, 5), 1, 3), - ] - ) - def test_setitem2d_1v_trt(self, name, x_shape, y_shape, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, y_start:y_end] = x - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - ref_output = m(*inputs) - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), - ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), - ] - ) - def test_setitem4d_4v_trt( - self, - name, - x_shape, - y_shape, - start_0, - end_0, - start_1, - end_1, - start_2, - end_2, - start_3, - end_3, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x - y = y + 3 - x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] - return x - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - ref_output = m(*inputs) - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py deleted file mode 100644 index fabd94e24c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/quant/test_quant_trt.py +++ /dev/null @@ -1,908 +0,0 @@ -# Owner(s): ["oncall: quantization"] - -import copy -import itertools -import operator -import unittest - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.quantized._reference as nnqr - -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch.ao.quantization import default_qconfig -from torch.ao.quantization.backend_config import ( - get_tensorrt_backend_config_dict, - ObservationType, -) -from torch.ao.quantization.fx.match_utils import MatchAllNode -from torch.ao.quantization.quantize_fx import ( - convert_to_reference_fx, - prepare_fx, - prepare_qat_fx, -) -from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_quantization import ( - NodeSpec as ns, - QuantizationTestCase, -) -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter -from torch_tensorrt.fx import TRTModule -from torch_tensorrt.fx.passes.lower_basic_pass import run_const_fold -from torch_tensorrt.fx.tracer.acc_tracer import acc_ops -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision - - -def lower_to_trt(model, inputs, shape_ranges): - """Lower a quantized model to TensorRT""" - assert len(inputs) == 1, "lower_to_trt only works for one input currently" - model = acc_tracer.trace(model, inputs) # type: ignore[attr-defined] - # TODO: test multiple inputs setting and enable multiple inputs - input_specs = [ - InputTensorSpec( - torch.Size([-1, *inputs[0].shape[1:]]), - torch.float, - shape_ranges=shape_ranges, - has_batch_dim=True, - ) - ] - - interp = TRTInterpreter( - model, input_specs, explicit_batch_dimension=True, explicit_precision=True - ) - result = interp.run(lower_precision=LowerPrecision.INT8) - trt_mod = TRTModule(result.engine, result.input_names, result.output_names) - return trt_mod - - -class TestConvertFxDoNotUse(QuantizationTestCase): - def setUp(self): - super().setUp() - self.trt_qconfig = torch.ao.quantization.QConfig( - activation=torch.ao.quantization.observer.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 - ), - weight=torch.ao.quantization.default_weight_observer, - ) - self.trt_backend_config_dict = get_tensorrt_backend_config_dict() - - def _test_quantized_inputs_outputs( - self, prepare_custom_config_dict, prepare_count_check, convert_count_check - ): - """ - Test the option to have inputs and outputs of the graph quantized - """ - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - self.conv2 = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - return x - - # quantized input, quantized output - m = M() - m.eval() - qconfig_dict = {"": torch.ao.quantization.default_qconfig} - example_inputs = (torch.rand(1, 1, 3, 3),) - mp = torch.ao.quantization.quantize_fx.prepare_fx( - m, - qconfig_dict, - example_inputs, - prepare_custom_config=prepare_custom_config_dict, - ) - self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) - mp(torch.randn(1, 1, 4, 4)) - mq = convert_to_reference_fx(mp, backend_config=self.trt_backend_config_dict) - self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) - - def test_quantized_input_quantized_output(self): - prepare_custom_config_dict = { - "input_quantized_idxs": [0], - "output_quantized_idxs": [0], - } - prepare_count_check = { - ns.call_module(torch.ao.quantization.MinMaxObserver): 2, - } - convert_count_check = { - # output of ref conv1 and output of ref conv2 - ns.call_function(torch.quantize_per_tensor): 2, - # input of ref conv1 and input of ref conv2 - ns.call_method("dequantize"): 2, - } - self._test_quantized_inputs_outputs( - prepare_custom_config_dict, prepare_count_check, convert_count_check - ) - - def test_fp32_input_quantized_output(self): - prepare_custom_config_dict = {"output_quantized_idxs": [0]} - prepare_count_check = { - ns.call_module(torch.ao.quantization.MinMaxObserver): 3, - } - convert_count_check = { - # input, output of conv1 and output of conv2 - ns.call_function(torch.quantize_per_tensor): 3, - # input of conv1, conv2 - ns.call_method("dequantize"): 2, - } - self._test_quantized_inputs_outputs( - prepare_custom_config_dict, prepare_count_check, convert_count_check - ) - - def test_quantized_input_fp32_output(self): - prepare_custom_config_dict = {"input_quantized_idxs": [0]} - prepare_count_check = { - ns.call_module(torch.ao.quantization.MinMaxObserver): 2, - } - convert_count_check = { - # output of conv1, conv2 - ns.call_function(torch.quantize_per_tensor): 2, - # input of ref conv1, input of ref conv2, final output - ns.call_method("dequantize"): 3, - } - self._test_quantized_inputs_outputs( - prepare_custom_config_dict, prepare_count_check, convert_count_check - ) - - def test_fp32_input_fp32_output(self): - prepare_custom_config_dict = {} - prepare_count_check = { - ns.call_module(torch.ao.quantization.MinMaxObserver): 3, - } - convert_count_check = { - ns.call_function(torch.quantize_per_tensor): 3, - ns.call_method("dequantize"): 3, - } - self._test_quantized_inputs_outputs( - prepare_custom_config_dict, prepare_count_check, convert_count_check - ) - - def _test_standalone_module( - self, - interface_config, - prepare_count_check, - standalone_prepare_count_check, - convert_count_check, - standalone_convert_count_check, - qconfig=None, - backend_config_dict=None, - ): - """Test standalone module with different quantized input/quantized output - configurations - """ - - class StandaloneModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - return self.conv(x) - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) - self.standalone = StandaloneModule() - - def forward(self, x): - x = self.conv(x) - x = self.standalone(x) - return x - - class RefM(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - self.conv2 = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - return x - - if backend_config_dict is None: - backend_config_dict = self.trt_backend_config_dict - if qconfig is None: - qconfig = self.trt_qconfig - - data = torch.randn(1, 1, 1, 1) - # instantiate M and RefM and align the parameters - original_m = M().eval() - original_ref_m = RefM().eval() - original_ref_m.conv1.weight = torch.nn.Parameter( - original_m.conv.weight.detach() - ) - original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) - original_ref_m.conv2.weight = torch.nn.Parameter( - original_m.standalone.conv.weight.detach() - ) - original_ref_m.conv2.bias = torch.nn.Parameter( - original_m.standalone.conv.bias.detach() - ) - - sm_example_inputs = (data,) - prepare_config = { - "standalone_module_name": [ - ( - "standalone", - None, - sm_example_inputs, - interface_config, - backend_config_dict, - ) - ] - } - - original_m_copy = copy.deepcopy(original_m) - original_ref_m_copy = copy.deepcopy(original_ref_m) - - qconfig_dict = {"": qconfig} - example_inputs = (data,) - # check prepared model - m = prepare_fx( - original_m_copy, - qconfig_dict, - example_inputs, - prepare_custom_config=prepare_config, - backend_config=backend_config_dict, - ) - # calibration - m(data) - self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) - self.checkGraphModuleNodes( - m.standalone, expected_node_occurrence=standalone_prepare_count_check - ) - - # check converted/quantized model - m = convert_to_reference_fx(m, backend_config=backend_config_dict) - self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) - self.checkGraphModuleNodes( - m.standalone, expected_node_occurrence=standalone_convert_count_check - ) - res = m(data) - - # quantize the reference model - ref_m = prepare_fx( - original_ref_m_copy, - qconfig_dict, - example_inputs, - backend_config=backend_config_dict, - ) - ref_m(data) - ref_m = convert_to_reference_fx(ref_m, backend_config=backend_config_dict) - ref_res = ref_m(data) - self.assertEqual(res, ref_res) - - def test_standalone_module_float_interface(self): - float_interface_config = { - "input_quantized_idxs": [], # float input - "output_quantized_idxs": [], # float output - } - interface_config = float_interface_config - # input and output of first conv, observer for standalone module - # will be inserted in the standalone module itself - prepare_count_check = { - ns.call_module(torch.ao.quantization.HistogramObserver): 2 - } - # for input and output of conv in the standalone module - standalone_prepare_count_check = { - ns.call_module(torch.ao.quantization.HistogramObserver): 2 - } - convert_count_check = { - # input and output of reference conv - ns.call_function(torch.quantize_per_tensor): 2, - ns.call_module(nnqr.Conv2d): 1, - ns.call_method("dequantize"): 2, - } - standalone_convert_count_check = { - # standalone module will take float as input and output - # so we'll see quantize and dequantize in the modoule - ns.call_function(torch.quantize_per_tensor): 2, - ns.call_module(nnqr.Conv2d): 1, - ns.call_method("dequantize"): 2, - } - self._test_standalone_module( - interface_config, - prepare_count_check, - standalone_prepare_count_check, - convert_count_check, - standalone_convert_count_check, - ) - - def test_standalone_module_quantized_interface(self): - quantized_interface_config = { - "input_quantized_idxs": [0], # quantized input - "output_quantized_idxs": [0], # quantized output - } - interface_config = quantized_interface_config - # TODO: input_quantized_idxs only supports quint8, we can remove this - # custom_backend_config_dict after - # the `input_quantized_idxs` supports more complicated - # configurations, as a first step we can change it to use a dictionary from - # index to dtype - qconfig = torch.ao.quantization.QConfig( - activation=torch.ao.quantization.observer.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 - ), - weight=torch.ao.quantization.default_weight_observer, - ) - weighted_op_quint8_dtype_config = { - # optional, input activation dtype - "input_dtype": torch.quint8, - # optional, weight dtype - "weight_dtype": torch.qint8, - # optional, bias dtype - "bias_dtype": torch.float, - # optional, output activation dtype - "output_dtype": torch.quint8, - } - conv_module_config = { - "pattern": torch.nn.Conv2d, - "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, - "dtype_configs": [ - weighted_op_quint8_dtype_config, - ], - "root_module": torch.nn.Conv2d, - "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, - } - custom_backend_config_dict = {"configs": [conv_module_config]} - # observer for input and output of first conv - prepare_count_check = { - ns.call_module(torch.ao.quantization.HistogramObserver): 2 - } - # for output of conv in the standalone module - standalone_prepare_count_check = { - ns.call_module(torch.ao.quantization.HistogramObserver): 1 - } - convert_count_check = { - # quantizing input/output for reference conv - ns.call_function(torch.quantize_per_tensor): 2, - ns.call_module(nnqr.Conv2d): 1, - # dequantize the input of reference conv and - # dequantizing output of standalone module - ns.call_method("dequantize"): 2, - } - standalone_convert_count_check = { - # quantization of input happens in parent module - # quantization of output happens in the standalone module - ns.call_function(torch.quantize_per_tensor): 1, - ns.call_module(nnqr.Conv2d): 1, - # dequantization of input happens in the standalone module - # dequantization for output happens in parent module - ns.call_method("dequantize"): 1, - } - self._test_standalone_module( - interface_config, - prepare_count_check, - standalone_prepare_count_check, - convert_count_check, - standalone_convert_count_check, - qconfig=qconfig, - backend_config_dict=custom_backend_config_dict, - ) - - -@unittest.skipIf(not TEST_CUDA, "gpu is not available.") -class TestQuantizeFxTRTOps(QuantizationTestCase): - """Test TensorRT operator support""" - - def setUp(self): - super().setUp() - self.trt_qconfig = torch.ao.quantization.QConfig( - activation=torch.ao.quantization.observer.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 - ), - weight=torch.ao.quantization.default_weight_observer, - ) - self.trt_backend_config_dict = get_tensorrt_backend_config_dict() - - def _test_module( - self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False - ): - """ - Args: - m: the float module we want to test - inputs: list of inputs for the module - shape_ranges: a list of shape_range, where every shape_range is a tuple of - three tuples - ((min_input_shape), (optimized_input_shape), (max_input_shape)). - Each shape_range is used to populate a TensorRT optimization profile. - e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize - for (25, 224) because it's the most common input shape, then we set shape_ranges to - ((1, 224), (25, 225), (100, 224)) - no_prepare: node occurrence after prepare - no_convert: node occurrence after convert - """ - if is_qat: - m = m.train() - prepare = prepare_qat_fx - else: - m = m.eval() - prepare = prepare_fx - example_inputs = tuple(inputs) - prepared = prepare( - m, - {"": self.trt_qconfig}, - example_inputs, - backend_config=self.trt_backend_config_dict, - ) - self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare) - # calibration - prepared(*inputs) - quantized = convert_to_reference_fx( - prepared, - backend_config=self.trt_backend_config_dict, - ) - self.checkGraphModuleNodes(quantized, expected_node_occurrence=no_convert) - # lower to trt - trt_mod = lower_to_trt(quantized, inputs, shape_ranges) - inputs_cuda = [i.cuda() for i in inputs] - # make sure it runs - trt_mod(*inputs_cuda) - - def test_conv_relu_module(self): - conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} - - conv1d_input = torch.rand(1, 3, 10) - conv2d_input = torch.rand(1, 3, 10, 10) - conv3d_input = torch.rand(1, 3, 10, 10, 10) - conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input} - - class ConvNdModule(torch.nn.Module): - def __init__(self, dim, has_relu=False, f_relu=False): - super().__init__() - self.conv = conv_module[dim](3, 3, 3).float() - if has_relu: - if f_relu: - self.relu = F.relu - else: - self.relu = torch.nn.ReLU() - else: - self.relu = torch.nn.Identity() - - def forward(self, x): - return self.relu(self.conv(x)) - - # just testing conv2d since conv1d and conv3d are not supported in fx2trt - for dim, has_relu, f_relu, is_qat in itertools.product( - [1, 2], [True, False], [True, False], [True, False] - ): - # when has_relu=False, we have torch.nn.Identity, which would introduce - # extra quant-dequat pair - no_convert = { - ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), - ns.call_method("dequantize"): 2 + int(not has_relu), - } - self._test_module( - ConvNdModule(dim, has_relu, f_relu), - [conv_input[dim]], - [ - ( - (1, *conv_input[dim].shape[1:]), - (5, *conv_input[dim].shape[1:]), - (10, *conv_input[dim].shape[1:]), - ) - ], - no_convert=no_convert, - is_qat=is_qat, - ) - - def test_linear_relu_module(self): - class LinearModule(torch.nn.Module): - def __init__(self, has_relu=False, f_relu=False): - super().__init__() - self.linear = torch.nn.Linear(5, 10).float() - if has_relu: - if f_relu: - self.relu = F.relu - else: - self.relu = torch.nn.ReLU() - else: - self.relu = torch.nn.Identity() - - def forward(self, x): - return self.relu(self.linear(x)) - - linear_input = torch.rand(8, 5) - - shape_ranges = [((1, 5), (5, 5), (10, 5))] - for has_relu, f_relu, is_qat in itertools.product( - [True, False], [True, False], [True, False] - ): - # when has_relu=False, we have torch.nn.Identity, which would introduce - # extra quant-dequat pair - no_convert = { - ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), - ns.call_method("dequantize"): 2 + int(not has_relu), - } - self._test_module( - LinearModule(has_relu, f_relu), - [linear_input], - shape_ranges, - no_convert=no_convert, - is_qat=is_qat, - ) - - def test_ops(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - self.linear = torch.nn.Linear(5, 5) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.conv(x) - x = self.linear(x) - x = x + 3 - x = self.relu(x) - x = x + 6 - return x - - m = M().eval() - example_inputs = (torch.rand(1, 3, 5, 5),) - m = prepare_fx( - m, - {"": self.trt_qconfig}, - example_inputs, - backend_config=self.trt_backend_config_dict, - ) - m = convert_to_reference_fx(m, backend_config=self.trt_backend_config_dict) - expected_occurrence = { - ns.call_function(torch.quantize_per_tensor): 5, - ns.call_method("dequantize"): 5, - ns.call_module(torch.nn.quantized._reference.Linear): 1, - ns.call_module(torch.nn.quantized._reference.Conv2d): 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) - - def test_unsupported_qconfig(self): - """Check that we won't quantize the model if the qconfig is not supported""" - - class LinearModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 10) - - def forward(self, x): - return self.linear(x) - - linear_module_input = torch.rand(8, 5) - - m = LinearModule().eval() - trt_unsupported_qconfig = default_qconfig - example_inputs = (torch.rand(1, 5),) - prepared = prepare_fx( - m, - {"": trt_unsupported_qconfig}, - example_inputs=example_inputs, - backend_config=self.trt_backend_config_dict, - ) - # calibration - prepared(linear_module_input) - quantized = convert_to_reference_fx( - prepared, - backend_config=self.trt_backend_config_dict, - ) - node_occurrence = { - ns.call_function(torch.quantize_per_tensor): 0, - ns.call_method("dequantize"): 0, - ns.call_module(torch.nn.Linear): 1, - ns.call_module(torch.nn.quantized._reference.Linear): 0, - } - # check model is not quantized - self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) - - def test_cat(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.cat([x, x], 1) - - m = M().eval() - example_inputs = (torch.rand(2, 2),) - prepared = prepare_fx( - m, - {"": self.trt_qconfig}, - example_inputs, - backend_config=self.trt_backend_config_dict, - ) - self.assertTrue(len(dict(prepared.named_children())) == 1) - quantized = convert_to_reference_fx( - prepared, - backend_config=self.trt_backend_config_dict, - ) - node_occurrence = { - ns.call_function(torch.quantize_per_tensor): 2, - ns.call_function(torch.cat): 1, - ns.call_method("dequantize"): 2, - } - self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) - - def test_addmm(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = torch.randn(5, 5) - self.bias = torch.randn(5) - - def forward(self, x): - return torch.addmm(self.bias, x, self.weight) - - m = M().eval() - example_inputs = (torch.rand(1, 5),) - prepared = prepare_fx( - m, - {"": self.trt_qconfig}, - example_inputs, - backend_config=self.trt_backend_config_dict, - ) - node_occurrence = { - # weight - ns.call_module(torch.ao.quantization.MinMaxObserver): 1, - # activation - ns.call_module(torch.ao.quantization.HistogramObserver): 2, - } - self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) - quantized = convert_to_reference_fx( - prepared, - backend_config=self.trt_backend_config_dict, - ) - node_occurrence = { - # input activation, output activation and weight - ns.call_function(torch.quantize_per_tensor): 3, - ns.call_function(torch.addmm): 1, - ns.call_method("dequantize"): 3, - } - self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) - - @unittest.skip( - "This is not supported yet, we can enable the test after it's supported" - ) - def test_conv_add(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x, y): - return self.conv(x) + y - - weighted_op_qint8_dtype_config = { - # optional, input activation dtype - "input_dtype": torch.qint8, - # optional, weight dtype - "weight_dtype": torch.qint8, - # optional, bias dtype - "bias_dtype": torch.float, - # optional, output activation dtype - "output_dtype": torch.qint8, - } - - def conv_add_root_node_getter(pattern): - (_, conv, _) = pattern - return conv - - def conv_add_extra_inputs_getter(pattern): - _, _, extra_input = pattern - return [extra_input] - - conv_add_config = { - "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, - "dtype_configs": [ - weighted_op_qint8_dtype_config, - ], - "root_node_getter": conv_add_root_node_getter, - "extra_inputs_getter": conv_add_extra_inputs_getter, - "root_module": torch.nn.Conv2d, - "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, - } - - if torch.__version__.startswith("1"): - conv_add_config["pattern"] = (operator.add, torch.nn.Conv2d, MatchAllNode) - else: - conv_add_config["pattern_complex_format"] = ( - operator.add, - torch.nn.Conv2d, - MatchAllNode, - ) - - m = M().eval() - modified_backend_config_dict = copy.deepcopy(self.trt_backend_config_dict) - modified_backend_config_dict["configs"].insert(0, conv_add_config) - example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1)) - m = prepare_fx( - m, - {"": self.trt_qconfig}, - example_inputs, - backend_config=modified_backend_config_dict, - ) - node_occurrence = { - ns.call_module(torch.ao.quantization.HistogramObserver): 3, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - m = convert_to_reference_fx(m, backend_config=modified_backend_config_dict) - node_occurrence = { - ns.call_function(torch.quantize_per_tensor): 3, - ns.call_method("dequantize"): 3, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - - def test_conv_add_standalone_module(self): - class Standalone(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - self.relu = torch.nn.ReLU() - - def forward(self, x, y): - return self.relu(self.conv(x) + y) - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - self.standalone = Standalone() - - def forward(self, x): - y = self.conv(x) - return self.standalone(x, y) - - from torch.ao.quantization.backend_config import ObservationType - - weighted_op_quint8_dtype_config = { - # optional, input activation dtype - # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs - # are more flexible - "input_dtype": torch.quint8, - # optional, weight dtype - "weight_dtype": torch.qint8, - # optional, bias dtype - "bias_dtype": torch.float, - # optional, output activation dtype - "output_dtype": torch.quint8, - } - - conv_add_config = { - "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, - "dtype_configs": [ - weighted_op_quint8_dtype_config, - ], - "root_module": torch.nn.Conv2d, - # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, - } - - if torch.__version__.startswith("1"): - conv_add_config["pattern"] = ( - torch.nn.ReLU, - (operator.add, torch.nn.Conv2d, MatchAllNode), - ) - else: - conv_add_config["pattern_complex_format"] = ( - torch.nn.ReLU, - (operator.add, torch.nn.Conv2d, MatchAllNode), - ) - - conv_config = { - "pattern": torch.nn.Conv2d, - "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, - "dtype_configs": [ - weighted_op_quint8_dtype_config, - ], - "root_module": torch.nn.Conv2d, - # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, - } - - m = M().eval() - backend_config_dict = { - "configs": [ - conv_add_config, - conv_config, - ] - } - sm_example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1)) - prepare_custom_config_dict = { - "standalone_module_name": [ - ( - "standalone", - None, - sm_example_inputs, - {"input_quantized_idxs": [0, 1]}, - None, - ) - ] - } - # TODO: use self.trt_qconfig after input_quantized_idxs and output_quantized_idxs - # are more flexible - qconfig = torch.ao.quantization.QConfig( - activation=torch.ao.quantization.observer.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 - ), - weight=torch.ao.quantization.default_weight_observer, - ) - example_inputs = (torch.rand(1, 3, 5, 5),) - m = prepare_fx( - m, - {"": qconfig}, - example_inputs, - prepare_custom_config=prepare_custom_config_dict, - backend_config=backend_config_dict, - ) - node_occurrence = { - # for input and output of conv, where input is used twice, once in conv and - # once in standalone module - ns.call_module(torch.ao.quantization.HistogramObserver): 2, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - standalone_node_occurrence = { - # output of the standalone module - ns.call_module(torch.ao.quantization.HistogramObserver): 1, - } - self.checkGraphModuleNodes( - m.standalone, expected_node_occurrence=standalone_node_occurrence - ) - m = convert_to_reference_fx(m, backend_config=backend_config_dict) - node_occurrence = { - # two inputs for standalone module - ns.call_function(torch.quantize_per_tensor): 2, - ns.call_module(nn.Conv2d): 1, - ns.call_method("dequantize"): 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - standalone_node_occurrence = { - # output for the pattern in standalone module - ns.call_function(torch.quantize_per_tensor): 1, - ns.call_module(nn.Conv2d): 1, - ns.call_module(torch.nn.ReLU): 1, - # two input and one output for the pattern in standalone module - ns.call_method("dequantize"): 3, - } - self.checkGraphModuleNodes( - m.standalone, expected_node_occurrence=standalone_node_occurrence - ) - - def test_quant_dequant_not_fold(self): - class LinearModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 10).float() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(self.linear(x)) - - model = LinearModule().eval() - inputs = [torch.rand(8, 5)] - example_inputs = tuple(inputs) - prepared = prepare_fx( - model, - {"": self.trt_qconfig}, - example_inputs, - backend_config=self.trt_backend_config_dict, - ) - quantized = convert_to_reference_fx( - prepared, - backend_config=self.trt_backend_config_dict, - ) - - model = acc_tracer.trace(quantized, inputs) - model = run_const_fold(model) - - no_const = { - ns.call_function(acc_ops.quantize_per_tensor): 3, - ns.call_function(acc_ops.dequantize): 3, - } - self.checkGraphModuleNodes(model, expected_node_occurrence=no_const) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py deleted file mode 100644 index 209181137e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tools/test_model_packager.py +++ /dev/null @@ -1,56 +0,0 @@ -import io -import unittest - -import torch -import torch.fx -from torch import nn -from torch.package import PackageImporter -from torch_tensorrt.dynamo.fx_ts_compat.tools.model_packager import ( - generate_standalone_repro, - ModelPackager, -) - - -class TestModel(nn.Module): - def __init__(self): - super().__init__() - self.a = torch.nn.Module() - self.b = torch.nn.Module() - self.a.weights = torch.nn.Parameter(torch.randn(1, 2)) - self.b.weights = torch.nn.Parameter( - torch.randn( - 1, - ) - ) - - def forward(self, x): - return x + self.a.weights + self.b.weights - - -class ModelPackagerTest(unittest.TestCase): - def test_text_repro_gen(self): - model = torch.fx.symbolic_trace(TestModel().eval()) - inputs = [torch.randn(1)] - _ = model(*inputs) - - string_io = io.StringIO() - generate_standalone_repro(model, string_io, "\n# hello") - string_io.seek(0) - exec(string_io.read()) - exported_model = locals()["ExportedModule"]() - _ = exported_model(*inputs) - - def test_package_model(self): - model = torch.fx.symbolic_trace(TestModel().eval()) - inputs = [torch.randn(1)] - _ = model(*inputs) - bytesIO = io.BytesIO() - ModelPackager.package_model(model, inputs, bytesIO) - bytesIO.seek(0) - pi = PackageImporter(bytesIO) - reload_model = pi.load_pickle("repro", "model") - reload_inputs = pi.load_pickle("repro", "inputs") - - torch.testing.assert_close(model(*inputs), reload_model(*reload_inputs)) - keys = dict(reload_model.named_children()).keys() - self.assertEqual(keys, {"_holder"}) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_shape_prop.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_shape_prop.py deleted file mode 100644 index a2f842b722..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_shape_prop.py +++ /dev/null @@ -1,98 +0,0 @@ -# Owner(s): ["oncall: fx"] - -import operator -import unittest - -import torch - -import torch_tensorrt.fx.tracer.acc_tracer.acc_shape_prop as acc_shape_prop -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from parameterized import param, parameterized - -torch.manual_seed(0) - - -class AccShapePropTest(unittest.TestCase): - @parameterized.expand( - [ - param("fp32", dtype=torch.float32), - param("fp16", dtype=torch.float16), - ] - ) - def test_basic(self, _, dtype): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(3, 4)) - self.submod = torch.nn.Linear(4, 4) - - def forward(self, x): - return torch.neg(self.submod(x.relu() + self.attr)) - - m = TestModule() - if dtype == torch.float16: - m.half() - gm = acc_tracer.rewriter_base_trace(m, None, None) - inp = torch.rand(3, 4, dtype=dtype) - acc_shape_prop.AccShapeProp(gm).propagate(inp) - - for node in gm.graph.nodes: - self.assertEqual(node.meta["tensor_meta"].dtype, dtype) - - def test_mutli_dtype(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - return torch.relu(x * 2), torch.sigmoid(y + y) - - m = TestModule() - gm = acc_tracer.rewriter_base_trace(m, None, None) - # Note: One input is fp32, the other fp16. - x, y = torch.rand(3, 4), torch.rand(3, 4, dtype=torch.float16) - acc_shape_prop.AccShapeProp(gm).propagate(x, y) - - for node in gm.graph.nodes: - if (node.op == "placeholder" and node.target == "x") or ( - node.op == "call_function" and node.target in {operator.mul, torch.relu} - ): - self.assertEqual(node.meta["tensor_meta"].dtype, torch.float32) - elif node.op != "output": - self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) - else: - self.assertEqual(node.meta["tensor_meta"][0].dtype, torch.float32) - self.assertEqual(node.meta["tensor_meta"][1].dtype, torch.float16) - - def test_to_dtype(self): - class TestModule(torch.nn.Module): - def forward(self, x): - return x.to(dtype=torch.float32).to(dtype=torch.float16) - - m = TestModule() - gm = acc_tracer.rewriter_base_trace(m, None, None) - x = torch.rand(3, 4, dtype=torch.float16) - acc_shape_prop.AccShapeProp(gm).propagate(x) - ph = None - for node in gm.graph.nodes: - if node.op == "placeholder": - ph = node - self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) - elif node.all_input_nodes == [ph]: - self.assertEqual(node.meta["tensor_meta"].dtype, torch.float32) - else: - self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) - - def test_split(self): - class TestModule(torch.nn.Module): - def forward(self, x): - s = torch.tensor_split(x, 2) - return s[0].relu(), s[1].sigmoid() - - m = TestModule() - gm = acc_tracer.rewriter_base_trace(m, None, None) - x = torch.rand(2, 4, dtype=torch.float16) - acc_shape_prop.AccShapeProp(gm).propagate(x) - for node in gm.graph.nodes: - if node.target == torch.tensor_split or node.op == "output": - self.assertEqual(node.meta["tensor_meta"][0].dtype, torch.float16) - self.assertEqual(node.meta["tensor_meta"][1].dtype, torch.float16) - else: - self.assertEqual(node.meta["tensor_meta"].dtype, torch.float16) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_tracer.py deleted file mode 100644 index 633359127f..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_acc_tracer.py +++ /dev/null @@ -1,2801 +0,0 @@ -# Owner(s): ["oncall: fx"] -import logging -import operator -import unittest -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple - -import numpy as np -import torch -import torch.nn as nn - -import torch_tensorrt.fx.tracer.acc_tracer.acc_normalizer as acc_normalizer -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -import torchvision -from parameterized import param, parameterized - -torch.manual_seed(0) - -_LOGGER: logging.Logger = logging.getLogger(__name__) - -torch.fx.wrap("len") - - -class AccTracerTest(unittest.TestCase): - def _make_model_unit_test( - self, - model, - *args, - input_shape=None, - enable_allclose=False, - **kwargs, - ): - """ - Test that the model can be traced correctly and is producing correct - result. - """ - if input_shape is None: - input_shape = [1, 3, 224, 224] - input = torch.randn(input_shape) - traced = acc_tracer.trace(model, [input]) - if enable_allclose: - torch.testing.assert_close(model(input), traced(input)) - else: - self.assertTrue(torch.equal(model(input), traced(input))) - traced_again = acc_tracer.trace(traced, [input]) - if enable_allclose: - torch.testing.assert_close(model(input), traced_again(input)) - else: - self.assertTrue(torch.equal(model(input), traced_again(input))) - - def _make_acc_op_function_test( - self, - acc_op: Callable, - torch_op, - *args, - input_shape=(2, 3), - validate_same_kwargs=True, - enable_allclose=False, - **kwargs, - ): - """ - Test that acc_op is traced somewhat. - """ - - class TestModule(torch.nn.Module): - def __init__(self, torch_op, args, kwargs): - super().__init__() - self._torch_op = torch_op - self._args = args - self._kwargs = kwargs - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self._torch_op(a, *self._args, **self._kwargs) - - m = TestModule(torch_op, args, kwargs) - m.eval() - a = torch.randn(*input_shape) - traced = acc_tracer.trace(m, [a]) - ph_a = acc_op_node = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_op) - self.assertEqual(node.kwargs["input"], ph_a) - if validate_same_kwargs: - for key, value in kwargs.items(): - self.assertEqual(node.kwargs[key], value) - acc_op_node = node - elif node.op == "output": - if acc_op is None: - # If we expect no new acc_op after graph building - # and found we have only output in traced graph - continue - self.assertEqual(acc_op_node, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref_outputs = m(a) - outputs = traced(a) - traced_again = acc_tracer.trace(traced, [a]) - outputs_again = traced_again(a) - if isinstance(ref_outputs, torch.Tensor): - ref_outputs = [ref_outputs] - outputs = [outputs] - outputs_again = [outputs_again] - - for ref_output, output, output_again in zip( - ref_outputs, outputs, outputs_again - ): - if enable_allclose: - torch.testing.assert_close( - torch.nan_to_num(ref_output), torch.nan_to_num(output) - ) - torch.testing.assert_close( - torch.nan_to_num(ref_output), torch.nan_to_num(output_again) - ) - else: - self.assertTrue( - torch.equal(torch.nan_to_num(ref_output), torch.nan_to_num(output)) - ) - self.assertTrue( - torch.equal( - torch.nan_to_num(ref_output), torch.nan_to_num(output_again) - ) - ) - - def test_sum(self): - self._make_acc_op_function_test(acc_ops.sum, torch.sum) - self._make_acc_op_function_test(acc_ops.sum, torch.sum, dim=(1,), keepdim=True) - - def test_prod(self): - self._make_acc_op_function_test(acc_ops.prod, torch.prod) - self._make_acc_op_function_test(acc_ops.prod, torch.prod, dim=1, keepdim=True) - - def test_mean(self): - self._make_acc_op_function_test(acc_ops.mean, torch.mean) - self._make_acc_op_function_test( - acc_ops.mean, torch.mean, dim=(1,), keepdim=True - ) - - def test_pad(self): - self._make_acc_op_function_test( - acc_ops.pad, torch.nn.functional.pad, pad=(2, 0) - ) - - def test_max(self): - def torch_max(x, *args, **kwargs): - return x.max(*args, **kwargs) - - self._make_acc_op_function_test(acc_ops.max_full_reduce, torch_max) - self._make_acc_op_function_test( - acc_ops.max_dim_reduce, torch_max, dim=1, keepdim=True - ) - self._make_acc_op_function_test( - acc_ops.max_dim_reduce, torch_max, input_shape=(1, 4), dim=1, keepdim=True - ) - self._make_acc_op_function_test( - acc_ops.max_dim_reduce, torch_max, input_shape=(3, 4, 3), dim=2 - ) - - @parameterized.expand( - [ - param("max_maximum", orig_op=torch.max, expected_op=acc_ops.maximum), - param( - "maximum_maximum", orig_op=torch.maximum, expected_op=acc_ops.maximum - ), - param("min_minimum", orig_op=torch.min, expected_op=acc_ops.minimum), - param( - "minimum_minimum", orig_op=torch.minimum, expected_op=acc_ops.minimum - ), - ] - ) - def test_maximum_minimum(self, _: str, orig_op, expected_op): - class TestModule(torch.nn.Module): - def __init__(self, orig_op): - super().__init__() - self.orig_op = orig_op - - def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor: - return self.orig_op(input, other) - - m = TestModule(orig_op) - input, other = torch.randn(2, 2), torch.randn(2, 2) - traced = acc_tracer.trace(m, [input, other]) - - ph_in = ph_oth = mxm = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "other": - ph_oth = node - else: - self.assertTrue(str(node.target) == "input") - ph_in = node - elif node.op == "call_function": - if node.target == expected_op: - self.assertEqual(node.kwargs["input"], ph_in) - self.assertEqual(node.kwargs["other"], ph_oth) - mxm = node - elif node.op == "output": - self.assertEqual(mxm, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input, other), traced(input, other))) - - def test_conv(self): - """ - Test that a conv is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(8, 7, 3, stride=2) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.randn(3, 8, 10, 10) - traced = acc_tracer.trace(m, [input]) - - ph = weight_attr = bias_attr = conv = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv.bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.conv2d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - self.assertEqual(node.kwargs["stride"], (2, 2)) - self.assertEqual(node.kwargs["padding"], (0, 0)) - self.assertEqual(node.kwargs["dilation"], (1, 1)) - self.assertEqual(node.kwargs["groups"], 1) - conv = node - elif node.op == "output": - self.assertEqual(conv, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_quantized_conv2d(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.quantized.Conv2d(3, 3, 1) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.quantize_per_tensor( - torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 - ) - traced = acc_tracer.trace(m, [input]) - _LOGGER.info(traced.graph) - ph = weight_attr = bias_attr = conv = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv_weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv_bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.quantized_conv2d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - conv = node - elif node.op == "output": - self.assertEqual(conv, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_quantized_convrelu2d(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.intrinsic.quantized.ConvReLU2d(3, 3, 1) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.quantize_per_tensor( - torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 - ) - traced = acc_tracer.trace(m, [input]) - ph = weight_attr = bias_attr = conv = relu = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv_weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv_bias": - bias_attr = node - elif node.op == "call_function" and node.target == acc_ops.quantized_conv2d: - self.assertEqual(node.target, acc_ops.quantized_conv2d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - conv = node - elif node.op == "call_function" and node.target == acc_ops.relu: - self.assertEqual(node.target, acc_ops.relu) - self.assertEqual(node.kwargs["input"], conv) - relu = node - elif node.op == "output": - self.assertEqual(relu, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_conv1d(self): - """ - Test that a conv is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv1d(8, 7, 3, stride=2) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.randn(3, 8, 8) - traced = acc_tracer.trace(m, [input]) - - ph = weight_attr = bias_attr = conv = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv.bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.conv1d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - self.assertEqual(node.kwargs["stride"], (2,)) - self.assertEqual(node.kwargs["padding"], (0,)) - self.assertEqual(node.kwargs["dilation"], (1,)) - self.assertEqual(node.kwargs["groups"], 1) - conv = node - elif node.op == "output": - self.assertEqual(conv, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_conv3d(self): - """ - Test that a conv is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv3d(8, 7, 3, stride=2) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.randn(3, 8, 8, 10, 10) - traced = acc_tracer.trace(m, [input]) - - ph = weight_attr = bias_attr = conv = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv.bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.conv3d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - self.assertEqual(node.kwargs["stride"], (2, 2, 2)) - self.assertEqual(node.kwargs["padding"], (0, 0, 0)) - self.assertEqual(node.kwargs["dilation"], (1, 1, 1)) - self.assertEqual(node.kwargs["groups"], 1) - conv = node - elif node.op == "output": - self.assertEqual(conv, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_conv_transpose2d(self): - """ - Test that a conv_transpose2d is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.ConvTranspose2d(8, 7, 3, stride=2) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.randn(3, 8, 10, 10) - traced = acc_tracer.trace(m, [input]) - - ph = weight_attr = bias_attr = conv = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv.bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.conv_transpose2d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - self.assertEqual(node.kwargs["stride"], (2, 2)) - self.assertEqual(node.kwargs["padding"], (0, 0)) - self.assertEqual(node.kwargs["output_padding"], (0, 0)) - self.assertEqual(node.kwargs["groups"], 1) - self.assertEqual(node.kwargs["dilation"], (1, 1)) - conv = node - elif node.op == "output": - self.assertEqual(conv, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_conv_transpose3d(self): - """ - Test that a conv_transpose3d is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.ConvTranspose3d(8, 7, 3, stride=2) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.conv(a) - - m = TestModule() - input = torch.randn(3, 8, 8, 10, 10) - traced = acc_tracer.trace(m, [input]) - - ph = weight_attr = bias_attr = conv = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "conv.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "conv.bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.conv_transpose3d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - self.assertEqual(node.kwargs["stride"], (2, 2, 2)) - self.assertEqual(node.kwargs["padding"], (0, 0, 0)) - self.assertEqual(node.kwargs["output_padding"], (0, 0, 0)) - self.assertEqual(node.kwargs["dilation"], (1, 1, 1)) - self.assertEqual(node.kwargs["groups"], 1) - conv = node - elif node.op == "output": - self.assertEqual(conv, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_embedding_bag(self): - """ - Test that an embedding_bag is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.eb = nn.EmbeddingBag(10, 3, mode="sum", include_last_offset=True) - - def forward(self, inp: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: - return self.eb(inp, offsets) - - m = TestModule() - inp = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) - offsets = torch.LongTensor([0, 4]) - traced = acc_tracer.trace(m, [inp, offsets]) - - inp_node = offsets_node = weight_attr = eb_node = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "inp": - inp_node = node - elif str(node.target) == "offsets": - offsets_node = node - else: - self.fail(f"Unexpected placeholder {node.target}.") - continue - elif node.op == "get_attr" and node.target == "eb.weight": - weight_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.embedding_bag) - # Note: Normalization called from acc_tracer means we use all kwargs. - self.assertEqual(node.kwargs["input"], inp_node) - self.assertEqual(node.kwargs["offsets"], offsets_node) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["mode"], "sum") - self.assertEqual(node.kwargs["include_last_offset"], True) - # The rest of these were unspecified, so verify they fell back - # to their respective default values thanks to normalization. - self.assertEqual(node.kwargs["max_norm"], None) - self.assertEqual(node.kwargs["norm_type"], 2.0) - self.assertEqual(node.kwargs["scale_grad_by_freq"], False) - self.assertEqual(node.kwargs["sparse"], False) - self.assertEqual(node.kwargs["per_sample_weights"], None) - eb_node = node - elif node.op == "output": - self.assertEqual(eb_node, node.args[0]) - - self.assertTrue(torch.equal(m(inp, offsets), traced(inp, offsets))) - - def test_embedding_bag_byte_and_4bit_rowwise_offsets(self): - """ - Test that 4 bit quantized embedding_bag is traced as expected. - """ - - class TestModule(nn.Module): - def __init__( - self, - op, - q_weights, - per_index_weights, - ): - super().__init__() - self.emb = op - self.q_weights = q_weights - self.per_index_weights = per_index_weights - - def forward( - self, - indices, - offsets, - ): - return self.emb( - self.q_weights, - indices, - offsets, - mode=0, - per_sample_weights=self.per_index_weights, - include_last_offset=True, - ) - - def run_embedding_bag_test(is_4bit, use_weights): - # generate random indices, offsets, and weights. - num_embeddings = 16 - embedding_dim = 32 - num_lengths = 10 - - weights = torch.from_numpy( - (np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype( - np.float32 - ) - ) - q_weights = ( - torch.ops.quantized.embedding_bag_4bit_prepack(weights) - if is_4bit - else torch.ops.quantized.embedding_bag_byte_prepack(weights) - ) - np_lengths = np.random.randint(0, num_lengths, size=10).astype(np.int32) - - num_lengths = np.sum(np_lengths) - indices = torch.from_numpy( - np.random.randint(low=0, high=num_embeddings, size=num_lengths) - ).int() - - lengths = torch.from_numpy(np_lengths) - offsets = torch.cat([torch.zeros([1]), torch.cumsum(lengths, 0)]).int() - - weights = torch.randint(low=0, high=4, size=indices.size()) - per_sample_weights = weights.to(torch.float32) - - indices = indices.to(torch.int32) - offsets = offsets.to(torch.int32) - inputs = [ - indices, - offsets, - ] - - op = ( - torch.ops.quantized.embedding_bag_4bit_rowwise_offsets - if is_4bit - else torch.ops.quantized.embedding_bag_byte_rowwise_offsets - ) - - m = TestModule( - op, - q_weights, - per_sample_weights, - ) - - traced = acc_tracer.trace(m, inputs) - _LOGGER.info(traced.graph) - - expected_target = ( - acc_ops.embedding_bag_4bit_rowwise_offsets - if is_4bit - else acc_ops.embedding_bag_byte_rowwise_offsets - ) - - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "indices": - inp_node = node - elif str(node.target) == "offsets": - offsets_node = node - else: - self.fail(f"Unexpected placeholder {node.target}.") - continue - elif node.op == "get_attr" and node.target == "q_weights": - weight_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, expected_target) - # Note: Normalization called from acc_tracer means we use all kwargs. - self.assertEqual(node.kwargs["indices"], inp_node) - self.assertEqual(node.kwargs["offsets"], offsets_node) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["mode"], 0) - self.assertEqual(node.kwargs["include_last_offset"], True) - # The rest of these were unspecified, so verify they fell back - # to their respective default values thanks to normalization. - eb_node = node - elif node.op == "output": - self.assertEqual(eb_node, node.args[0]) - self.assertTrue(torch.equal(m(indices, offsets), traced(indices, offsets))) - - # test 8-bit - run_embedding_bag_test(is_4bit=False, use_weights=True) - # test 4-bit - run_embedding_bag_test(is_4bit=True, use_weights=True) - - def test_quantized_batch_norm2d(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.bn = nn.quantized.BatchNorm2d(3) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.bn(a) - - m = TestModule() - m.eval() - input = torch.quantize_per_tensor( - torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 - ) - traced = acc_tracer.trace(m, [input]) - ph = weight_attr = bias_attr = bn_mean = bn_var = bn = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "bn.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "bn.bias": - bias_attr = node - elif node.op == "get_attr" and node.target == "bn.running_mean": - bn_mean = node - elif node.op == "get_attr" and node.target == "bn.running_var": - bn_var = node - elif node.op == "get_attr" and node.target == "bn.scale": - bn_scale = node - elif node.op == "get_attr" and node.target == "bn.zero_point": - bn_zero_point = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.quantized_batch_norm2d) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - self.assertEqual(node.kwargs["running_mean"], bn_mean) - self.assertEqual(node.kwargs["running_var"], bn_var) - self.assertEqual(node.kwargs["acc_out_ty"][6]["scale"], bn_scale) - self.assertEqual( - node.kwargs["acc_out_ty"][6]["zero_point"], bn_zero_point - ) - bn = node - elif node.op == "output": - self.assertEqual(bn, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_linear(self): - """ - Test that a linear is traced as expected, i.e. to the functional level and with - kwarg normalization. Also verify that symbolic shape inference worked as part of - the acc_tracer. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(3, 5, bias=True) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.linear(a) - - m = TestModule() - test_input = torch.randn(1, 3) - traced = acc_tracer.trace(m, [test_input]) - ph = weight_attr = bias_attr = linear = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "linear.weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "linear.bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.linear) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - linear = node - elif node.op == "output": - self.assertEqual(linear, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - self.assertTrue(torch.equal(m(test_input), traced(test_input))) - - def test_quantized_linear(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.quantized.Linear(3, 5) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.linear(a) - - m = TestModule() - input = torch.quantize_per_tensor( - torch.randn(2, 3), scale=0.01, zero_point=3, dtype=torch.quint8 - ) - traced = acc_tracer.trace(m, [input]) - ph = weight_attr = bias_attr = linear = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "linear_weight": - weight_attr = node - elif node.op == "get_attr" and node.target == "linear_bias": - bias_attr = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.quantized_linear) - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight_attr) - self.assertEqual(node.kwargs["bias"], bias_attr) - linear = node - elif node.op == "output": - self.assertEqual(linear, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input), traced(input))) - - @parameterized.expand( - [ - param("remove_exceptions_false", remove_exceptions=False), - param("remove_exceptions_true", remove_exceptions=True), - ] - ) - def test_batch_norm(self, _, remove_exceptions): - """ - Test that a batch norm is traced as expected, i.e. to the functional level - and with kwarg normalization. Note that we also expect to see a - ConditionalExceptionWrapper in the graph that the AST rewriter converted - from `if x: raise y`. - - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(2) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.bn(a) - - m = TestModule() - input = torch.randn(2, 2, 1, 1) - # Note: Explicitly not removing exceptions so that we can check they - # were found and exist below. - traced = acc_tracer.trace( - m, - [input], - remove_exceptions=remove_exceptions, - ) - - ph = exception_wrapper = weight = bias = mean = var = bn = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "get_attr" and node.target == "bn.weight": - weight = node - elif node.op == "get_attr" and node.target == "bn.bias": - bias = node - elif node.op == "get_attr" and node.target == "bn.running_mean": - mean = node - elif node.op == "get_attr" and node.target == "bn.running_var": - var = node - elif node.op == "call_function" and node.target == acc_ops.batch_norm: - # Note: Normalization called from acc_tracer means we use - # all kwargs. - self.assertEqual(node.kwargs["input"], ph) - self.assertEqual(node.kwargs["weight"], weight) - self.assertEqual(node.kwargs["bias"], bias) - self.assertEqual(node.kwargs["running_mean"], mean) - self.assertEqual(node.kwargs["running_var"], var) - bn = node - elif ( - node.op == "call_module" - and node.target == "bn._conditional_exception_wrapper_ValueError" - ): - exception_wrapper = node - elif node.op == "output": - self.assertEqual(bn, node.args[0]) - - self.assertTrue(remove_exceptions or exception_wrapper is not None) - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_remove_asserts(self): - """ - Test that a Module with asserts has the asserts automatically removed, as - well as calls to a class method that should be dead. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def _test_method(self, a): - return a - - def forward(self, a: torch.Tensor) -> torch.Tensor: - assert torch.equal(self._test_method(a), a) - return a - - m = TestModule() - input = torch.randn(10) - traced = acc_tracer.trace(m, [input], ast_rewriter_allow_list={TestModule}) - # Check we have no call_functions. If remove asserts didn't work - # correctly we would see a call to torch._assert, _test_method, and - # torch.equal. - for node in traced.graph.nodes: - self.assertFalse(node.op == "call_function") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_no_rewrite_leaf_module(self): - """ - Test that when we supply a leaf module, we don't rewrite it - """ - - class TestChildModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return a.relu() - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.child = TestChildModule() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.child(a) + self.child(a) - - m = TestModule() - input = torch.randn(10) - traced = acc_tracer.trace(m, [input], leaf_module_list={TestChildModule}) - # trace it again just in case - traced = acc_tracer.trace(traced, [input], leaf_module_list={TestChildModule}) - - for _, m in traced.named_children(): - self.assertFalse("__AccRewrittenModule" in str(type(m)), str(type(m))) - - def test_sequential(self): - """ - Test that the tracer works for torch.nn.Sequential. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.model = nn.Sequential(nn.Sigmoid(), nn.ReLU()) - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return self.model(a) - - m = TestModule() - input = torch.randn(10) - traced = acc_tracer.trace(m, [input]) - - for node in traced.graph.nodes: - if node.op == "call_function": - is_sigmoid = node.target == acc_ops.sigmoid - is_relu = node.target == acc_ops.relu - self.assertTrue(is_sigmoid or is_relu) - else: - self.assertTrue(node.op == "placeholder" or node.op == "output") - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_unsqueeze(self): - """ - Test that torch.unsqueeze is traced correctly. - """ - self._make_acc_op_function_test( - acc_ops.unsqueeze, - torch.unsqueeze, - validate_same_kwargs=False, - dim=1, - ) - - def test_stack(self): - """ - Test that torch.stack is traced correctly. - """ - - class TestModule(torch.nn.Module): - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - return torch.stack((a, b), dim=1) - - a, b = torch.randn(4, 5, 6), torch.randn(4, 5, 6) - mod = TestModule() - traced = acc_tracer.trace(mod, [a, b]) - self.assertTrue(torch.equal(mod(a, b), traced(a, b))) - - ph_a = ph_b = unsqueeze_a = unsqueeze_b = cat_node = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - ph_b = node - elif node.op == "call_function": - if node.target == acc_ops.unsqueeze: - if node.kwargs["input"] is ph_a: - unsqueeze_a = node - else: - self.assertEqual(node.kwargs["input"], ph_b) - unsqueeze_b = node - else: - self.assertEqual(node.target, acc_ops.cat) - self.assertEqual(node.kwargs["tensors"], [unsqueeze_a, unsqueeze_b]) - cat_node = node - elif node.op == "output": - self.assertEqual(cat_node, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - def test_no_raise(self): - """ - self that we can trace `if x: raise y(msg)` when the raise isn't executed. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b): - if torch.equal(a, b): - raise AssertionError("a equaled b!") - return a - - m = TestModule() - in_a, in_b = torch.randn(5), torch.randn(5) - traced = acc_tracer.trace( - m, - [in_a, in_b], - remove_exceptions=False, - use_acc_normalization=False, - ast_rewriter_allow_list={TestModule}, - ) - - # Verify the structure of the graph, including the existence of the - # exception_wrapper. - ph_a = exception_wrapper = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - elif node.op == "call_module": - self.assertEqual( - node.target, "_conditional_exception_wrapper_AssertionError" - ) - exception_wrapper = node - elif node.op == "output": - self.assertEqual(ph_a, node.args[0]) - - self.assertTrue(exception_wrapper is not None) - - self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_b))) - - def test_yes_raise(self): - """ - Test that we can trace `if x: raise y(msg)` when the raise is executed. - """ - err_str = "a equaled b!" - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.err_str = err_str - - def forward(self, a, b): - if torch.equal(a, b): - raise RuntimeError(self.err_str) - return a - - m = TestModule() - # Note: We must use different inputs here in order for shape_prop to work, as - # otherwise the exception is thrown (as expected/checked below). - in_a, in_b = torch.randn(5), torch.randn(5) - traced = acc_tracer.trace( - m, - [in_a, in_b], - remove_exceptions=False, - ast_rewriter_allow_list={TestModule}, - ) - - # Verify the structure of the graph, including the existence of the - # exception_wrapper. - ph_a = exception_wrapper = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - elif node.op == "call_module": - self.assertEqual( - node.target, "_conditional_exception_wrapper_RuntimeError" - ) - exception_wrapper = node - elif node.op == "output": - self.assertEqual(ph_a, node.args[0]) - - self.assertTrue(exception_wrapper is not None) - - def test(mod): - try: - # Note: Use the same input here to ensure the exception is thrown. - mod(in_a, in_a) - self.fail("Shouldn't get here because exception should be thrown.") - except RuntimeError as e: - self.assertEqual(err_str, str(e)) - - test(m) - test(traced) - - def test_remove_raise(self): - """ - Test that we can trace `if x: raise y(msg)` and then remove the exception_wrapper. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b): - if torch.equal(a, b): - raise AssertionError("a equaled b!") - return a - - m = TestModule() - in_a, in_b = torch.randn(5), torch.randn(5) - traced = acc_tracer.trace( - m, - [in_a, in_b], - remove_exceptions=True, - ast_rewriter_allow_list={TestModule}, - ) - - # Verify the structure of the graph, including the existence of the - # exception_wrapper. - ph_a = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - elif node.op == "output": - self.assertEqual(ph_a, node.args[0]) - else: - # Should not encounter any call_modules, e.g. to the - # exception_wrapper. - self.assertFalse(node.op == "call_module") - - # Note: Using input in_a twice for the tracer version, which would - # trigger the raise if it was still there. - self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_a))) - - def test_raise_no_message(self): - """ - Test that we can trace `if x: raise y` when `y` has no message. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b): - if torch.equal(a, b): - raise AssertionError - return a - - m = TestModule() - in_a, in_b = torch.randn(5), torch.randn(5) - traced = acc_tracer.trace( - m, - [in_a, in_b], - remove_exceptions=False, - use_acc_normalization=False, - ast_rewriter_allow_list={TestModule}, - ) - - # Verify the structure of the graph, including the existence of the - # exception_wrapper. - ph_a = exception_wrapper = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - elif node.op == "call_module": - self.assertEqual( - node.target, "_conditional_exception_wrapper_AssertionError" - ) - exception_wrapper = node - elif node.op == "output": - self.assertEqual(ph_a, node.args[0]) - - self.assertTrue(exception_wrapper is not None) - self.assertTrue(torch.equal(m(in_a, in_b), traced(in_a, in_b))) - - def test_quantized_add(self): - """ - Test that a quantized_add and acc_ops.quantize_per_tensor are traced as expected, - verifying the acc_out_tys are set as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.q_input = torch.nn.quantized.Quantize( - scale=1.0 / 128, zero_point=5, dtype=torch.quint8 - ) - self.q_other = torch.nn.quantized.Quantize( - scale=1.0 / 128, zero_point=10, dtype=torch.quint8 - ) - - def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor: - return torch.ops.quantized.add( - self.q_input(input), - self.q_other(other), - scale=0.05, - zero_point=1, - ) - - m = TestModule() - input, other = torch.randn(2, 3, 4), torch.randn(2, 3, 4) - traced = acc_tracer.trace(m, [input, other]) - - input_ph = other_ph = q_input = q_other = q_add = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "input": - input_ph = node - else: - self.assertTrue(str(node.target) == "other") - other_ph = node - elif ( - node.op == "call_function" - and node.target == acc_ops.quantize_per_tensor - ): - qparams = { - "scale": 1.0 / 128, - "zero_point": 5, - } - expected_md = acc_utils.build_raw_tensor_meta( - dtype=torch.quint8, - qparams=qparams, - ) - if node.kwargs["input"] == input_ph: - q_input = node - else: - self.assertTrue(node.kwargs["input"] == other_ph) - q_other = node - qparams_copy = qparams.copy() - qparams_copy["zero_point"] = 10 - expected_md = expected_md._replace(qparams=qparams_copy) - self.assertEqual(node.kwargs["acc_out_ty"], expected_md) - elif node.op == "call_function" and node.target == acc_ops.quantized_add: - self.assertEqual(node.kwargs["input"], q_input) - self.assertEqual(node.kwargs["other"], q_other) - qparams = { - "scale": 0.05, - "zero_point": 1, - } - expected_md = acc_utils.build_raw_tensor_meta(qparams=qparams) - self.assertEqual(node.kwargs["acc_out_ty"], expected_md) - q_add = node - elif node.op == "output": - self.assertEqual(q_add, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input, other), traced(input, other))) - - def test_quantized_mul(self): - """ - Test that a quantized_mul and acc_ops.quantize_per_tensor are traced as expected, - verifying the acc_out_tys are set as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.q_input = torch.nn.quantized.Quantize( - scale=1.0 / 128, zero_point=5, dtype=torch.quint8 - ) - self.q_other = torch.nn.quantized.Quantize( - scale=1.0 / 128, zero_point=10, dtype=torch.quint8 - ) - - def forward(self, input: torch.Tensor, other: torch.Tensor) -> torch.Tensor: - return torch.ops.quantized.mul( - self.q_input(input), - self.q_other(other), - scale=0.05, - zero_point=1, - ) - - m = TestModule() - input, other = torch.randn(2, 3, 4), torch.randn(2, 3, 4) - traced = acc_tracer.trace(m, [input, other]) - - input_ph = other_ph = q_input = q_other = q_add = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "input": - input_ph = node - else: - self.assertTrue(str(node.target) == "other") - other_ph = node - elif ( - node.op == "call_function" - and node.target == acc_ops.quantize_per_tensor - ): - qparams = { - "scale": 1.0 / 128, - "zero_point": 5, - } - expected_md = acc_utils.build_raw_tensor_meta( - dtype=torch.quint8, - qparams=qparams, - ) - if node.kwargs["input"] == input_ph: - q_input = node - else: - self.assertTrue(node.kwargs["input"] == other_ph) - q_other = node - qparams_copy = qparams.copy() - qparams_copy["zero_point"] = 10 - expected_md = expected_md._replace(qparams=qparams_copy) - self.assertEqual(node.kwargs["acc_out_ty"], expected_md) - elif node.op == "call_function" and node.target == acc_ops.quantized_mul: - self.assertEqual(node.kwargs["input"], q_input) - self.assertEqual(node.kwargs["other"], q_other) - qparams = { - "scale": 0.05, - "zero_point": 1, - } - expected_md = acc_utils.build_raw_tensor_meta(qparams=qparams) - self.assertEqual(node.kwargs["acc_out_ty"], expected_md) - q_add = node - elif node.op == "output": - self.assertEqual(q_add, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input, other), traced(input, other))) - - def test_cat(self): - """ - Test that torch.cat is traced correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - return torch.cat([a, a, b], 0) - - m = TestModule() - a, b = torch.randn(2, 2), torch.randn(2, 2) - traced = acc_tracer.trace(m, (a, b)) - - ph_a = ph_b = cat = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - ph_b = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.cat) - self.assertEqual(node.kwargs["tensors"][0], ph_a) - self.assertEqual(node.kwargs["tensors"][1], ph_a) - self.assertEqual(node.kwargs["tensors"][2], ph_b) - self.assertEqual(node.kwargs["dim"], 0) - cat = node - elif node.op == "output": - self.assertEqual(cat, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(a, b), traced(a, b))) - - def test_square(self): - """ - Test that torch.square is traced correctly. - """ - self._make_acc_op_function_test(acc_ops.mul, torch.square) - - def test_reshape(self): - """ - Test that torch.reshape is traced correctly. - """ - self._make_acc_op_function_test(acc_ops.reshape, torch.reshape, (1, -1)) - # arg = (1, -1) - self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.reshape(1, -1)) - # arg = ((1, -1)) - self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.reshape((1, -1))) - - def test_transpose(self): - """ - Test that torch.transpose is traced correctly. - """ - self._make_acc_op_function_test( - acc_ops.permute, lambda x: torch.transpose(x, 1, 0) - ) - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - x = len(a.shape) - 2 - y = len(a.shape) - 1 - return a.transpose(x, y) - - m = TestModule() - m.eval() - - a = torch.randn(2, 3, 4, 5) - traced = acc_tracer.trace(m, [a]) - - ph_a = permute = None - for node in traced.graph.nodes: - if node.op == "placeholder": - ph_a = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.permute) - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["permutation"], [0, 1, 3, 2]) - permute = node - elif node.op == "output": - self.assertEqual(permute, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(a), traced(a))) - - def test_permute(self): - """ - Test that torch.permute is traced correctly. - """ - - def torch_permute(a, *dim): - return a.permute(*dim) - - self._make_acc_op_function_test(acc_ops.permute, torch_permute, 1, 0) - - def test_min_full_reduce(self): - """ - Test that test_min_full_reduce is traced correctly. - """ - self._make_acc_op_function_test(acc_ops.min_full_reduce, torch.min) - - def test_matmul(self): - """ - Test that torch.matmul is traced correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - return torch.matmul(a, b) - - m = TestModule() - a, b = torch.randn(2, 2), torch.randn(2, 2) - traced = acc_tracer.trace(m, [a, b]) - - ph_a = ph_b = matmul = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - else: - self.assertTrue(str(node.target) == "b") - ph_b = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.matmul) - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["other"], ph_b) - matmul = node - elif node.op == "output": - self.assertEqual(matmul, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(a, b), traced(a, b))) - - def test_bmm(self): - self._make_acc_op_function_test( - acc_ops.matmul, lambda x: torch.bmm(x, x), input_shape=(2, 4, 4) - ) - - def test_tile(self): - return self._make_acc_op_function_test( - acc_ops.tile, lambda x: torch.tile(x, (2, 1, 2)), input_shape=(1, 2) - ) - - def test_dropout(self): - self._make_acc_op_function_test( - None, - lambda x: nn.functional.dropout(x, training=False), - input_shape=(1, 2, 3), - ) - - def test_stochastic_depth(self): - self._make_acc_op_function_test( - None, - lambda x, p, mode, training: torchvision.ops.stochastic_depth( - x, p=p, mode=mode, training=training - ), - input_shape=(1, 2, 3), - p=0.5, - mode="row", - training=False, - ) - - def test_hardsigmoid(self): - self._make_acc_op_function_test( - acc_ops.hardsigmoid, - lambda x: nn.functional.hardsigmoid(x), - input_shape=(3, 4, 5), - ) - - def test_hardtanh(self): - self._make_acc_op_function_test( - acc_ops.hardtanh, - lambda x: nn.functional.hardtanh(x), - input_shape=(3, 4, 5), - ) - - def test_hardswish(self): - class TestModule(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - y = nn.functional.hardswish(x) - return y - - m = TestModule() - x = torch.randn(3, 4, 5) - traced = acc_tracer.trace(m, [x]) - ph_x = hardsigmoid_y = res_y = None - for node in traced.graph.nodes: - if node.op == "placeholder": - ph_x = node - elif node.op == "call_function" and node.target == acc_ops.hardsigmoid: - hardsigmoid_y = node - self.assertEqual(node.kwargs["input"], ph_x) - elif node.op == "call_function" and node.target == acc_ops.mul: - res_y = node - self.assertEqual(node.kwargs["input"], hardsigmoid_y) - self.assertEqual(node.kwargs["other"], ph_x) - elif node.op == "output": - self.assertEqual(node.args[0], res_y) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref = m(x) - res = traced(x) - torch.testing.assert_close(ref, res) - - def test_add_with_alpha(self): - """ - Test that normalization works for torch add with alpha, which requires special - normalization handling. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - a1 = torch.add(a, b) - a2 = torch.add(a, b, alpha=1.0) - a3 = torch.add(a, b, alpha=0.5) - return a1, a2, a3 - - m = TestModule() - input_a = torch.randn(2, 3) - input_b = torch.randn(2, 3) - traced = acc_tracer.trace(m, [input_a, input_b]) - - ph_a = ph_b = add_1 = add_2 = add_3 = mul = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - elif str(node.target) == "b": - ph_b = node - else: - self.fail(f"Unexpected placeholder {node.target}.") - elif node.op == "call_function" and node.target == acc_ops.mul: - mul = node - self.assertEqual(node.kwargs["input"], ph_b) - self.assertEqual(node.kwargs["other"], 0.5) - elif node.op == "call_function" and node.target == acc_ops.add: - if add_1 is None: - add_1 = node - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["other"], ph_b) - elif add_2 is None: - add_2 = node - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["other"], ph_b) - elif add_3 is None: - add_3 = node - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["other"], mul) - else: - self.fail(f"Unexpected add: {node.format_node()}") - elif node.op == "output": - self.assertEqual(node.args[0][0], add_1) - self.assertEqual(node.args[0][1], add_2) - self.assertEqual(node.args[0][2], add_3) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref = m(input_a, input_b) - res = traced(input_a, input_b) - self.assertTrue(torch.equal(ref[0], res[0])) - self.assertTrue(torch.equal(ref[1], res[1])) - self.assertTrue(torch.equal(ref[2], res[2])) - - def test_leaf_module_list(self): - """ - Test leaf_module_list is working properly. - """ - - class LeafModule(nn.Module): - def forward(self, x): - return x - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.mod = LeafModule() - - def forward(self, x): - return self.mod(x) - - x = torch.randn(1, 1) - mod = TestModule() - acc_mod = acc_tracer.trace( - mod, - [x], - leaf_module_list={LeafModule}, - ) - ph = leaf_module = None - for node in acc_mod.graph.nodes: - if node.op == "placeholder": - ph = node - elif node.op == "call_module": - leaf_module = node - self.assertEqual(leaf_module.target, "mod") - self.assertEqual(leaf_module.args[0], ph) - elif node.op == "output": - self.assertEqual(node.args[0], leaf_module) - else: - self.fail(f"Unexpected node: {node.format_node()}") - self.assertTrue(torch.equal(mod(x), acc_mod(x))) - - def test_sign(self): - self._make_acc_op_function_test(acc_ops.sign, torch.sign) - - def test_relu(self): - self._make_acc_op_function_test(acc_ops.relu, torch.relu) - - def test_leaky_relu(self): - self._make_acc_op_function_test( - acc_ops.leaky_relu, torch.nn.functional.leaky_relu - ) - - def test_elu(self): - self._make_acc_op_function_test(acc_ops.elu, torch.nn.functional.elu) - - def test_selu(self): - self._make_acc_op_function_test(acc_ops.selu, torch.nn.functional.selu) - - def test_softsign(self): - self._make_acc_op_function_test(acc_ops.softsign, torch.nn.functional.softsign) - - def test_sigmoid(self): - self._make_acc_op_function_test(acc_ops.sigmoid, torch.sigmoid) - - def test_sin(self): - self._make_acc_op_function_test(acc_ops.sin, torch.sin) - - def test_cos(self): - self._make_acc_op_function_test(acc_ops.cos, torch.cos) - - def test_tan(self): - self._make_acc_op_function_test(acc_ops.tan, torch.tan) - - def test_sinh(self): - self._make_acc_op_function_test(acc_ops.sinh, torch.sinh) - - def test_cosh(self): - self._make_acc_op_function_test(acc_ops.cosh, torch.cosh) - - def test_tanh(self): - self._make_acc_op_function_test(acc_ops.tanh, torch.tanh) - - def test_asin(self): - self._make_acc_op_function_test(acc_ops.asin, torch.asin) - - def test_acos(self): - self._make_acc_op_function_test(acc_ops.acos, torch.acos) - - def test_atan(self): - self._make_acc_op_function_test(acc_ops.atan, torch.atan) - - def test_exp(self): - self._make_acc_op_function_test(acc_ops.exp, torch.exp) - - def test_log(self): - self._make_acc_op_function_test(acc_ops.log, torch.log) - - def test_sqrt(self): - self._make_acc_op_function_test(acc_ops.sqrt, torch.sqrt) - - def test_reciprocal(self): - self._make_acc_op_function_test(acc_ops.reciprocal, torch.reciprocal) - - def test_abs(self): - self._make_acc_op_function_test(acc_ops.abs, torch.abs) - - def test_neg(self): - self._make_acc_op_function_test(acc_ops.neg, torch.neg) - - def test_floor(self): - self._make_acc_op_function_test(acc_ops.floor, torch.floor) - - def test_ceil(self): - self._make_acc_op_function_test(acc_ops.ceil, torch.ceil) - - def test_softmax(self): - self._make_acc_op_function_test(acc_ops.softmax, torch.nn.functional.softmax) - - def test_tensor_squeeze(self): - self._make_acc_op_function_test(acc_ops.squeeze, lambda x: x.squeeze()) - - def test_torch_squeeze(self): - self._make_acc_op_function_test(acc_ops.squeeze, lambda x: torch.squeeze(x)) - - def test_operator_mul(self): - self._make_acc_op_function_test(acc_ops.mul, lambda x: x * 7) - - def test_torch_mul(self): - self._make_acc_op_function_test(acc_ops.mul, lambda x: torch.mul(x, 7)) - - def test_torch_isinf(self): - self._make_acc_op_function_test(acc_ops.isinf, torch.isinf) - - def test_torch_any(self): - self._make_acc_op_function_test(acc_ops.any, torch.any) - - def test_div(self): - self._make_acc_op_function_test(acc_ops.div, lambda x: torch.div(x, 2)) - self._make_acc_op_function_test(acc_ops.div, lambda x: x / 2) - - def test_fmod(self): - self._make_acc_op_function_test(acc_ops.fmod, lambda x: torch.fmod(x, 1.3)) - self._make_acc_op_function_test(acc_ops.fmod, lambda x: torch.fmod(x, -0.4)) - - def test_floor_div(self): - self._make_acc_op_function_test( - acc_ops.floor_div, lambda x: torch.div(x, 2, rounding_mode="floor") - ) - - def test_trunc_div(self): - self._make_acc_op_function_test( - acc_ops.trunc_div, lambda x: torch.div(x, 2, rounding_mode="trunc") - ) - # does not behave the same as floor_divide - # self._make_acc_op_function_test( - # acc_ops.trunc_div, lambda x: torch.floor_divide(x, 2) - # ) - - def test_view(self): - """ - Test that Tensor.view is traced correctly. - """ - - self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view(1, -1)) - self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view([1, -1])) - - def test_narrow(self): - """ - Test that torch.narrow is traced correctly. - """ - return self._make_acc_op_function_test( - acc_ops.slice_tensor, - torch.narrow, - validate_same_kwargs=False, - dim=1, - start=1, - length=2, - ) - - def test_pow(self): - self._make_acc_op_function_test(acc_ops.pow, torch.pow, exponent=2) - - def test_numel(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - return torch.numel(a) - - m = TestModule() - a = torch.randn(2, 1, 4) - traced = acc_tracer.trace(m, [a]) - - ph_a = numel = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertTrue(node.target == "a") - ph_a = node - elif node.op == "call_function" and node.target == acc_ops.numel: - numel = node - self.assertTrue(numel.kwargs["input"] is ph_a) - elif node.op == "output": - self.assertEqual(node.args[0], numel) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref = m(a) - res = traced(a) - self.assertEqual(ref, res) - - def test_size(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - idx = a.size(1) - return a.shape[idx] - - m = TestModule() - a = torch.randn(2, 1, 4) - traced = acc_tracer.trace(m, [a]) - - ph_a = size_1 = size_2 = getitem_1 = getitem_2 = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertTrue(node.target == "a") - ph_a = node - elif node.op == "call_function" and node.target == acc_ops.size: - if size_1: - size_2 = node - self.assertTrue(size_2.kwargs["input"] is ph_a) - else: - size_1 = node - self.assertTrue(size_1.kwargs["input"] is ph_a) - elif node.op == "call_function" and node.target == acc_ops.getitem: - if getitem_1: - getitem_2 = node - self.assertTrue(getitem_2.kwargs["idx"] == getitem_1) - self.assertTrue(getitem_2.kwargs["input"] == size_2) - else: - getitem_1 = node - self.assertTrue(getitem_1.kwargs["idx"] == 1) - self.assertTrue(getitem_1.kwargs["input"] == size_1) - elif node.op == "output": - self.assertEqual(node.args[0], getitem_2) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref = m(a) - res = traced(a) - self.assertEqual(ref, res) - - def test_getattr_named_tuple(self): - """ - Test that call_function getattr on namedtuples is - traced correctly. - """ - - class TestNamedTuple(NamedTuple): - foo: torch.Tensor - bar: torch.Tensor - - class TestModule(nn.Module): - def forward(self, a: TestNamedTuple): - return a.foo + a.bar - - m = TestModule() - a = TestNamedTuple(torch.randn(2, 2), torch.randn(2, 2)) - traced = acc_tracer.trace(m, [a]) - - ph_a = getitem_1 = getitem_2 = add = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(node.target, "a") - ph_a = node - - elif node.op == "call_function" and node.target == acc_ops.getitem: - if getitem_1: - getitem_2 = node - self.assertEqual(getitem_2.kwargs["idx"], 1) - else: - getitem_1 = node - self.assertEqual(getitem_1.kwargs["idx"], 0) - - self.assertEqual(node.kwargs["input"], ph_a) - - elif node.op == "call_function" and node.target == acc_ops.add: - self.assertEqual(node.kwargs["input"], getitem_1) - self.assertEqual(node.kwargs["other"], getitem_2) - add = node - - elif node.op == "output": - self.assertEqual(node.args[0], add) - - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref = m(a) - res = traced(a) - self.assertTrue(torch.equal(ref, res)) - - def test_flatten(self): - """ - Test that torch.flatten is traced correctly. - """ - self._make_acc_op_function_test( - acc_ops.flatten, torch.flatten, start_dim=1, end_dim=1 - ) - self._make_acc_op_function_test(acc_ops.flatten, lambda x: x.flatten()) - - def test_topk_multi_output(self): - """ - Test that torch.topk multi outputs work. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return torch.topk(a, 3)[1] - - m = TestModule() - input_a = torch.randn(10) - traced = acc_tracer.trace(m, [input_a]) - - ph_a = topk = getitem = None - for node in traced.graph.nodes: - if node.op == "placeholder" and str(node.target) == "a": - ph_a = node - elif node.op == "call_function" and node.target == acc_ops.topk: - topk = node - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["k"], 3) - elif node.op == "call_function" and node.target == acc_ops.getitem: - getitem = node - self.assertEqual(node.kwargs["input"], topk) - self.assertEqual(node.kwargs["idx"], 1) - elif node.op == "output": - self.assertEqual(node.args[0], getitem) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - self.assertTrue(torch.equal(m(input_a), traced(input_a))) - - def test_addmm_with_alpha_beta(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor - ) -> torch.Tensor: - return torch.addmm(input, a, b, alpha=1.2, beta=1.1) - - m = TestModule() - input, a, b = torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2) - traced = acc_tracer.trace(m, [input, a, b]) - - ph_in = ph_a = ph_b = mm = add = mm_mul = add_mul = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - elif str(node.target) == "b": - ph_b = node - else: - self.assertTrue(str(node.target) == "input") - ph_in = node - elif node.op == "call_function": - if node.target == acc_ops.matmul: - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["other"], ph_b) - mm = node - elif node.target == acc_ops.add: - self.assertEqual(node.kwargs["input"], mm_mul) - self.assertEqual(node.kwargs["other"], add_mul) - add = node - elif mm_mul: - self.assertEqual(node.kwargs["input"], ph_in) - self.assertEqual(node.kwargs["other"], 1.1) - add_mul = node - else: - self.assertEqual(node.kwargs["input"], mm) - self.assertEqual(node.kwargs["other"], 1.2) - mm_mul = node - elif node.op == "output": - self.assertEqual(add, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - torch.testing.assert_close(m(input, a, b), traced(input, a, b)) - - def test_log1p(self): - class TestModule(torch.nn.Module): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.log1p(input) - - m = TestModule().eval() - input = torch.tensor([[1.2, 0.3, -0.4]]) - traced = acc_tracer.trace(m, [input]) - - ph_in = add = log = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertTrue(str(node.target) == "input") - ph_in = node - elif node.op == "call_function": - if node.target == acc_ops.add: - self.assertEqual(node.kwargs["input"], ph_in) - self.assertEqual(node.kwargs["other"], 1) - add = node - else: - self.assertEqual(node.target, acc_ops.log) - self.assertEqual(node.kwargs["input"], add) - log = node - elif node.op == "output": - self.assertEqual(log, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - torch.testing.assert_close(m(input), traced(input)) - - @parameterized.expand([(torch.float,), (torch.float16,)]) - def test_addmm(self, dtype): - class TestModule(torch.nn.Module): - def forward( - self, input: torch.Tensor, a: torch.Tensor, b: torch.Tensor - ) -> torch.Tensor: - return torch.addmm(input, a, b) - - m = TestModule() - input, a, b = ( - torch.randn(2, 2, dtype=dtype), - torch.randn(2, 2, dtype=dtype), - torch.randn(2, 2, dtype=dtype), - ) - traced = acc_tracer.trace(m, [input, a, b]) - - ph_in = ph_a = ph_b = mm = add = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if str(node.target) == "a": - ph_a = node - elif str(node.target) == "b": - ph_b = node - else: - self.assertTrue(str(node.target) == "input") - ph_in = node - elif node.op == "call_function": - if node.target == acc_ops.matmul: - self.assertEqual(node.kwargs["input"], ph_a) - self.assertEqual(node.kwargs["other"], ph_b) - mm = node - else: - self.assertEqual(node.target, acc_ops.add) - self.assertEqual(node.kwargs["input"], mm) - self.assertEqual(node.kwargs["other"], ph_in) - add = node - elif node.op == "output": - self.assertEqual(add, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - for node in [ph_in, ph_a, ph_b, mm, add]: - self.assertEqual(acc_utils.get_tensor_meta(node).dtype, dtype) - - if dtype == torch.float: - self.assertTrue(torch.equal(m(input, a, b), traced(input, a, b))) - - def test_gelu(self): - return self._make_acc_op_function_test(acc_ops.gelu, torch.nn.functional.gelu) - - @parameterized.expand( - [ - (1, True), - (1, False), - (None, False), - ] - ) - def test_argmin(self, dim, keepdim): - class TestModule(torch.nn.Module): - def __init__(self, dim, keepdim): - super().__init__() - self.dim = dim - self.keepdim = keepdim - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.argmin(input, dim=self.dim, keepdim=self.keepdim) - - m = TestModule(dim, keepdim) - input = torch.randn(2, 2) - traced = acc_tracer.trace(m, [input]) - - ph_in = flatten = topk = getitem = squeeze = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertTrue(str(node.target) == "input") - ph_in = node - elif node.op == "call_function": - if node.target == acc_ops.flatten: - self.assertEqual(node.kwargs["input"], ph_in) - flatten = node - elif node.target == acc_ops.topk: - self.assertEqual( - node.kwargs["input"], flatten if flatten else ph_in - ) - topk = node - elif node.target == acc_ops.getitem: - self.assertEqual(node.kwargs["input"], topk) - getitem = node - elif node.target == acc_ops.squeeze: - self.assertEqual(node.kwargs["input"], getitem) - squeeze = node - elif node.op == "output": - self.assertEqual(squeeze if squeeze else getitem, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - if dim is None: - self.assertTrue(flatten is not None) - if not keepdim: - self.assertTrue(squeeze is not None) - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_t(self): - """ - Test Tensor.t() is traced correctly. - """ - self._make_acc_op_function_test(acc_ops.permute, lambda x: x.t()) - self._make_acc_op_function_test( - acc_ops.permute, lambda x: x.t(), input_shape=(3,) - ) - - def test_split_size(self): - self._make_acc_op_function_test( - acc_ops.split, - torch.split, - validate_same_kwargs=False, - split_size_or_sections=2, - dim=1, - ) - - def test_split_sections(self): - class TestModule(torch.nn.Module): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.split(input, [2, 5, 3], 1) - - m = TestModule() - input = torch.randn(1, 10) - traced = acc_tracer.trace(m, [input]) - - ph_in = slice_node_0 = slice_node_1 = slice_node_2 = None - tuple_construct_node = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertTrue(str(node.target) == "input") - ph_in = node - elif node.op == "call_function": - if node.target == acc_ops.slice_tensor: - self.assertEqual(node.kwargs["input"], ph_in) - if slice_node_0: - if slice_node_1: - slice_node_2 = node - else: - slice_node_1 = node - else: - slice_node_0 = node - else: - self.assertEqual(node.target, acc_ops.tuple_construct) - self.assertEqual( - node.kwargs["tensors"], - (slice_node_0, slice_node_1, slice_node_2), - ) - tuple_construct_node = node - elif node.op == "output": - self.assertEqual(tuple_construct_node, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref_output = m(input) - output = traced(input) - for i, j in zip(ref_output, output): - self.assertTrue(torch.equal(i, j)) - - @parameterized.expand( - [ - ("neg_1", -1, 1, 3), - ("neg_2", -2, 1, 3), - ("neg_4", -4, 1, 1), - ] - ) - def test_negative_slicing(self, _, dim, start, length): - """ - Test that slicing with negative dims works. - """ - self._make_acc_op_function_test( - acc_ops.slice_tensor, - torch.narrow, - input_shape=(2, 3, 4, 5), - validate_same_kwargs=False, - dim=dim, - start=start, - length=length, - ) - - def test_list_input(self): - """ - Test that list inputs are traced correctly. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: List[torch.Tensor]) -> torch.Tensor: - return a[0] + a[1] - - m = TestModule() - input = [torch.randn(2, 3), torch.randn(2, 3)] - traced = acc_tracer.trace(m, [input]) - - ph = getitem_0 = getitem_1 = add = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "call_function" and node.target == acc_ops.getitem: - self.assertTrue(node.kwargs["idx"] == 0 or node.kwargs["idx"] == 1) - if node.kwargs["idx"] == 0: - getitem_0 = node - else: - getitem_1 = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.add) - self.assertEqual(node.kwargs["input"], getitem_0) - self.assertEqual(node.kwargs["other"], getitem_1) - add = node - elif node.op == "output": - self.assertEqual(add, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - # Check the tensor ranks are correct given the input is a list. - self.assertIsInstance(ph.meta["tensor_rank"], list) - self.assertEqual(len(ph.meta["tensor_rank"]), 2) - self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"][0]) - self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"][1]) - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_dict_input(self): - """ - Test that dict inputs are traced correctly. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: Dict[str, torch.Tensor]) -> torch.Tensor: - return a["foo"] + a["bar"] - - m = TestModule() - input = {"foo": torch.randn(2, 3), "bar": torch.randn(2, 3)} - traced = acc_tracer.trace(m, [input]) - - ph = getitem_0 = getitem_1 = add = None - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "call_function" and node.target == acc_ops.getitem: - self.assertTrue( - node.kwargs["idx"] == "foo" or node.kwargs["idx"] == "bar" - ) - if node.kwargs["idx"] == "foo": - getitem_0 = node - else: - getitem_1 = node - elif node.op == "call_function": - self.assertEqual(node.target, acc_ops.add) - self.assertEqual(node.kwargs["input"], getitem_0) - self.assertEqual(node.kwargs["other"], getitem_1) - add = node - elif node.op == "output": - self.assertEqual(add, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - # Check the tensor ranks are correct given the input is a dict. - self.assertIsInstance(ph.meta["tensor_rank"], dict) - self.assertEqual(len(ph.meta["tensor_rank"]), 2) - self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"]["foo"]) - self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"]["bar"]) - - self.assertTrue(torch.equal(m(input), traced(input))) - - def test_none_type_ret(self): - """ - Test that a NoneType is traced as expected. - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, a: torch.Tensor - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - return a + a, None - - m = TestModule() - input = torch.randn(1, 2, 3) - try: - traced = acc_tracer.trace( - m, - [input], - ) - except RuntimeError as e: - self.assertEqual( - "This error should not be triggered, as NoneType should be lowered without an issue", - str(e), - ) - ans1, _ = m(input) - ans2, _ = traced(input) - self.assertTrue(torch.equal(ans1, ans2)) - - def test_mobilenet_v3(self): - """ - Test that we can trace mobilenet v3 small and run/compare against the untraced version. - """ - m = torchvision.models.mobilenet_v3_small(pretrained=True) - self._make_model_unit_test(m, enable_allclose=True) - - def test_mobilenet_v2(self): - """ - Test that we can trace mobilenet v2 small and run/compare against the untraced version. - """ - m = torchvision.models.mobilenet_v2(pretrained=True) - self._make_model_unit_test(m) - - def test_vgg16(self): - """ - Test that we can trace vgg16 and run/compare against the untraced version. - """ - m = torchvision.models.vgg16(pretrained=True) - self._make_model_unit_test(m) - - def test_resnet18(self): - """ - Test that we can trace resnet18 and run/compare against the untraced version. - """ - m = torchvision.models.resnet18(pretrained=True) - self._make_model_unit_test(m) - - def test_resnext50_32x4d(self): - """ - Test that we can trace resnext and run/compare against the untraced version. - """ - m = torchvision.models.resnext50_32x4d(pretrained=True) - self._make_model_unit_test(m) - - def test_cumsum(self): - # Tests call_function version - self._make_acc_op_function_test(acc_ops.cumsum, torch.cumsum, dim=1) - self._make_acc_op_function_test( - acc_ops.cumsum, torch.cumsum, dim=1, dtype=torch.float - ) - - # Tests call_method version - class TestModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return a.cumsum(dim=0) - - m = TestModule() - a = torch.rand(2, 2) - gm = acc_tracer.trace(m, [a]) - self.assertTrue(torch.equal(m(a), gm(a))) - - def test_chunk(self): - self._make_acc_op_function_test(acc_ops.chunk, torch.chunk, chunks=2, dim=0) - - def test_retrace_reshape(self): - """ - Retrace reshape to verify it's retraceable. - """ - - class TestModule(torch.nn.Module): - def forward(self, a: torch.Tensor) -> torch.Tensor: - return a.reshape(a.size()[0], 1, 2) - - m = TestModule() - a = torch.randn(2, 2) - gm = acc_tracer.trace(m, [a]) - self.assertTrue(torch.equal(m(a), gm(a))) - gm_retrace = acc_tracer.trace(gm, [a]) - self.assertTrue(torch.equal(m(a), gm_retrace(a))) - - def test_index_select(self): - class TestModule(nn.Module): - def __init__(self, dim, index): - super().__init__() - self._dim = dim - self._index = index - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return torch.index_select(a, self._dim, self._index) - - dim = 0 - index = torch.tensor([1, 0]) - m = TestModule(dim, index) - _input = [torch.randn(2, 3), torch.randn(2, 3)] - traced = acc_tracer.trace(m, _input) - - ph = index = index_select = None - - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "call_function" and node.target == acc_ops.index_select: - self.assertTrue(node.kwargs["input"] == ph) - self.assertTrue(node.kwargs["index"] == index) - self.assertTrue(node.kwargs["dim"] == dim) - index_select = node - elif node.op == "output": - self.assertEqual(index_select, node.args[0]) - elif node.op == "get_attr": - # There only be oneâ„¢ const node - self.assertTrue(index is None) - index = node - else: - self.fail(f"Unexpected node: {node.format_node()}") - - def test_gather(self): - class TestModule(nn.Module): - def __init__(self, dim, index): - super().__init__() - self._dim = dim - self._index = index - - def forward(self, a: torch.Tensor) -> torch.Tensor: - return torch.gather(a, self._dim, self._index) - - dim = 0 - index = torch.tensor([[1, 0], [0, 1]]) - m = TestModule(dim, index) - _input = [torch.randn(2, 3), torch.randn(2, 3)] - traced = acc_tracer.trace(m, _input) - - ph = index = gather = None - - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(str(node.target), "a") - ph = node - elif node.op == "call_function" and node.target == acc_ops.gather: - self.assertTrue(node.kwargs["input"] == ph) - self.assertTrue(node.kwargs["index"] == index) - self.assertTrue(node.kwargs["dim"] == dim) - gather = node - elif node.op == "output": - self.assertEqual(gather, node.args[0]) - elif node.op == "get_attr": - # There only be oneâ„¢ const node - self.assertTrue(index is None) - index = node - else: - self.fail(f"Unexpected node: {node.format_node()}") - - def test_where(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b, c): - return torch.where(a, b, c) - - m = TestModule() - x = torch.randn(3, 2) - y = torch.ones(3, 2) - cond = x > 0 - traced = acc_tracer.trace(m, [cond, x, y]) - - ph_a = where = None - ph_b = None - ph_c = None - for node in traced.graph.nodes: - if node.op == "placeholder": - if node.target == "a": - ph_a = node - elif node.target == "b": - ph_b = node - elif node.target == "c": - ph_c = node - elif node.op == "call_function" and node.target == acc_ops.where: - where = node - self.assertTrue(where.kwargs["condition"] is ph_a) - self.assertTrue(where.kwargs["x"] is ph_b) - self.assertTrue(where.kwargs["y"] is ph_c) - elif node.op == "output": - self.assertEqual(node.args[0], where) - else: - self.fail(f"Unexpected node: {node.format_node()}") - - ref = m(cond, x, y) - res = traced(cond, x, y) - self.assertTrue(torch.equal(ref, res)) - - @parameterized.expand( - [ - ("sections divisible", 2, 0), - ("sections indivisible", 3, 0), - ("indices list", [1, 3], 0), - ("indices tuple", (1, 3), 0), - ("indices tensor", torch.tensor([1, 3]), 0), - ("indices tensor dim1", torch.tensor([1, 3]), 1), - ("indices tensor dim2", torch.tensor([1, 3]), 2), - ("indices tensor long dim2", torch.tensor([1, 3, 5, 7]), 2), - ] - ) - def test_tensor_split(self, _, indices_or_sections, dim): - """ - Test that the tracer works for torch.tensor_split with indices and sections - """ - - class TestModule(nn.Module): - def __init__(self, indices_or_sections, dim): - super().__init__() - self._indices_or_sections = indices_or_sections - self._dim = dim - - def forward(self, a): - return torch.tensor_split(a, self._indices_or_sections, self._dim) - - m = TestModule(indices_or_sections, dim) - a = torch.randn(4, 8, 16) - traced = acc_tracer.trace(m, [a]) - - results = traced(a) - references = m(a) - for res, ref in zip(results, references): - self.assertTrue( - torch.equal(ref, res), f"Tensors at don't match {ref=} {res=}" - ) - - def test_inplace_raise(self): - """ - Test that encountering inplace is raised for exception - """ - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - a = a + 2 - a.sub_(3) - return a - - m = TestModule() - in_a = torch.randn(5) - try: - acc_tracer.trace( - m, - [in_a], - ) - self.fail("Shouldn't get here because exception should be thrown.") - except RuntimeError as e: - self.assertEqual( - "Tried to trace mutable operation sub_. FX only supports functional code", - str(e), - ) - - def test_repeat_interleave(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.repeat_interleave(x, 2, 1) - - # TODO: finish test later - m = TestModule() - x = torch.randn(3, 4) - traced = acc_tracer.trace(m, [x]) - ph_in = tile = size = getitem = unsqueeze = reshape = None - for node in traced.graph.nodes: - if node.op == "placeholder": - ph_in = node - elif node.op == "call_function": - if node.target == acc_ops.size: - self.assertEqual(node.kwargs["input"], ph_in) - size = node - elif node.target == acc_ops.getitem: - self.assertEqual(node.kwargs["input"], size) - getitem = node - elif node.target == acc_ops.reshape: - self.assertEqual(node.kwargs["input"], tile) - reshape = node - elif node.target == acc_ops.unsqueeze: - self.assertEqual(node.kwargs["input"], ph_in) - unsqueeze = node - elif node.target == acc_ops.tile: - self.assertEqual(node.kwargs["input"], unsqueeze) - tile = node - elif node.op == "output": - self.assertEqual(reshape, node.args[0]) - else: - self.fail(f"Unexpected node: {node.format_node()}") - if size is not None: - self.assertIsNotNone(getitem) - self.assertTrue(torch.equal(m(x), traced(x))) - - def test_acc_normalization_block_list(self): - class TestModule(nn.Module): - def forward(self, x: List[torch.Tensor]) -> torch.Tensor: - return x[0] + x[1] - - m = TestModule() - x = [torch.randn(1), torch.randn(1)] - traced = acc_tracer.trace( - m, [x], acc_normalization_block_list={("call_function", operator.getitem)} - ) - for node in traced.graph.nodes: - if "getitem" in node.name: - # Make sure we didn't convert to the acc version - self.assertEqual(node.target, operator.getitem) - - def test_detach(self): - class TestModule(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.detach(x) - - m = TestModule() - sample_inputs = [torch.randn(8)] - traced = acc_tracer.trace(m, sample_inputs) - - placeholder = output = None - for node in traced.graph.nodes: - if node.op == "placeholder": - assert placeholder is None - placeholder = node - elif node.op == "output": - assert output is None - output = node - else: - raise RuntimeError(f"Unexpected Node {node.format_node()}") - - self.assertIsNotNone(placeholder) - self.assertIsNotNone(output) - - self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs))) - - def test_all_acc_ops_registered(self): - self.assertEqual( - acc_normalizer._acc_ops, - { - acc_ops.linear, - acc_ops.embedding, - acc_ops.max_pool1d, - acc_ops.max_pool2d, - acc_ops.max_pool3d, - acc_ops.flatten, - acc_ops.adaptive_avg_pool2d, - acc_ops.adaptive_avg_pool3d, - acc_ops.avg_pool1d, - acc_ops.avg_pool2d, - acc_ops.avg_pool3d, - acc_ops.add, - acc_ops.min_full_reduce, - acc_ops.min_dim_reduce, - acc_ops.minimum, - acc_ops.cat, - acc_ops.softmax, - acc_ops.sign, - acc_ops.permute, - acc_ops.matmul, - acc_ops.quantize_per_tensor, - acc_ops.quantize_per_channel, - acc_ops.quantized_add, - acc_ops.quantized_mul, - acc_ops.dequantize, - acc_ops.sub, - acc_ops.mul, - acc_ops.div, - acc_ops.fmod, - acc_ops.floor_div, - acc_ops.trunc_div, - acc_ops.pow, - acc_ops.relu, - acc_ops.prelu, - acc_ops.leaky_relu, - acc_ops.elu, - acc_ops.selu, - acc_ops.softsign, - acc_ops.tuple_construct, - acc_ops.unsqueeze, - acc_ops.sigmoid, - acc_ops.sum, - acc_ops.prod, - acc_ops.max_full_reduce, - acc_ops.max_dim_reduce, - acc_ops.maximum, - acc_ops.sinh, - acc_ops.cosh, - acc_ops.tanh, - acc_ops.asin, - acc_ops.acos, - acc_ops.atan, - acc_ops.exp, - acc_ops.log, - acc_ops.sqrt, - acc_ops.reciprocal, - acc_ops.abs, - acc_ops.neg, - acc_ops.floor, - acc_ops.ceil, - acc_ops.size, - acc_ops.split, - acc_ops.conv1d, - acc_ops.conv2d, - acc_ops.conv3d, - acc_ops.conv_transpose2d, - acc_ops.conv_transpose3d, - acc_ops.batch_norm, - acc_ops.embedding_bag, - acc_ops.embedding_bag_byte_rowwise_offsets, - acc_ops.embedding_bag_4bit_rowwise_offsets, - acc_ops.contiguous, - acc_ops.pad, - acc_ops.sin, - acc_ops.cos, - acc_ops.tan, - acc_ops.topk, - acc_ops.getitem, - acc_ops.squeeze, - acc_ops.tile, - acc_ops.reshape, - acc_ops.quantized_linear, - acc_ops.quantized_conv2d, - acc_ops.quantized_batch_norm2d, - acc_ops.to_dtype, - acc_ops.clamp, - acc_ops.layer_norm, - acc_ops.linalg_norm, - acc_ops.slice_tensor, - acc_ops.hardsigmoid, - acc_ops.mean, - acc_ops.hardtanh, - acc_ops.gelu, - acc_ops.cumsum, - acc_ops.chunk, - acc_ops.rescale_quantize_per_tensor, - acc_ops.rescale_quantize_per_channel, - acc_ops.nan_to_num, - acc_ops.expand, - acc_ops.masked_fill, - acc_ops.eq, - acc_ops.gt, - acc_ops.lt, - acc_ops.logical_or, - acc_ops.logical_xor, - acc_ops.gather, - acc_ops.index_select, - acc_ops.interpolate, - acc_ops.logical_and, - acc_ops.logical_not, - acc_ops.ne, - acc_ops.device, - acc_ops.numel, - acc_ops.where, - acc_ops.dtype, - acc_ops.isinf, - acc_ops.any, - acc_ops.tensor_split, - acc_ops.new_empty, - acc_ops.new_ones, - acc_ops.einsum, - acc_ops.as_strided, - acc_ops.var, - acc_ops.grid_sample, - acc_ops.xl_weight, - }, - ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py deleted file mode 100644 index a066bc4413..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_dispatch_tracer.py +++ /dev/null @@ -1,245 +0,0 @@ -import copy -import unittest - -import torch -import torch._dynamo as torchdynamo - -import torch._dynamo.config -import torchvision -from functorch.experimental import functionalize -from torch._dynamo.optimizations import backends -from torch._dynamo.optimizations.normalize import normalize_ir - -from torch.library import Library -from torch_tensorrt.dynamo.fx_ts_compat.lower import compile -from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision, proxytensor_trace - -# TODO(ezyang): remove this after we properly support fake example inputs -torch._dynamo.config.DO_NOT_USE_legacy_non_fake_example_inputs = True - -torch.manual_seed(0) - -wrap_lib = Library("wrap", "DEF") -""" -There are two methods for setting leaf_module. leaf(op registeration) and leaf(override call_module) -Only leaf(op registeration) can work together with functionalize. -If you do not need funcitonalize, you can choose any of the leaf module methods. - -Test coverage: -ProxytensorTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration) -DispatchTracerTest.test_leaf_operator_reg: dispatch tracer + functionalize + leaf(op registeration) -DispatchTracerTest.test_leaf: dispatch tracer + leaf(override call_module) -DispatchTracerTest.test_non_tensor_input: dispatch tracer -DispatchTracerTest.test_reference_copy: dispatch tracer + functionalize -DispatchTracerTest.test_reference_copy_torchdynamo: dispatcher tracer + torchdynamo + functionalize -""" - - -class ProxytensorTracerTest(unittest.TestCase): - def test_leaf_operator_reg(self): - class Leaf(torch.nn.Module): - def forward(self, x, y): - return x + y + torch.nn.Parameter(torch.ones(5)) - - leaf = Leaf() - wrap_lib.define("wrapped_foo(Tensor x, Tensor y) -> Tensor") - wrap_lib.impl("wrapped_foo", leaf, "CPU") - - class Bar(torch.nn.Module): - def __init__(self): - super(Bar, self).__init__() - self.foo = torch.ops.wrap.wrapped_foo - self.other = torch.nn.Parameter(torch.ones(5)) - - def forward(self, x, y): - x = self.foo(x, y) - x = x + self.other - return x - - mod = Bar().eval() - inputs = [torch.ones(5), torch.ones(5)] - gm = proxytensor_trace(mod, inputs) - inputs_new = [torch.ones(5) + 5, torch.ones(5) + 8] - output = gm(*inputs_new) - ref_output = mod(*inputs_new) - torch.testing.assert_close(output, ref_output) - - def test_resnet18_dynamo(self): - mod = torchvision.models.resnet18() - mod = mod.cuda().half().eval() - - inputs = [torch.ones(32, 3, 224, 224)] - inputs = [i.cuda().half() for i in inputs] - ref_output = mod(*inputs) - - torchdynamo.reset() - dynamo_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) - dynamo_output = dynamo_mod(*inputs) - cos_val = torch.nn.functional.cosine_similarity( - dynamo_output.flatten(), ref_output.flatten(), dim=0, eps=1e-4 - ) - self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) - - -class DispatchTracerTest(unittest.TestCase): - def test_leaf_operator_reg(self): - class Leaf(torch.nn.Module): - def forward(self, x, y): - return x + y + torch.nn.Parameter(torch.ones(5)) - - leaf = Leaf() - wrap_lib.define("wrapped_leaf(Tensor x, Tensor y) -> Tensor") - wrap_lib.impl("wrapped_leaf", leaf, "CPU") - - class Bar(torch.nn.Module): - def __init__(self): - super(Bar, self).__init__() - self.leaf = torch.ops.wrap.wrapped_leaf - self.other = torch.nn.Parameter(torch.ones(5)) - - def forward(self, x, y): - x = self.leaf(x, y) - x = x + self.other - return x - - mod = Bar() - - def f(x, y): - return mod(x, y) - - gm = make_fx(functionalize(f))(torch.ones(5), torch.ones(5)) - inputs = [torch.ones(5) + 5, torch.ones(5) + 8] - output = gm(*inputs) - ref_output = f(*inputs) - torch.testing.assert_close(output, ref_output) - # through the op registration method, the module is defined in a call_function - call_function_node = None - for node in gm.graph.nodes: - if ( - node.op == "call_function" - and node.target == torch.ops.wrap.wrapped_leaf - ): - call_function_node = node - self.assertIsNotNone(call_function_node) - - ## The test is broken on Aug 27 as the leaf node does not work. P525693772 - # def test_leaf(self): - # class TestModuleLeaf(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.conv = torch.nn.Conv2d(3, 10, 1) - # self.relu = torch.nn.ReLU(inplace=True) - - # def forward(self, x): - # x = self.conv(x) - # return self.relu(x) - - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - - # self.relu = torch.nn.ReLU(inplace=True) - # self.leaf = TestModuleLeaf() - - # def forward(self, x): - # x = self.leaf(x) - # return self.relu(x) - - # mod = TestModule() - - # def f(x): - # return mod(x) - - # a = torch.randn(1, 3, 1, 1) - # ref_output = f(a) - # func = make_fx(f, leaf_module_list={"test_dispatch_tracer.TestModuleLeaf"}) - # gm = func(a) - # output = gm(a) - # torch.testing.assert_close(output, ref_output) - # import pdb;pdb.set_trace() - # # There should be a call module node in the graph. - # call_module_node = None - # for node in gm.graph.nodes: - # if node.op == "call_module": - # call_module_node = node - # self.assertIsNotNone(call_module_node) - # self.assertEqual(call_module_node.target, "TestModuleLeaf_0") - - def test_non_tensor_input(self): - def foo(x): - a = x["a"] - b = x["b"] - return a + b - - x = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} - ref_output = foo(x) - func = make_fx(foo) - gm = func(x) - output = gm(x) - torch.testing.assert_close(output, ref_output) - - def test_reference_copy(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, 0] = x[:, 0] - return y - - mod = TestModule() - - def f(x, y): - return mod(x, y) - - a = torch.ones(2, 2) + 2 - b = torch.ones(2, 2) - b_copy = torch.ones(2, 2) - ref_output = f(a, b) - gm = make_fx(functionalize(f))(a, b) - output = gm(a, b_copy) - torch.testing.assert_close(output, ref_output) - - def test_reference_copy_torchdynamo(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU(inplace=True) - - def forward(self, x, y): - y = y + 3 - y = self.relu(y) - y[:, 0] = x[:, 0] - return y - - mod = TestModule() - - def f(x, y): - return mod(x, y) - - a = torch.ones(2, 2) + 2 - b = torch.ones(2, 2) - inputs = [a, b] - ref_output = f(*inputs) - - def compile_dispatch(gm, example_inputs): - # after normalization, relu in-place is removed - gm = normalize_ir(gm, example_inputs) - # dispatch tracer - nargs = len(example_inputs) - - def fake_signature(fn, nargs): - """FX gets confused by varargs, de-confuse it""" - argnames = ",".join(f"arg{i}" for i in range(nargs)) - return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) - - gm = make_fx(functionalize(fake_signature(gm, nargs)))(*example_inputs) - return gm - - optimized_mod = torchdynamo.optimize( - compile_dispatch, - nopython=True, - )(mod) - output = optimized_mod(*inputs) - torch.testing.assert_close(output, ref_output) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py deleted file mode 100644 index 1dfdfa7125..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/tracer/test_resnet.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest - -import torch - -import torch._dynamo.config -import torchvision -from torch_tensorrt.dynamo.fx_ts_compat.lower import compile -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision - - -class ResnetTest(unittest.TestCase): - def test_resnet18_aten(self): - mod = torchvision.models.resnet18() - mod = mod.cuda().half().eval() - - inputs = [torch.ones(32, 3, 224, 224)] - inputs = [i.cuda().half() for i in inputs] - - aten_mod = compile( - mod, - inputs, - enabled_precisions={torch.float16}, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - is_aten=True, - ) - aten_output = aten_mod(*inputs) - aten_output = aten_output[0] - fx_mod = compile( - mod, - inputs, - enabled_precisions={torch.float16}, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - is_aten=False, - ) - fx_output = fx_mod(*inputs) - # Kernel selection is tricky in TRT with big variance as shown below: - # Mismatched elements: 30816 / 32000 (96.3%) - # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) - # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) - # so we choose to use cosine similarity - cos_val = torch.nn.functional.cosine_similarity( - aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 - ) - self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) - - def test_resnet18_aten_dynamic(self): - mod = torchvision.models.resnet18() - mod = mod.cuda().half().eval() - - inputs = [torch.ones(32, 3, 224, 224)] - inputs = [i.cuda().half() for i in inputs] - - aten_mod = compile( - mod, - inputs, - enabled_precisions={torch.float16}, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - is_aten=True, - ) - aten_output = aten_mod(*inputs) - aten_output = aten_output[0] - fx_mod = compile( - mod, - inputs, - enabled_precisions={torch.float16}, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - is_aten=False, - ) - fx_output = fx_mod(*inputs) - - cos_val = torch.nn.functional.cosine_similarity( - aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 - ) - self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_diagnostics.py deleted file mode 100644 index 3ce3b7ade8..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_diagnostics.py +++ /dev/null @@ -1,200 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] -import functools -import glob -import logging -import os -import shutil -import tempfile -from typing import Union -from unittest import TestCase - -import torch_tensorrt.fx.diagnostics as diag - - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -def reset_diag(fn): - @functools.wraps(fn) - def reset(*a, **kw): - try: - tok1 = diag._CURRENT_COLLECTOR.set(None) - tok2 = diag._CURRENT_WRITER.set(None) - tok3 = diag._SUBSEQUENT_COLLECT_SUPPRESSED_BY.set(None) - return fn(*a, **kw) - finally: - diag._CURRENT_COLLECTOR.reset(tok1) - diag._CURRENT_WRITER.reset(tok2) - diag._SUBSEQUENT_COLLECT_SUPPRESSED_BY.reset(tok3) - - return reset - - -class Fx2trtDiagnosticsTest(TestCase): - @reset_diag - def test_diagnostics(self): - collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) - - diag.set_current_collector(collector) - - try: - with diag.collect_when_fail(): - diag.write("aaa", "hello") - diag.write("bbb", lambda: "world") - diag.write("ccc", b"123") - diag.write("ddd", lambda: b"456") - - def boom() -> str: - raise AssertionError("Error generating diagnostics.") - - diag.write("eee", boom) - - diag.write("zzz", "done") - raise _UserDefinedError("Error while lowering") - except _UserDefinedError: - pass - - zip_fn = collector._last_zip_path_for_test - assert os.path.exists(zip_fn) - with tempfile.TemporaryDirectory() as tempdir: - _LOGGER.info(f"Unpacking into {tempdir}") - shutil.unpack_archive(zip_fn, tempdir) - _check_file(tempdir, "aaa", "hello") - _check_file(tempdir, "bbb", "world") - _check_file(tempdir, "ccc", b"123") - _check_file(tempdir, "ddd", b"456") - _check_file(tempdir, "zzz", "done") - # file eee should still exist to contain err msg - _check_file(tempdir, "eee", "") - - @reset_diag - def test_condition_func_name(self): - collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) - diag.set_current_collector(collector) - - with diag.collect_when( - diag.CollectionConditions.when_called_by_function( - self.test_condition_func_name.__name__ - ) - ): - diag.write("aaa", "hello") - - zip_fn = collector._last_zip_path_for_test - assert os.path.exists(zip_fn) - with tempfile.TemporaryDirectory() as tempdir: - _LOGGER.info(f"Unpacking into {tempdir}") - shutil.unpack_archive(zip_fn, tempdir) - _check_file(tempdir, "aaa", "hello") - - @reset_diag - def test_write_without_collect(self): - collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) - diag.set_current_collector(collector) - diag.write("aaa", "hello") - root_dir = diag.get_current_writer().root_dir() - res = glob.glob(f"{root_dir}/*") - assert not res # root dir should be empty - - def test_conditions(self): - - _test_cond( - diag.CollectionConditions.when_called_by_function( - self.test_conditions.__name__ - ), - should_collect=True, - ) - - _test_cond( - diag.CollectionConditions.when_called_by_function("moo_baa_la_la_la"), - should_collect=False, - ) - - _test_cond( - diag.CollectionConditions.any( - diag.CollectionConditions.never(), - diag.CollectionConditions.always(), - ), - True, - ) - - _test_cond( - diag.CollectionConditions.all( - diag.CollectionConditions.never(), - diag.CollectionConditions.always(), - ), - False, - ) - - _test_cond( - diag.CollectionConditions.not_( # returns False - diag.CollectionConditions.always(), # returns True - ), - False, - ) - - _test_cond( - diag.CollectionConditions.when_not_in_tests(), - False, # Yes we are in test right now - ) - - # nested - _test_cond( - diag.CollectionConditions.any( - diag.CollectionConditions.never(), - diag.CollectionConditions.any( - diag.CollectionConditions.always(), - ), - ), - True, - ) - - -@reset_diag -def _test_cond( - cond: diag.CollectionCondition, - should_collect: bool, -) -> None: - collector = diag.ZipDiagnosticsCollector(writer=diag.get_current_writer()) - diag.set_current_collector(collector) - - with diag.collect_when(cond): - diag.write("aaa", "hello") - - zip_fn = collector._last_zip_path_for_test - if should_collect: - assert os.path.exists(zip_fn) - with tempfile.TemporaryDirectory() as tempdir: - _LOGGER.info(f"Unpacking into {tempdir}") - shutil.unpack_archive(zip_fn, tempdir) - _check_file(tempdir, "aaa", "hello") - else: - assert not zip_fn, "the collection should not have triggered" - - -def _check_file(dir: str, fn: str, content: Union[str, bytes]): - fp = os.path.join(dir, fn) - res = glob.glob(f"{fp}*") - assert len(res) == 1 - fp = res[0] - if not os.path.exists(fp): - raise _CheckFileDoesNotExist(f"{fp} must exist") - if not content: - # don't check content then - return - if isinstance(content, bytes): - with open(fp, "rb") as f: - content_actual = f.read() - assert content == content_actual - else: - content: str - with open(fp, "r", encoding="utf-8") as f: - content_actual = f.read() - assert content == content_actual - - -class _UserDefinedError(Exception): - pass - - -class _CheckFileDoesNotExist(AssertionError): - pass diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py deleted file mode 100644 index 4077fe1491..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_fx2trt_lower.py +++ /dev/null @@ -1,104 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import logging -import unittest - -import torch -import torch.fx as fx -import torch.nn as nn -from torch_tensorrt.dynamo.fx_ts_compat.lower import Lowerer, LowerSetting -from torch_tensorrt.fx.passes.lower_basic_pass import replace_mutable_op - -logger = logging.getLogger(__name__) - - -class Fx2trtLowerTests(unittest.TestCase): - def test_fx2trt_lower(self): - class _Mod(nn.Module): - def forward(self, x): - return (x, 2 * x) - - mod = _Mod() - mod_traced = fx.symbolic_trace(mod) - input = [torch.rand(4).cuda()] - lower = Lowerer.create(LowerSetting()) - lower(mod_traced, input) - - def test_lower_with_batchnorm_act_rewrite(self): - class MyBatchNorm(nn.BatchNorm2d): - def forward(self, x): - self._check_input_dim(x) - return x + 1 - - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.bn = MyBatchNorm(3) - - def forward(self, x): - return self.bn(x) - - module = TestModule() - inputs = [torch.randn(1, 3, 224, 224).cuda()] - lower = Lowerer.create(LowerSetting(ast_rewriter_allow_list={MyBatchNorm})) - lower(module, inputs) - - def test_lower_const_fold(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.a = nn.Parameter(torch.randn(1)) - - def forward(self, x): - return (torch.sqrt(x), self.a) - - lower = Lowerer.create(LowerSetting()) - lower(TestModule(), [torch.randn([2, 2]).cuda()]) - - def test_replace_mutable_op(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - xf = x.fill_(100) - yf = y.fill_(200) - c = torch.cat([xf, yf], dim=1) - return c - - lower = Lowerer.create(LowerSetting()) - mod_traced = fx.symbolic_trace(TestModule()) - lower(mod_traced, [torch.randn(3, 4).cuda(), torch.randn(3, 4).cuda()]) - - def test_replace_mutable_op_dont_apply(self): - class TestModule(torch.nn.Module): - def forward(self, x): - s = x + 1 - t = s.fill_(5) - p = s + t - return p - - mod_traced = fx.symbolic_trace(TestModule()) - old_code = mod_traced.code - - transformed = replace_mutable_op(mod_traced) - new_code = transformed.code - - # s.fill_ shouldn't have been replaced - # because s is used later - self.assertEqual(old_code, new_code) - - def test_replace_mutable_op_do_apply(self): - class TestModule(torch.nn.Module): - def forward(self, x): - s = x + 1 - t = s.fill_(5) # s not used afterwards - p = x + t - return p - - mod_traced = fx.symbolic_trace(TestModule()) - old_code = mod_traced.code - - transformed = replace_mutable_op(mod_traced) - new_code = transformed.code - - # s.fill_ should have been replaced - # because s is not used afterwards - self.assertNotEqual(old_code, new_code) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer.py deleted file mode 100644 index 58f23c0a13..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer.py +++ /dev/null @@ -1,128 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] -import functools -import logging -import typing as t -from contextlib import contextmanager -from unittest import TestCase - -import torch_tensorrt.fx.observer as ob -from torch_tensorrt.fx.observer import observable - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -def set_observer_callback_rethrow(fn): - """ - Specify that observer callback exceptions should be re-thrown (default - behavior is to swallow) Re-throw is only for test purpose. - """ - - @functools.wraps(fn) - def fn_(*args, **kwargs): - try: - ob.RETHROW_CALLBACK_EXCEPTION = True - return fn(*args, **kwargs) - finally: - ob.RETHROW_CALLBACK_EXCEPTION = False - - return fn_ - - -class ObserverTests(TestCase): - @set_observer_callback_rethrow - def test_basics(self): - @observable() - def foo(x, y, z): - return x + y + z - - with execution_verifier() as verify_execution: - - @verify_execution - def log_pre(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling log: {ctx}") - assert ctx.callable is foo.orig_func - assert ctx.args == (1, 2) - assert ctx.kwargs == {"z": 3} - assert not ctx.return_value - - @verify_execution - def log_post(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling log: {ctx}") - assert ctx.callable is foo.orig_func - assert ctx.args == (1, 2) - assert ctx.kwargs == {"z": 3} - assert ctx.return_value == 6 - - with foo.observers.pre.add(log_pre), foo.observers.post.add(log_post): - foo(1, 2, z=3) - - with execution_verifier() as verify_execution: - - @verify_execution - def log_pre(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling log: {ctx}") - - @verify_execution - def log_post(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling log: {ctx}") - - foo.observers.pre.add(log_pre) - foo.observers.post.add(log_post) - foo(1, 2, 3) - - with execution_verifier() as verify_execution: - - @verify_execution - def f1(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling f1: {ctx}") - - @verify_execution - def f2(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling f2: {ctx}") - - # Test that we can register the same observation point twice - with foo.observers.pre.add(f1): - with foo.observers.pre.add(f2): - foo(1, 2, z=3) - - def test_observer_callbacks_should_not_throw(self): - @observable() - def foo(x, y, z): - return x + y + z - - with execution_verifier() as verify_execution: - - @verify_execution - def log_pre(ctx: ob.ObserveContext) -> None: - _LOGGER.info(f"calling log: {ctx}") - raise CallbackError("TEST CALLBACK EXCEPTION") - - with foo.observers.pre.add(log_pre): - foo(1, 2, 3) - - -@contextmanager -def execution_verifier(): - _is_called: t.Dict[callable, bool] = {} - - def verify_executed(fn): - _is_called[fn] = False - - @functools.wraps(fn) - def fn_(*args, **kwargs): - _is_called[fn] = True - return fn(*args, **kwargs) - - return fn_ - - try: - yield verify_executed - except: # noqa: B001 - raise - else: - for fn, was_executed in _is_called.items(): - assert was_executed, f"{fn} was not executed" - - -class CallbackError(Exception): - pass diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py deleted file mode 100644 index bd17f42e72..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/test_observer_gpu.py +++ /dev/null @@ -1,51 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] -import functools -from unittest import TestCase - -import torch_tensorrt.fx.observer as ob -from test_observer import execution_verifier, set_observer_callback_rethrow -from torch_tensorrt.fx.passes.lower_basic_pass import fuse_permute_linear - - -class ObserverGPUTests(TestCase): - @set_observer_callback_rethrow - def test_observe_lowerer(self): - """ - Test that we can observe the execution of `fuse_permute_linear` during - lowering. - """ - from dataclasses import replace - - import torch - import torch.nn as nn - - import torch_tensorrt.dynamo.fx_ts_compat.lower as lower - from torch_tensorrt.dynamo.fx_ts_compat.lower_setting import LowerSetting - - class Model(nn.Module): - def forward(self, x, y): - return x + y - - mod = Model().cuda() - inp = [torch.rand(1, 10), torch.rand(1, 10)] - inp = [i.cuda() for i in inp] - mod(*inp) - - with execution_verifier() as verify_execution: - - lowerer = lower.Lowerer.create(lower_setting=LowerSetting(min_block_size=0)) - - @verify_execution - def observe_fuse_permute_linear_post(ctx: ob.ObserveContext): - """ - Called when fuse_permute_linear is executed. Decorated with - `verify_execution` so if this function is not executed, the - test fails. - """ - assert ctx.callable is fuse_permute_linear.orig_func - - # Register the observer callback and do the lowering - with fuse_permute_linear.observers.post.add( - observe_fuse_permute_linear_post - ): - lowerer(mod, inp) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py deleted file mode 100644 index ebccd3c08b..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_operator_supported_test.py +++ /dev/null @@ -1,82 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import torch -import torch.fx -import torch.nn as nn -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops # noqa: F401 -from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import ( - create_trt_operator_support, -) -from torch_tensorrt.fx.tracer.acc_tracer import acc_ops, acc_tracer - - -class TestTRTOperatorSupport(TestCase): - def test_supported_node_target(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(1, 1) - - def forward(self, x): - x = self.linear(x) - x = x + 1 - return torch.add(input=x, other=x) - - mod = TestModule() - traced_mod = acc_tracer.trace(mod, [torch.randn(1, 2, 1, 1)]) - op_support = create_trt_operator_support() - for node in traced_mod.graph.nodes: - self.assertTrue(op_support.is_node_supported(mod, node)) - - def test_unsupport_node_explicit_batch_dim(self): - class TestModule(nn.Module): - def forward(self, x): - y = torch.add(input=x, other=x) - return torch.max_pool1d(y, 1) - - mod = TestModule() - traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)]) - op_support = create_trt_operator_support(use_implicit_batch_dim=False) - - for node in traced_mod.graph.nodes: - if node.target == acc_ops.add: - self.assertTrue(op_support.is_node_supported(mod, node)) - elif node.target == acc_ops.split: - self.assertFalse(op_support.is_node_supported(mod, node)) - - def test_unsupport_node_implicit_batch_dim(self): - class TestModule(nn.Module): - def forward(self, x): - y = torch.add(input=x, other=x) - return nn.functional.gelu(y) - - mod = TestModule() - traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)]) - op_support = create_trt_operator_support(use_implicit_batch_dim=True) - - for node in traced_mod.graph.nodes: - if node.target == acc_ops.add: - self.assertTrue(op_support.is_node_supported(mod, node)) - elif node.target == acc_ops.gelu: - self.assertFalse(op_support.is_node_supported(mod, node)) - - def test_support_node_with_int_attr(self): - class TestModule(nn.Module): - def forward(self, x): - zeros = torch.randint(3, 5, (1,)) - zeros = zeros.to(torch.int64) - scale = torch.randn(1) - return torch.quantize_per_tensor(x, scale, zeros, torch.quint8) - - mod = TestModule() - traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)]) - op_support = create_trt_operator_support(use_implicit_batch_dim=True) - - for node in traced_mod.graph.nodes: - if node.target == acc_ops.quantize_per_tensor: - self.assertTrue(op_support.is_node_supported(mod, node)) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py deleted file mode 100644 index 6421f662fc..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/trt_lower/trt_splitter_test.py +++ /dev/null @@ -1,1179 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import operator - -import torch # isort:skip -import torch.fx # isort:skip - -import torch.fx.passes.operator_support as op_support -import torch.fx.passes.shape_prop as shape_prop -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.fx.passes import splitter_base -from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat.tools.trt_splitter import ( - TRTSplitter, - TRTSplitterSetting, -) -from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer - -ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!" -ERROR_MSG_MULTI_ACC_MODULES = "FX split failed: Found more than one ACC submodules!" -ACC_SUBMODULE_PREFIX = "_run_on_acc_" - - -# Check if the split result has expected number of ACC submodule. If not, raise runtime error; -def verify_split_model( - mod: torch.fx.GraphModule, - acc_submodule_keyword: str = ACC_SUBMODULE_PREFIX, - expected_number: int = 1, -) -> None: - acc_submodule_num = 0 - for name, _ in mod.named_children(): - if name.startswith(acc_submodule_keyword): - acc_submodule_num = acc_submodule_num + 1 - - if acc_submodule_num < expected_number: - raise RuntimeError(ERROR_MSG_NO_ACC_MODULE) - elif acc_submodule_num > expected_number: - raise RuntimeError(ERROR_MSG_MULTI_ACC_MODULES) - - -def find_inputs(module): - return [n for n in module.graph.nodes if n.op == "placeholder"] - - -def find_fun_calls(module, target): - return [ - n for n in module.graph.nodes if n.op == "call_function" and n.target == target - ] - - -def find_output(module): - return next(n for n in module.graph.nodes if n.op == "output") - - -TENSOR_SIZE_DUMMY = "tensor_size_dummy" - - -def find_call_targets(module: torch.fx.GraphModule): - result = set() - for n in module.graph.nodes: - n: torch.fx.Node - if n.op in {"call_module", "call_function", "call_method"}: - result.add(n.target) - return result - - -# We test both FxNetSplitOnly and FxNetSplitter here, since they share most -# functionalities. The only difference is that FxNetSplitOnly does not implement -# split_preview() related functions, while FxNetSplitter does. -class TestSplit(TestCase): - def test_demo(self): - """ - ==> b ==> - // \\ - a d - \\ // - ==> c ==> - """ - - class SimpleModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.cos(a) - d = b + c - return d - - mod = acc_tracer.trace(SimpleModule(), [torch.randn(2, 3)]) - - # Making b and c run on ACC - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.sin": None, - "acc_ops.cos": None, - } - ), - ) - - st_split = splitter() - - [arg] = find_inputs(st_split) - - # First subgraph calculates b = sin(a) and c = cos(a) on ACC - [sin] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sin) - self.assertEqual(arg.name, sin.kwargs["input"].name) - - [cos] = find_fun_calls(st_split._run_on_acc_0, acc_ops.cos) - self.assertEqual(arg.name, cos.kwargs["input"].name) - - # Second subgraph calculates d = b + c on CPU - [add] = find_fun_calls(st_split._run_on_gpu_1, acc_ops.add) - self.assertEqual(sin.name, add.kwargs["input"].name) - self.assertEqual(cos.name, add.kwargs["other"].name) - - def test_mod_with_getattr(self): - """ - CPU subgraph should have get_attr for self.a while ACC subgraph - should have get_attr for self.b. - """ - - class SimpleModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(1, 1, 1, 1) - self.b = torch.randn(1, 1, 1, 1) - self.conv = torch.nn.Conv2d(1, 1, 1) - self.linear = torch.nn.Linear(1, 1) - - def forward(self, x): - x = x + self.a - x = self.conv(x) - return self.linear(x - self.b) - - mod = acc_tracer.trace(SimpleModule(), [torch.randn(1, 1, 1, 1)]) - mod.eval() - - splitter = TRTSplitter( - mod, - (torch.randn(1, 1, 1, 1),), - op_support_with_support_dict( - { - "acc_ops.linear": None, - "acc_ops.sub": None, - } - ), - ) - - def test_splitter(splitter): - st_split = splitter() - verify_split_model(st_split) - # Should be "a", "conv.weight", "conv.bias". - get_attr_nodes = [ - node.target - for node in st_split._run_on_gpu_0.graph.nodes - if node.op == "get_attr" - ] - assert len(get_attr_nodes) == 3 and "a" in get_attr_nodes - - # Should be "b", "conv.weight", "conv.bias". - get_attr_nodes = [ - node.target - for node in st_split._run_on_acc_1.graph.nodes - if node.op == "get_attr" - ] - assert len(get_attr_nodes) == 3 and "b" in get_attr_nodes - - test_splitter(splitter) - - def test_nothing_to_split(self): - class SimpleModule(torch.nn.Module): - def forward(self, a): - return a - - mod = acc_tracer.trace(SimpleModule(), [torch.randn(2, 3)]) - - # Mark any operation as runnable on ACC - class CustomOpSupport(op_support.OperatorSupportBase): - def is_node_supported(self, submodules, node): - return True - - splitter = TRTSplitter(mod, (torch.randn(2, 3),), CustomOpSupport()) - - def test_splitter(splitter): - st_split = splitter() - try: - verify_split_model(st_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) - self.assertEqual(splitter.module.__dict__.keys(), st_split.__dict__.keys()) - - test_splitter(splitter) - - def test_multi_output(self): - class MultiOutputModule(torch.nn.Module): - def forward(self, x): - res, ind = torch.topk(x, 3) - return torch.sigmoid(res), ind - - mod = acc_tracer.trace(MultiOutputModule(), [torch.randn(2, 3)]) - - # Mark any operation as runnable on ACC - class CustomOpSupport(op_support.OperatorSupportBase): - def is_node_supported(self, submodules, node): - return True - - splitter = TRTSplitter(mod, (torch.randn(2, 3),), CustomOpSupport()) - - def test_splitter(splitter): - st_split = splitter() - verify_split_model(st_split) - [arg] = find_inputs(st_split) - - # There is only one subgraph that executes topk and sigmoid on ACC - [topk] = find_fun_calls(st_split._run_on_acc_0, acc_ops.topk) - self.assertEqual(arg.name, topk.kwargs["input"].name) - self.assertEqual(3, topk.kwargs["k"]) - - [topk_res1, topk_res2] = find_fun_calls( - st_split._run_on_acc_0, acc_ops.getitem - ) - - [sigmoid] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sigmoid) - self.assertIn( - sigmoid.kwargs["input"].name, {topk_res1.name, topk_res2.name} - ) - - # Main graph returns a tuple - output = find_output(st_split._run_on_acc_0) - self.assertLess( - {output.args[0][0].name, output.args[0][1].name}, - {topk_res1.name, topk_res2.name, sigmoid.name}, - ) - - test_splitter(splitter) - - def test_nested_modules(self): - """ - x - // \\ - // \\ - relu(x) sin(x) - \\ // - \\ // - relu(x) + sin(x) - """ - - class ReluModule(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - class SinModule(torch.nn.Module): - def forward(self, x): - return torch.sin(x) - - class TestModule3(torch.nn.Module): - def __init__(self, relu_module, sin_module): - super().__init__() - self.relu_module = relu_module - self.sin_module = sin_module - - def forward(self, x): - return self.relu_module(x) + self.sin_module(x) - - mod = acc_tracer.trace( - TestModule3(ReluModule(), SinModule()), [torch.randn(2, 3)] - ) - - # Making sin(x) run on ACC - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.sin": None, - } - ), - ) - - def test_splitter(splitter): - st_split = splitter() - verify_split_model(st_split) - [arg] = find_inputs(st_split) - - # First subgraph calculates relu(x) on CPU - [relu] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.relu) - self.assertEqual(arg.name, relu.kwargs["input"].name) - - # Second subgraph calculates sin(x) on ACC - [sin] = find_fun_calls(st_split._run_on_acc_1, acc_ops.sin) - self.assertEqual(arg.name, sin.kwargs["input"].name) - - # Third subgraph calculates sum on CPU - [add] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.add) - self.assertEqual(relu.name, add.kwargs["input"].name) - self.assertEqual(sin.name, add.kwargs["other"].name) - - # Checking that results of applying split module will be the same - tensor = torch.randn(5) - self.assertTrue(torch.equal(mod(tensor), st_split(tensor))) - - test_splitter(splitter) - - def test_longer_chain(self): - """ - sin relu cos sigmoid tanh - a ====> b =====> c ====> d ========> e =====> f - """ - - class TestModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.relu(b) - d = torch.cos(c) - e = torch.sigmoid(d) - f = torch.tanh(e) - return f - - mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) - - # Making relu and sigmoid execute on ACC - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.relu": None, - "acc_ops.sigmoid": None, - } - ), - ) - - def test_splitter(splitter): - st_split = splitter() - try: - verify_split_model(st_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) - [arg] = find_inputs(st_split) - - # First subgraph calculates b = sin(a) on CPU - [sin] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.sin) - self.assertEqual(arg.name, sin.kwargs["input"].name) - - # Second subgraph calculates c = relu(b) on ACC - [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu) - self.assertEqual(sin.name, relu.kwargs["input"].name) - - # Third subgraph calculates d = cos(c) on CPU - [cos] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.cos) - self.assertEqual(relu.name, cos.kwargs["input"].name) - - # Fourth subgraph calculates e = sigmoid(d) on ACC - [sigmoid] = find_fun_calls(st_split._run_on_acc_3, acc_ops.sigmoid) - self.assertEqual(cos.name, sigmoid.kwargs["input"].name) - - # Fifth subgraph calculates f = tanh(e) on CPU - [tanh] = find_fun_calls(st_split._run_on_gpu_4, acc_ops.tanh) - self.assertEqual(sigmoid.name, tanh.kwargs["input"].name) - - test_splitter(splitter) - - def test_min_acc_module_size(self): - """ - sin relu cos sigmoid tanh - a ====> b =====> c ====> d ========> e =====> f - - We set sin, cos and tanh as acc node but also set min_acc_module_size to 2 - and expect the whole module stay on CPU. - """ - - class TestModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.relu(b) - d = torch.cos(c) - e = torch.sigmoid(d) - f = torch.tanh(e) - return f - - mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) - - # Set sin, cos and tanh as acc node and split with settings - class CustomOpSupport(op_support.OperatorSupport): - _support_dict = { - "acc_ops.sin": None, - "acc_ops.cos": None, - "acc_ops.tanh": None, - } - - # Create splitter setting and set min_acc_module_size to 2 - settings = splitter_base._SplitterSettingBase() - settings.min_acc_module_size = 2 - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.sin": None, - "acc_ops.cos": None, - "acc_ops.tanh": None, - } - ), - settings, - ) - - def test_splitter(splitter): - st_split = splitter() - try: - verify_split_model(st_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) - modules = list(st_split.named_modules()) - # Main module and a submodule - assert len(modules) == 2 - - assert modules[1][0] == "_run_on_gpu_0" - - test_splitter(splitter) - - def test_extend_acc_subgraph_after_split(self): - class TestModule(torch.nn.Module): - r""" a (input) - | - b - / \ - c d - \ / - e - / \ - | (g1, g2, g3, g4) - \ / | - f | - \ | - h - - c and f are not runnable on acc while all other nodes are supported by acc. - g1, g2, g3 and g4 should be in a fusion group, let's call it g. - - After split we have 2 cpu subgraphs (c) and (f), 3 acc subgraphs (b, d), (e, g) and (h). - We expect 3 acc subgraphs (b), (d, e, g) and (h) after extend the second acc subgraph. - And expect acc subgraphs stay the same after extend the third acc subgraph because of - the unbreakable fusion group. - """ - - def forward(self, a: torch.Tensor): - b = a + a - c = b - b - d = b + b - e = c + d - - # These four nodes should be in a fusion group - g1 = e.size() - g2 = g1[0] - g3 = e + g2 - g4 = g3 + g2 - - f = e - g3 - h = f + g4 - return h - - a = torch.randn(2) - mod = acc_tracer.trace(TestModule(), (a,)) - - # Allow all nodes expect subtract run on accelerator - class CustomOpSupport(op_support.OperatorSupportBase): - def is_node_supported(self, submodules, node): - return op_support.get_node_target(submodules, node) != "acc_ops.sub" - - splitter = TRTSplitter(mod, (a,), CustomOpSupport()) - - def test_splitter(splitter): - # Manually tag nodes first in case split algorithm changes in the future - nodes = list(splitter.module.graph.nodes) - # b and d - nodes[1].tag = "acc_0" - nodes[3].tag = "acc_0" - # c - nodes[2].tag = "cpu_1" - # e and g - nodes[4].tag = "acc_2" - nodes[5].tag = "acc_2" - nodes[6].tag = "acc_2" - nodes[7].tag = "acc_2" - nodes[8].tag = "acc_2" - # f - nodes[9].tag = "cpu_3" - # h - nodes[10].tag = "acc_4" - - splitter.tags = ["acc_0", "cpu_1", "acc_2", "cpu_3", "acc_4"] - split_module = splitter.split() - try: - verify_split_model(split_module, "acc_") - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) - try: - verify_split_model(split_module) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) - - module_names = [name for name, _ in split_module.named_modules()] - # Main module, 2 cpu submodules and 3 acc submodule - assert len(module_names) == 6 - - # 1 Placeholder, 2 Adds and 1 Output - assert len(split_module.acc_0.graph.nodes) == 4 - # 2 Placeholder, 3 Adds, 1 Size, 1 GetItem and 1 Output - assert len(split_module.acc_2.graph.nodes) == 8 - - # Extend the second acc subgraph - splitter.extend_acc_subgraph("acc_2") - extend_module = splitter.split() - try: - verify_split_model(extend_module, "acc_") - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) - - # 1 Placeholder, 1 Adds and 1 Output - assert len(extend_module.acc_0.graph.nodes) == 3 - # 2 Placeholder, 4 Adds 1 Size, 1 GetItem and 1 Output - assert len(extend_module.acc_2.graph.nodes) == 9 - - # Extend the third acc subgraph - splitter.extend_acc_subgraph("acc_4") - extend_module = splitter.split() - try: - verify_split_model(extend_module, "acc_") - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) - - assert len(extend_module.acc_2.graph.nodes) == 9 - # 2 Placeholder, 1 Adds and 1 Output - assert len(extend_module.acc_4.graph.nodes) == 4 - - test_splitter(splitter) - - def test_get_attr_into_output(self): - """ - Here we verify the case when get_attr node is consumed directly by the - output. We don't expect any split to happen in this test, just want to - make sure that the splitter code doesn't break. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 3) - - def forward(self, x): - return (x, self.a) - - # No need to put anything on ACC. - class TestOperatorSupport: - def is_node_supported(self, submodules, node): - return False - - module_original = acc_tracer.trace(TestModule(), [torch.randn(4, 5)]) - - splitter = TRTSplitter( - module=module_original, - sample_input=torch.randn(4, 5), - operator_support=TestOperatorSupport(), - ) - - def test_splitter(splitter): - module_split = splitter() - try: - verify_split_model(module_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) - - output = find_output(module_split) - # Second argument of the output should be get_attr. - self.assertEqual("get_attr", output.args[0][1].op) - - # Check if modules are equivalent. - tensor = torch.randn(10, 20) - result_original = module_original(tensor) - result_split = module_split(tensor) - self.assertTrue(torch.equal(result_original[0], result_split[0])) - self.assertTrue(torch.equal(result_original[1], result_split[1])) - - test_splitter(splitter) - - def test_get_attr_into_starter_node(self): - """ - Here we verify the case when starter nodes depend on get_attr node only. - We don't expect any split to happen in this test, just want to make sure - that the splitter code doesn't break. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.randn(2, 3) - - def forward(self): - m = self.a + self.a - o = m + m - return o - - # No need to put anything on ACC. - class TestOperatorSupport: - def is_node_supported(self, submodules, node): - return False - - module_original = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) - - splitter = TRTSplitter( - module=module_original, - sample_input=torch.randn(2, 3), - operator_support=TestOperatorSupport(), - ) - - def test_splitter(splitter): - module_split = splitter() - try: - verify_split_model(module_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) - - # Check if modules are equivalent. - result_original = module_original() - result_split = module_split() - self.assertTrue(torch.equal(result_original, result_split)) - - test_splitter(splitter) - - -class TestSplitComplexGraph(TestCase): - """ - a ====== - // \\ \\ - b c d - \\ // // - e // - \\ // - \\ // - f - """ - - class TestModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.relu(a) - d = torch.cos(a) - e = b + c - f = e - d - return f - - def test_split_complex_graph_1(self): - mod = acc_tracer.trace(self.TestModule(), [torch.randn(2, 3)]) - - # Making 'c' and 'd' run on ACC - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.cos": None, - "acc_ops.relu": None, - } - ), - ) - - def test_splitter(splitter): - st_split = splitter() - verify_split_model(st_split) - - [arg] = find_inputs(st_split) - - # First subgraph calculates b = sin(a) on CPU - [sin] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.sin) - self.assertEqual(arg.name, sin.kwargs["input"].name) - - # Second subgraph calculates c = relu(a) and d = cos(a) on ACC - [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu) - self.assertEqual(arg.name, relu.kwargs["input"].name) - - [cos] = find_fun_calls(st_split._run_on_acc_1, acc_ops.cos) - self.assertEqual(arg.name, cos.kwargs["input"].name) - - # Third subgraph calculates the e = b + c and f = e - d on CPU - [add] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.add) - self.assertEqual(sin.name, add.kwargs["input"].name) - self.assertEqual(relu.name, add.kwargs["other"].name) - - [sub] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.sub) - self.assertEqual(add.name, sub.kwargs["input"].name) - self.assertEqual(cos.name, sub.kwargs["other"].name) - - test_splitter(splitter) - - def test_split_complex_graph_2(self): - module_nn = self.TestModule() - module = acc_tracer.trace(module_nn, (torch.randn(2, 3),)) - - # Making 'c', 'd' and 'e' run on ACC - splitter = TRTSplitter( - module, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.cos": None, - "acc_ops.relu": None, - "acc_ops.add": None, - } - ), - ) - - def test_splitter(splitter): - module_fx_split = splitter() - verify_split_model(module_fx_split) - - [arg] = find_inputs(module) - - # First subgraph calculates b = sin(a) on CPU - [sin] = find_fun_calls(module_fx_split._run_on_gpu_0, acc_ops.sin) - self.assertEqual(arg.name, sin.kwargs["input"].name) - - # Second subgraph calculates c = relu(a), d = cos(a) and e = b + c on ACC - [relu] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.relu) - self.assertEqual(arg.name, relu.kwargs["input"].name) - - [cos] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.cos) - self.assertEqual(arg.name, cos.kwargs["input"].name) - - [add] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.add) - self.assertEqual(sin.name, add.kwargs["input"].name) - self.assertEqual(relu.name, add.kwargs["other"].name) - - # Third subgraph calculates f = e + d on CPU - [sub] = find_fun_calls(module_fx_split._run_on_gpu_2, acc_ops.sub) - self.assertEqual(add.name, sub.kwargs["input"].name) - self.assertEqual(cos.name, sub.kwargs["other"].name) - - test_splitter(splitter) - - -class TestSplitNonTensorEdges(TestCase): - """ - a (relu) - // \\ - (b1,b2) c (cos) - \\ // - d (add) - || - e (sigmoid) - """ - - # Note non-tensor edge between b2 and d - class TestModule(torch.nn.Module): - def forward(self, x): - a = torch.relu(x) - - b1 = a.size() - b2 = b1[0] - - c = torch.cos(a) - - d = b2 + c - e = torch.sigmoid(d) - return e - - def test_split_non_tensor_edges_1(self): - test_data = torch.randn(2, 3) - - module_nn = acc_tracer.trace(self.TestModule(), (test_data,)) - - # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC - splitter = TRTSplitter( - module_nn, - (test_data,), - op_support_with_support_dict( - { - "acc_ops.relu": None, - "acc_ops.sigmoid": None, - "acc_ops.add": None, - "acc_ops.getitem": None, - "acc_ops.size": None, - } - ), - ) - - def test_splitter(splitter): - module_fx_split = splitter() - try: - verify_split_model(module_fx_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) - - self.assertEqual( - {acc_ops.relu}, find_call_targets(module_fx_split._run_on_acc_0) - ) - - self.assertEqual( - {acc_ops.cos}, find_call_targets(module_fx_split._run_on_gpu_1) - ) - - self.assertEqual( - {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, - find_call_targets(module_fx_split._run_on_acc_2), - ) - - # Make sure we can compile to TorchScript - module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) - self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) - - test_splitter(splitter) - - def test_split_non_tensor_edges_2(self): - test_data = torch.randn(2, 3) - - module_nn = acc_tracer.trace(self.TestModule(), (test_data,)) - - # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC with limit on ACC - # subgraph size - settings = splitter_base._SplitterSettingBase() - settings.min_acc_module_size = 2 - splitter = TRTSplitter( - module_nn, - (test_data,), - op_support_with_support_dict( - { - "acc_ops.relu": None, - "acc_ops.sigmoid": None, - "acc_ops.add": None, - "acc_ops.getitem": None, - "acc_ops.size": None, - } - ), - settings, - ) - - def test_splitter(splitter): - module_fx_split = splitter() - verify_split_model(module_fx_split) - - self.assertEqual( - {acc_ops.relu, acc_ops.cos}, - find_call_targets(module_fx_split._run_on_gpu_0), - ) - - self.assertEqual( - {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, - find_call_targets(module_fx_split._run_on_acc_1), - ) - - # Make sure we can compile to TorchScript - module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) - self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) - - test_splitter(splitter) - - def test_split_non_tensor_edges_3(self): - test_data = torch.randn(2, 3) - - module_nn = acc_tracer.trace( - self.TestModule(), - (test_data,), - ) - - # Making 'a', 'c', 'd' and 'e' run on ACC - splitter = TRTSplitter( - module_nn, - (test_data,), - op_support_with_support_dict( - { - "acc_ops.relu": None, - "acc_ops.sigmoid": None, - "acc_ops.cos": None, - "acc_ops.add": None, - } - ), - ) - - def test_splitter(splitter): - module_fx_split = splitter() - try: - verify_split_model(module_fx_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES) - - self.assertEqual( - {acc_ops.relu, acc_ops.cos}, - find_call_targets(module_fx_split._run_on_acc_0), - ) - - self.assertEqual( - {acc_ops.size, acc_ops.getitem, acc_ops.add}, - find_call_targets(module_fx_split._run_on_gpu_1), - ) - - self.assertEqual( - {acc_ops.sigmoid}, - find_call_targets(module_fx_split._run_on_acc_2), - ) - - # Make sure we can compile to TorchScript - module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) - self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) - - test_splitter(splitter) - - def test_split_non_tensor_edges_4(self): - test_data = torch.randn(2, 3) - - module_nn = acc_tracer.trace( - self.TestModule(), - (test_data,), - ) - - # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC - # subgraph size - settings = splitter_base._SplitterSettingBase() - settings.min_acc_module_size = 2 - splitter = TRTSplitter( - module_nn, - (test_data,), - op_support_with_support_dict( - { - "acc_ops.relu": None, - "acc_ops.sigmoid": None, - "acc_ops.cos": None, - "acc_ops.add": None, - } - ), - settings, - ) - - def test_splitter(splitter): - module_fx_split = splitter() - verify_split_model(module_fx_split) - - self.assertEqual( - {acc_ops.relu, acc_ops.cos}, - find_call_targets(module_fx_split._run_on_acc_0), - ) - - self.assertEqual( - {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, - find_call_targets(module_fx_split._run_on_gpu_1), - ) - - # Make sure we can compile to TorchScript - module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) - self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data))) - - test_splitter(splitter) - - -class TestAccNodesFinder(TestCase): - def test_acc_nodes_finder_1(self): - """ - y -------------> - | - ----> b ----> - x ----> a d - ----> c ----> - | - z -------------> - """ - - # Make a return non-tensor data - class TestModule(torch.nn.Module): - def forward(self, x, y, z): - a1 = x.size() - a1 = a1[0] - - b = y + a1 - c = z - a1 - - d = b + c - - return d - - module_nn = TestModule() - module_fx = torch.fx.symbolic_trace(module_nn) - - # Make a and c lowerable to ACC - finder = torch.fx.passes.splitter_base.FxNetAccNodesFinder( - module_fx, - op_support_with_support_dict( - { - "acc_ops.sub": None, - "acc_ops.getitem": None, - "acc_ops.size": None, - } - ), - False, - ) - acc_nodes = finder() - self.assertEqual(set(), acc_nodes, "Shouldn't have ACC nodes") - - -class TestAccFusionsFinder(TestCase): - """ - x - / \\ - a b - / | \\ - / | a2 - a0 a1 | - | / | - c | - | | - d | - \\ / - e - """ - - class TestModule(torch.nn.Module): - def forward(self, x): - a = x.size() - b = x + x - - a0 = a[0] - a1 = a[1] - a2 = a[2] - c = x.view(a1, a0, -1) - - d = c + c - e = d + a2 - return b, e - - def test_acc_fusions_finder_1(self): - """ - Assume every node is acc node. We should have one fusion group - (a, a0, a1, a2, c, d, e). - """ - module_nn = self.TestModule() - module_fx = torch.fx.symbolic_trace(module_nn) - shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1)) - - acc_node = { - node - for node in module_fx.graph.nodes - if node.op in torch.fx.passes.tools_common.CALLABLE_NODE_OPS - } - - fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder( - module_fx, - acc_node, - ) - fusion_map = fusions_finder() - - self.assertEqual(len(fusion_map), 7) - for _, v in fusion_map.items(): - self.assertEqual(len(v), 7) - - def test_acc_fusions_finder_2(self): - """ - Let b and d be cpu nodes. After fusion all nodes should be cpu nodes - because d is included in the fusion group which force all other nodes - in the same fusion group to be on CPU too. - """ - module_nn = self.TestModule() - module_fx = torch.fx.symbolic_trace(module_nn) - shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1)) - - acc_node = { - node for node in module_fx.graph.nodes if node.target == operator.add - } - fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder( - module_fx, - acc_node, - ) - fusion_map = fusions_finder() - self.assertEqual(len(fusion_map), 0) - - def test_start_with_acc_module_(self): - """ - sin relu cos sigmoid tanh - a ====> b =====> c ====> d ========> e =====> f - - We set sin, relu and cos as acc node but also set min_acc_module_size to 2 - and expect the whole module stay on CPU. - """ - - class TestModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.relu(b) - d = torch.cos(c) - e = torch.sigmoid(d) - f = torch.tanh(e) - return f - - mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) - - # Set sin, cos and tanh as acc node and split with settings - class CustomOpSupport(op_support.OperatorSupport): - _support_dict = { - "acc_ops.sin": None, - "acc_ops.cos": None, - "acc_ops.relu": None, - } - - # Create splitter setting and set min_acc_module_size to 2 - settings = splitter_base._SplitterSettingBase() - settings.min_acc_module_size = 2 - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.sin": None, - "acc_ops.cos": None, - "acc_ops.relu": None, - } - ), - settings, - ) - - def test_splitter(splitter): - st_split = splitter() - try: - verify_split_model(st_split) - except RuntimeError as err: - self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE) - modules = list(st_split.named_modules()) - # Main module and a submodule - assert len(modules) == 3 - - assert modules[1][0] == "_run_on_acc_0" - assert modules[2][0] == "_run_on_gpu_1" - - test_splitter(splitter) - - def test_exclude_support_node_by_name(self): - class TestModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.relu(b) - d = torch.cos(c) - e = torch.sigmoid(d) - f = torch.tanh(e) - return f - - mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) - - # Set sin, cos and tanh as acc node and split with settings - class CustomOpSupport(op_support.OperatorSupport): - _support_dict = { - "acc_ops.sin": None, - "acc_ops.cos": None, - "acc_ops.relu": None, - "acc_ops.sigmoid": None, - "acc_ops.tanh": None, - } - - # For unsupport relu node, this would cut graph into acc_0, gpu_1 and acc_2 - # as three sub graphs. - settings = TRTSplitterSetting() - settings.exclude_support_node_name = {"relu"} - splitter = TRTSplitter( - mod, - (torch.randn(2, 3),), - op_support_with_support_dict( - { - "acc_ops.sin": None, - "acc_ops.cos": None, - "acc_ops.relu": None, - } - ), - settings, - ) - res = splitter.generate_split_results() - self.assertTrue(len(res), 3) - - -def op_support_with_support_dict(support_dict: dict) -> op_support.OperatorSupportBase: - return op_support.OperatorSupport(support_dict) - - -if __name__ == "__main__": - run_tests() From a9b0711f5567787fc6cf65c06daa959d4a5c0970 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 14 Apr 2023 10:23:59 -0700 Subject: [PATCH 41/45] chore: Modify dynamo fx_ts_compat tests Signed-off-by: Dheeraj Peri --- .circleci/config.yml | 207 ------------------------------------------- 1 file changed, 207 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 35ff68eea3..fc4afdf1a8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -727,148 +727,6 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-fx_ts_converters_acc: - description: "Test the Dynamo acc converters" - steps: - - run: - name: Run FX converter tests - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd converters/acc_op/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/acc_op/test_results.xml - popd - - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_converters_aten: - description: "Test the dynamo aten converters" - steps: - - run: - name: Run dynamo converter tests - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd converters/aten_op/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/aten_op/test_results.xml - popd - - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_converters_vanilla: - description: "Test the dynamo vanilla converters" - steps: - - run: - name: Run dynamo converter tests - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd converters/vanilla/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/vanilla/test_results.xml - popd - - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_passes: - description: "Test the dynamo passes" - steps: - - run: - name: Run dynamo passes - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd passes - list_passes=$(ls | grep -v test_setitem*) - pytest $list_passes --junitxml=/tmp/artifacts/test_results/dynamo/passes/test_results.xml - popd - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_tools: - description: "Test the dynamo tools" - steps: - - run: - name: Run dynamo tools - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd tools - pytest --junitxml=/tmp/artifacts/test_results/dynamo/tools/test_results.xml - popd - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_trt_lower: - description: "Test the dynamo TRT lowering" - steps: - - run: - name: Run dynamo TRT lowering - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd trt_lower - pytest --junitxml=/tmp/artifacts/test_results/dynamo/trt_lower/test_results.xml - popd - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_tracer: - description: "Test all dynamo tracers" - steps: - - run: - name: Run dynamo tracer - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd tracer - list_tracer=$(ls | grep -v test_dispatch_*) - pytest $list_tracer --junitxml=/tmp/artifacts/test_results/fx/tracer/test_results.xml - popd - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_tracer_acc: - description: "Test the dynamo acc tracer only" - steps: - - run: - name: Run dynamo tracer - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd tracer - list_tracer=$(ls | grep test_acc) - pytest $list_tracer --junitxml=/tmp/artifacts/test_results/dynamo/tracer/test_results.xml - popd - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts_quant: - description: "Test the dynamo quant" - steps: - - run: - name: Run dynamo quant tests - command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd quant/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/quant/test_results.xml - popd - - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - test-dynamo-fx_ts: description: "Test the dynamo backend" steps: @@ -876,35 +734,7 @@ commands: name: Run dynamo tests command: | mkdir -p /tmp/artifacts/test_results - - test-dynamo-fx_ts_converters_acc - - test-dynamo-fx_ts_converters_aten - - test-dynamo-fx_ts_converters_vanilla - - test-dynamo-fx_ts_passes - - test-dynamo-fx_ts_tools - - test-dynamo-fx_ts_trt_lower - - test-dynamo-fx_ts_tracer - - test-dynamo-fx_ts_core - - test-dynamo-fx_ts_quant - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - - test-dynamo-fx_ts-no-aten: - description: "Test the dynamo backend without aten operators" - steps: - - run: - name: Run dynamo tests without aten ops - command: | - mkdir -p /tmp/artifacts/test_results - - test-dynamo-fx_ts_converters_acc - - test-dynamo-fx_ts_converters_vanilla - - test-dynamo-fx_ts_passes - - test-dynamo-fx_ts_tools - - test-dynamo-fx_ts_trt_lower - - test-dynamo-fx_ts_tracer_acc - test-dynamo-fx_ts_core - - test-dynamo-fx_ts_quant - store_test_results: path: /tmp/artifacts - store_artifacts: @@ -1119,37 +949,6 @@ jobs: - dump-test-env - test-dynamo-fx_ts - test-py-dynamo-x86_64-linux-no-aten: - parameters: - torch-build: - type: string - torch-build-index: - type: string - trt-version-long: - type: string - machine: - image: ubuntu-2004-cuda-11.4:202110-01 - resource_class: gpu.nvidia.large - steps: - - checkout - - attach_workspace: - at: /tmp/dist/ - - install-torch-from-index: - torch-build: << parameters.torch-build >> - torch-build-index: << parameters.torch-build-index >> - - create-py-env: - trt-version-long: << parameters.trt-version-long >> - - install-cudnn - # - run: - # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" - # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH - - run: - name: "Install torch-tensorrt" - command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl - # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - - dump-test-env - - test-dynamo-fx_ts-no-aten - package-x86_64-linux: parameters: enabled: @@ -1565,12 +1364,6 @@ workflows: requires: - build-x86_64-linux-legacy - - test-py-dynamo-x86_64-linux-no-aten: - torch-build: << pipeline.parameters.torch-build-legacy >> - torch-build-index: << pipeline.parameters.torch-build-index-legacy >> - trt-version-long: << pipeline.parameters.trt-version-long >> - requires: - - build-x86_64-linux-legacy release: when: << pipeline.parameters.enable-packaging >> jobs: From 13caff0eb5affd82b66eea45ee2f86d15a76a5c4 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 18 Apr 2023 19:16:29 -0700 Subject: [PATCH 42/45] chore: Refactor code Signed-off-by: Dheeraj Peri --- .circleci/config.yml | 23 +- py/torch_tensorrt/_Input.py | 23 +- .../fx_ts_compat/Dynamic_Shape_Support.md | 137 ----------- .../dynamo/fx_ts_compat/fx2trt.py | 2 +- .../dynamo/fx_ts_compat/input_tensor_spec.py | 4 +- .../dynamo/fx_ts_compat/lower.py | 31 ++- .../dynamo/fx_ts_compat/lower_setting.py | 6 +- .../passes/lower_pass_manager_builder.py | 2 +- .../test/core/test_import_fx2trt.py | 18 -- .../test/core/test_input_tensor_spec.py | 14 ++ .../fx_ts_compat/test/core/test_trt_module.py | 145 ------------ .../fx_ts_compat/tools/common_fx2trt.py | 2 +- .../tools/engine_layer_visualize.py | 217 ------------------ .../dynamo/fx_ts_compat/tools/graph_util.py | 78 ------- .../fx_ts_compat/tools/model_packager.py | 126 ---------- .../fx_ts_compat/tools/node_profiler.py | 53 ----- .../dynamo/fx_ts_compat/tools/tensor_prop.py | 33 --- .../fx_ts_compat/tools/timing_cache_utils.py | 39 ---- .../fx_ts_compat/tools/trt_profiler_sorted.py | 58 ----- .../dynamo/fx_ts_compat/tools/trt_splitter.py | 138 ----------- .../dynamo/fx_ts_compat/types.py | 24 -- .../dynamo/fx_ts_compat/utils.py | 140 ----------- 22 files changed, 71 insertions(+), 1242 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/engine_layer_visualize.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/graph_util.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/model_packager.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/node_profiler.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/tensor_prop.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/timing_cache_utils.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_profiler_sorted.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_splitter.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/types.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/utils.py diff --git a/.circleci/config.yml b/.circleci/config.yml index efcad6d6bd..88b547729e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -711,15 +711,15 @@ commands: # =================== FX tests end ======================== # # =================== Dynamo tests start ======================== # - test-dynamo-fx_ts_core: - description: "Test the Dynamo core" + test-dynamo-fx_ts: + description: "Test the Dynamo fx_ts_compat path" steps: - run: - name: Run Dynamo core tests + name: Run Dynamo fx_ts_compat core tests command: | cd py/torch_tensorrt/dynamo/fx_ts_compat/test pushd core/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/core/test_results.xml + pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml popd - store_test_results: @@ -737,7 +737,7 @@ commands: pushd test/ pip3 install timm pip3 install transformers - pytest --junitxml=/tmp/artifacts/test_results/dynamo/test_results.xml --ir torch_compile + pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile popd - store_test_results: @@ -745,19 +745,6 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-fx_ts: - description: "Test the dynamo backend" - steps: - - run: - name: Run dynamo tests - command: | - mkdir -p /tmp/artifacts/test_results - - test-dynamo-fx_ts_core - - store_test_results: - path: /tmp/artifacts - - store_artifacts: - path: /tmp/testlogs - # =================== Dynamo tests end ======================== # # Define a job to be invoked later in a workflow. diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 5558dc5450..33dc23b4e5 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -40,7 +40,7 @@ class _ShapeMode(Enum): DOMAIN_OFFSET = 2.0 low_tensor_domain_incl = 0.0 high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET - torch_dtype = None + torch_dtype = torch.float32 def __init__(self, *args, **kwargs): """__init__ Method for torch_tensorrt.Input @@ -142,6 +142,7 @@ def __init__(self, *args, **kwargs): self.torch_dtype = kwargs["dtype"] self.dtype = Input._parse_dtype(kwargs["dtype"]) + self.torch_dtype = Input._to_torch_dtype(self.dtype) self._explicit_set_dtype = True if "format" in kwargs: @@ -215,6 +216,22 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: + str(type(dtype)) ) + @staticmethod + def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: + if dtype == _enums.dtype.long: + return torch.long + elif dtype == _enums.dtype.int32: + return torch.int32 + elif dtype == _enums.dtype.half: + return torch.half + elif dtype == _enums.dtype.float: + return torch.float + elif dtype == _enums.dtype.bool: + return torch.bool + else: + # Default torch_dtype used in FX path + return torch.float32 + def is_trt_dtype(self) -> bool: return self.dtype != _enums.dtype.long @@ -368,9 +385,9 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor if self.shape_mode == Input._ShapeMode.STATIC: return torch.rand(self.shape).to( - dtype=self.dtype if not self.torch_dtype else self.torch_dtype + dtype=self.torch_dtype ) else: return torch.rand(self.shape[optimization_profile_field]).to( - dtype=self.dtype if not self.torch_dtype else self.torch_dtype + dtype=self.torch_dtype ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md b/py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md deleted file mode 100644 index eb4454340e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md +++ /dev/null @@ -1,137 +0,0 @@ -# PyTorch Operations Dynamic Shape Support Summary - - - - | Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason | -| --- | --- | --- | --- | --- | --- | -| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. | -| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 | -| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | | -| | avg_pool1d | partially | (-1, 3, 3) | 1 | | -| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." | -| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | | -| cat | | yes | (-1,-,1,-1,-1) | 4 | | -| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! | -| clamp | | yes | (-1,-,1,-1,-1) | | | -| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. | -| | conv1d | partially | (-1, 3, 3) | 1 | | -| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. | -| dequantize | | yes | (-1,-,1,-1,-1) | 4 | | -| eimsum | | yes | (-1,-,1,-1,-1) | 4 | | -| elu | | yes | (-1,-,1,-1,-1) | 4 | | -| embedding | | yes | (-1,-,1,-1,-1) | 4 | | -| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | -| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | -| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| | EqOperatorConstant | partially | (3,-1) | 1 | | -| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| expand | | no | | | Dynamic shape is not suitable for the expand operation. | -| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | | -| gelu | | yes | (-1,-,1,-1,-1) | 4 | | -| getitem | | yes | (-1,-,1,-1,-1) | 4 | | -| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | -| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | -| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | | -| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | | -| interpolate | | yes | (-1,-,1,-1,-1) | 4 | | -| isinf | | yes | (-1,-,1,-1,-1) | 4 | | -| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | | -| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. | -| logical_and | | yes | (-1, -1, -1, -1) | 4 | | -| logical_or | | yes | (-1, -1, -1, -1) | 4 | | -| logical_xor | | yes | (-1, -1, -1, -1) | 4 | | -| lt | | yes | (-1, -1, -1, -1) | 4 | | -| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | -| mat_mul | | yes | batch dim | | | -| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | | -| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | | -| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | | -| maximum | | yes | (-1, -1, -1, -1) | 4 | | -| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer | -| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | | -| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | | -| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | | -| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | | -| | MinMethod | yes | (-1, -1, -1, -1) | 4 | | -| minimum | | yes | (-1, -1, -1, -1) | 4 | | -| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! | -| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | | -| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | | -| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | | -| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | | -| | NeOperatorConstantConverter | partially | (3, -1) | 1 | | -| new_ones | | yes | (-1, -1, -1, -1) | 4 | | -| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. | -| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. | -| permute | | yes | (-1, -1, -1, -1) | 4 | | -| prod | | yes | (-1, -1, -1, -1) | 4 | | -| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | | -| reduce op | | yes | (-1, -1, -1, -1) | 4 | | -| relu | | yes | (-1, -1, -1, -1) | 4 | | -| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | -| reshape | | yes | (-1, -1, -1, -1) | 4 | | -| selu | | yes | (-1, -1, -1, -1) | 4 | | -| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | | -| silu | | yes | (-1,-,1,-1,-1) | 4 | | -| size | | yes | (-1, -1, -1, -1) | 4 | | -| softmax | | yes | (-1, -1, -1, -1) | 4 | | -| softsign | | yes | (-1, -1, -1, -1) | 4 | | -| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! | -| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. | -| std | | yes | (-1, -1, -1, -1) | 4 | | -| tanh | | yes | (-1, -1, -1, -1) | 4 | | -| tile | | yes | (-1, -1, -1, -1) | 4 | | -| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | | -| | float | yes | (-1, -1, -1, -1) | 4 | | -| topk | | yes | (-1, -1, -1, -1) | 4 | | -| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | | -| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | | -| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} | -| unary ops | | yes | (-1, -1, -1, -1) | 4 | | -| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | -| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] | - - - -Binary Ops Include following operations: -|Binary Ops | -|----------| -|add | -|sub | -|div | -|mul | -|floor_div | -|fmod | -|floor_divide| -|pow | - - -Unary Ops Include following operations: -|Unary Ops | -|----------| -|rsqrt | -|sin | -|cos | -|tan | -|sinh | -|cosh | -|asin | -|acos | -|atan | -|abs | -|neg | -|reciprocal| -|sqrt | -|log | -|exp | -|floor | -|ceil | -|sign | - -Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing. diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index 18e514d83d..b5165c6f2d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -16,7 +16,7 @@ from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS from .input_tensor_spec import InputTensorSpec from torch_tensorrt.fx.observer import Observer -from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt +from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py index 3eb9a115af..7f67e8abbf 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py @@ -2,8 +2,8 @@ import torch -from .types import Shape, ShapeRange -from .utils import get_dynamic_dims +from torch_tensorrt.fx.types import Shape, ShapeRange +from torch_tensorrt.fx.utils import get_dynamic_dims from torch_tensorrt._Input import Input diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index d16eb008c1..60ace0f12a 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -14,12 +14,12 @@ from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder from .passes.pass_utils import PassFunc, validate_inference -from .tools.timing_cache_utils import TimingCacheManager -from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer from torch_tensorrt.fx.trt_module import TRTModule -from .utils import LowerPrecision +from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt._Device import Device logger = logging.getLogger(__name__) @@ -36,12 +36,23 @@ def compile( enabled_precisions=set(), min_block_size: int = 3, workspace_size=0, - verbose_log=False, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + debug=False, + refit=False, timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, is_aten=False, use_experimental_fx_rt=False, + num_avg_timing_iters=1, + torch_executed_ops=[], + torch_executed_modules=[], + **kwargs, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -52,7 +63,7 @@ def compile( input: Input for module. min_block_size: Minimal number of nodes for an accelerated submodule workspace_size: Maximum size of workspace given to TensorRT. - verbose_log: Enable verbose log for TensorRT if set True. + debug: Enable verbose log for TensorRT if set True. timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. cuda_graph_batch_size: Cuda graph batch size, default to be -1. @@ -65,6 +76,12 @@ def compile( "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True" ) + logger.warn( + "For ir=fx_ts_compat backend only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, device, disable_tf32, sparse_weights, min_block_size}" + ) + # Parse precision into LowerPrecision lower_precision = LowerPrecision.FP32 if torch.float16 in enabled_precisions: @@ -100,7 +117,7 @@ def compile( sparse_weights=sparse_weights, workspace_size=workspace_size, lower_precision=lower_precision, - verbose_log=verbose_log, + debug=debug, timing_cache_prefix=timing_cache_prefix, save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, @@ -148,7 +165,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, explicit_precision=self.lower_setting.explicit_precision, logger_level=trt.Logger.VERBOSE - if self.lower_setting.verbose_log + if self.lower_setting.debug else trt.Logger.WARNING, ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index be0ada8b55..9008bbe8e9 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -9,7 +9,7 @@ fuse_permute_linear, fuse_permute_matmul, ) -from .utils import LowerPrecision +from torch_tensorrt.fx.utils import LowerPrecision @dc.dataclass @@ -54,7 +54,7 @@ class LowerSetting(LowerSettingBasic): as (a->b->c->d)=>(e). Current basic fuse patterns are: permute->linear permute->matmul - verbose_log: Enable TensorRT engine verbose log mode. + debug: Enable TensorRT engine verbose log mode. algo_selector: Enable TensorRT algorithm selector at execution time. timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing cache file at execution time if valid timing cache file is provided. @@ -85,7 +85,7 @@ class LowerSetting(LowerSettingBasic): [fuse_permute_matmul, fuse_permute_linear] ) ) - verbose_log: bool = False + debug: bool = False algo_selector = None timing_cache_prefix: str = "" save_timing_cache: bool = False diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py index cb012c4f4e..0fd3777254 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py @@ -8,7 +8,7 @@ from torch.fx.passes.pass_manager import inplace_wrapper, PassManager from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision +from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt import _Input from ..input_tensor_spec import InputTensorSpec diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py deleted file mode 100644 index 12e47ef112..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3 -# (c) Facebook, Inc. and its affiliates. Confidential and proprietary. - -# Owner(s): ["oncall: gpu_enablement"] - -# Test that this import should not trigger any error when run -# in non-GPU hosts, or in any build mode. -import torch_tensorrt.dynamo.fx_ts_compat.lower as fxl # noqa: F401 -from torch.testing._internal.common_utils import run_tests, TestCase - - -class MainTests(TestCase): - def test_1(self): - pass - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py index 7794b1bac8..b22986a0c5 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch +import torch_tensorrt from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting @@ -63,6 +64,19 @@ def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): tensor_shape[i] = batch_size self.assertSequenceEqual(tensor_shape, shape) + def test_from_static_input(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] + inputs = torch_tensorrt.Input.from_tensors(tensors) + specs = [InputTensorSpec.from_input(input) for input in inputs] + for spec, tensor in zip(specs, tensors): + self._validate_spec(spec, tensor) + + def test_from_dynamic_input(self): + inputs = torch_tensorrt.Input(min_shape=(2, 2, 3), opt_shape=(4, 2, 3), max_shape=(8, 2, 3)) + example_tensor = inputs.example_tensor(optimization_profile_field="opt_shape") + spec = InputTensorSpec.from_input(inputs) + self._validate_spec(spec, example_tensor, dynamic_dims=[0]) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py deleted file mode 100644 index 8043b753ac..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_trt_module.py +++ /dev/null @@ -1,145 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import io -import os - -import torch -import torch.fx - -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx import TRTModule -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision - - -class TestTRTModule(TestCase): - def test_save_and_load_trt_module(self): - class TestModule(torch.nn.Module): - def forward(self, x): - return x + x - - inputs = [torch.randn(1, 1)] - mod = TestModule().eval() - ref_output = mod(*inputs) - - mod = acc_tracer.trace(mod, inputs) - interp = TRTInterpreter(mod, input_specs=InputTensorSpec.from_tensors(inputs)) - trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32)) - torch.save(trt_mod, "trt.pt") - reload_trt_mod = torch.load("trt.pt") - - torch.testing.assert_close( - reload_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04 - ) - os.remove(f"{os.getcwd()}/trt.pt") - - def test_save_and_load_state_dict(self): - class TestModule(torch.nn.Module): - def forward(self, x): - return x + x - - inputs = [torch.randn(1, 1)] - mod = TestModule().eval() - ref_output = mod(*inputs) - - mod = acc_tracer.trace(mod, inputs) - interp = TRTInterpreter(mod, input_specs=InputTensorSpec.from_tensors(inputs)) - trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32)) - st = trt_mod.state_dict() - - new_trt_mod = TRTModule() - new_trt_mod.load_state_dict(st) - - torch.testing.assert_close( - new_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04 - ) - - -# TODO add unittest.skip later -# class TestTRTModuleNext(TestCase): -# def test_save_and_load_trt_module(self): -# class TestModule(torch.nn.Module): -# def forward(self, x): -# return x + x - -# inputs = [torch.randn(1, 1)] -# mod = TestModule().eval() -# ref_output = mod(*inputs) - -# mod = acc_tracer.trace(mod, inputs) - -# interp = TRTInterpreter( -# mod, -# input_specs=InputTensorSpec.from_tensors(inputs), -# explicit_batch_dimension=True, -# ) -# interp_res = interp.run(lower_precision=LowerPrecision.FP32) - -# with io.BytesIO() as engine_bytes: -# engine_bytes.write(interp_res.engine.serialize()) -# engine_str = engine_bytes.getvalue() - -# trt_mod = TRTModuleNext( -# name="TestModule", -# serialized_engine=engine_str, -# input_binding_names=interp_res.input_names, -# output_binding_names=interp_res.output_names, -# target_device=Device(f"cuda:{torch.cuda.current_device()}"), -# ) - -# torch.save(trt_mod, "trt.pt") -# reload_trt_mod = torch.load("trt.pt") - -# torch.testing.assert_allclose( -# reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), -# ref_output, -# rtol=1e-04, -# atol=1e-04, -# ) -# os.remove(f"{os.getcwd()}/trt.pt") - -# def test_save_and_load_state_dict(self): -# class TestModule(torch.nn.Module): -# def forward(self, x): -# return x + x - -# inputs = [torch.randn(1, 1)] -# mod = TestModule().eval() -# ref_output = mod(*inputs) - -# mod = acc_tracer.trace(mod, inputs) -# interp = TRTInterpreter( -# mod, -# input_specs=InputTensorSpec.from_tensors(inputs), -# explicit_batch_dimension=True, -# ) -# interp_res = interp.run(lower_precision=LowerPrecision.FP32) - -# with io.BytesIO() as engine_bytes: -# engine_bytes.write(interp_res.engine.serialize()) -# engine_str = engine_bytes.getvalue() - -# trt_mod = TRTModuleNext( -# name="TestModule", -# serialized_engine=engine_str, -# input_binding_names=interp_res.input_names, -# output_binding_names=interp_res.output_names, -# target_device=Device(f"cuda:{torch.cuda.current_device()}"), -# ) - -# st = trt_mod.state_dict() - -# new_trt_mod = TRTModuleNext() -# new_trt_mod.load_state_dict(st) - -# torch.testing.assert_allclose( -# new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), -# ref_output, -# rtol=1e-04, -# atol=1e-04, -# ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py index 5c0d8bbc76..334243fef4 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py @@ -27,7 +27,7 @@ run_const_fold, ) from torch_tensorrt.dynamo.fx_ts_compat.passes.pass_utils import chain_passes -from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision, proxytensor_trace +from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/engine_layer_visualize.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/engine_layer_visualize.py deleted file mode 100644 index cecd1ecb20..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/engine_layer_visualize.py +++ /dev/null @@ -1,217 +0,0 @@ -import argparse -import logging -import re -from typing import Any, Dict, List, NamedTuple, Optional, Tuple - -import pydot - -_LOGGER: logging.Logger = logging.getLogger(__name__) -""" -log_file is generated by tensorrt verbose logger during building engine. -profile_file is generated by tensorrt profiler. - -Curretnly we support processing multiple logs in one log_file, which -would generate multiple dot graphs. However, multiple engine profiles are not -supported. - -Usage: - python torch_tensorrt.fx/tools/engine_layer_visualize.py --log_file aaa --profile_file bbb -""" - -parser = argparse.ArgumentParser() -parser.add_argument( - "--log_file", - type=str, - default="", - help="TensorRT VERBOSE logging when building engines.", -) -parser.add_argument( - "--profile_file", - type=str, - default="", - help="TensorRT execution context profiler output.", -) -args = parser.parse_args() - - -class LayerInfo(NamedTuple): - kernel_name: str - layer_name: str - tactic: str - input_names: Optional[List[str]] - input_types: Optional[List[str]] - output_name: str - output_type: str - time: str - - @classmethod - def from_string(cls, string, tactic_names, layer_times=None): - input_names = [] - input_types = [] - kernel_name, layer_name, tactic, inputs, output_name, output_type = re.findall( - "Layer\\((.+)\\): (.+), Tactic: (-?\\d+), (.+)? -> (.+)\\[(.+)\\]", string - )[0] - - if kernel_name != "Constant": - inputs = re.findall( - "[, ]*(.+?)\\[([Half|Float|Int8]+\\(\\d[,\\d]*\\))\\]", inputs - ) - for input_name, input_type in inputs: - input_names.append(input_name) - input_types.append(input_type) - - if layer_name in tactic_names: - kernel_name = tactic_names[layer_name] - else: - input_names = input_types = None # type:ignore[assignment] - - return cls( - kernel_name, - layer_name, - tactic, - input_names, - input_types, - output_name, - output_type, - layer_times[layer_name] if layer_times else "NA", - ) - - -def build_node(layer): - layer_name = layer.layer_name.replace("|", "\\|") - label = f"{{{layer_name}|kernel: {layer.kernel_name}\\l|tactic: {layer.tactic}\\l|time: {layer.time}\\l}}" - label = label.replace(">", "\\>") - return pydot.Node(layer.layer_name, label=label, **style) - - -def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node): - if layer.input_names is None: - return - - for input_name, input_type in zip(layer.input_names, layer.input_types): - if input_name not in output_name2node: - if input_name in reformat_layers: - from_node = pydot.Node( - input_name, - label="{reformatter|kernel: Reformat\\l|tactic: 0\\l}", - **style, - ) - graph.add_node(from_node) - if reformat_layers[input_name][0] in output_name2node: - graph.add_edge( - pydot.Edge( - output_name2node[reformat_layers[input_name][0]], - from_node, - label=f"{reformat_layers[input_name][0]}\\l{reformat_layers[input_name][1]}\\l", - ) - ) - else: - _LOGGER.info(f"Missing node {input_name}") - from_node = input_name - else: - from_node = output_name2node[input_name] - - edge_name = input_name.replace(">", "\\>") - graph.add_edge( - pydot.Edge( - from_node, - layer_name2node[layer.layer_name], - label=f"{edge_name}\\l{input_type}\\l", - ) - ) - - -if args.profile_file != "": - layer_times = {} - with open(args.profile_file) as f: - times = f.readlines() - - for t in times: - t = t.strip("\n").split(": ") # type: ignore[assignment] - layer_times[": ".join(t[:-1])] = t[-1] -else: - layer_times = None # type: ignore[assignment] - -if args.log_file != "": - with open(args.log_file) as f: - lines = f.readlines() - - graphs = [] - layers = [] - reformat_layers: Dict[str, Tuple[str, str]] = {} - tactic_names: Dict[str, str] = {} - layer_info_start = False - tactic_name_start = False - - for line in lines: - line = line.strip("\n") - - if layer_info_start: - if "Layer(" in line: - layers.append(LayerInfo.from_string(line, tactic_names, layer_times)) - else: - layer_info_start = False - graphs.append((layers, reformat_layers)) - layers = [] - reformat_layers = {} - tactic_names = {} - - if tactic_name_start and "Set Tactic Name:" in line: - layer_name, kernel_name, _ = re.findall( - "VERBOSE: (.*) Set Tactic Name: (.*) Tactic: (.*)$", line - )[0] - tactic_names[layer_name] = kernel_name - - # Some reformat layers aren't displayed in Engine Layer Information - if "Adding reformat layer" in line: - output_name, input_name, from_type, to_type = re.findall( - "reformat layer: (.+) \\((.+)\\) from (.+) to (.+)", line - )[0] - reformat_layers[output_name] = (input_name, from_type) - - if "Total Activation Memory:" in line: - tactic_name_start = True - - if "Engine Layer Information" in line: - layer_info_start = True - tactic_name_start = False - - style = { - "shape": "record", - "fillcolor": "Salmon", - "style": '"filled,rounded"', - "fontcolor": "#000000", - } - - dot_graphs: List[Any] = [] - i = 0 - for layers, reformat_layers in graphs: - output_name2node = {} - layer_name2node = {} - dot_graph = pydot.Dot("Layer Graph") - - for layer in layers: - node = build_node(layer) - dot_graph.add_node(node) - output_name2node[layer.output_name] = node - layer_name2node[layer.layer_name] = node - - for layer in layers: - build_edge( - layer, dot_graph, reformat_layers, output_name2node, layer_name2node - ) - - dot_graph.write_raw(f"EngineLayers_{i}.dot") - i += 1 - -if args.profile_file != "": - est_reformat_time = 0.0 - est_total_time = 0.0 - - for layer in layers: - if layer.kernel_name == "Reformat": - est_reformat_time += float(layer.time[:-2]) - est_total_time += float(layer.time[:-2]) - - _LOGGER.info(f"Time Cost on Reformatting: {est_reformat_time} ms") - _LOGGER.info(f"Total Time Cost: {est_total_time} ms") diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/graph_util.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/graph_util.py deleted file mode 100644 index 5d07f76641..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/graph_util.py +++ /dev/null @@ -1,78 +0,0 @@ -import graphviz # type: ignore[import] - - -def get_layer_name_type(layer): - return "\n".join(f"{i}" for i in [layer.name, layer.type]) - - -def trt_network_to_dot_graph(network): - dot = graphviz.Digraph(comment="Network") - - # add nodes (layers) - for i in range(network.num_layers): - layer = network.get_layer(i) - dot.node(get_layer_name_type(layer)) - - # add nodes (inputs) - for i in range(network.num_inputs): - dot.node(network.get_input(i).name) - - # add nodes (outputs) - for i in range(network.num_outputs): - dot.node(network.get_output(i).name) - - # add layer->layer edges - for a in range(network.num_layers): - layer_a = network.get_layer(a) - - for b in range(network.num_layers): - layer_b = network.get_layer(b) - - for i in range(layer_a.num_outputs): - output_i = layer_a.get_output(i) - - for j in range(layer_b.num_inputs): - input_j = layer_b.get_input(j) - - if output_i == input_j: - dot.edge( - get_layer_name_type(layer_a), - get_layer_name_type(layer_b), - label=str(input_j.shape), - ) - - # add input->layer edges - for i in range(network.num_inputs): - input_i = network.get_input(i) - - for b in range(network.num_layers): - layer_b = network.get_layer(b) - - for j in range(layer_b.num_inputs): - input_j = layer_b.get_input(j) - - if input_i == input_j: - dot.edge( - input_i.name, - get_layer_name_type(layer_b), - label=str(input_j.shape), - ) - - # add layer->output edges - for i in range(network.num_outputs): - input_i = network.get_output(i) - - for b in range(network.num_layers): - layer_b = network.get_layer(b) - - for j in range(layer_b.num_outputs): - input_j = layer_b.get_output(j) - - if input_i == input_j: - dot.edge( - get_layer_name_type(layer_b), - input_i.name, - label=str(input_j.shape), - ) - - return dot diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/model_packager.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/model_packager.py deleted file mode 100644 index 0ef0ff05a4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/model_packager.py +++ /dev/null @@ -1,126 +0,0 @@ -from pathlib import Path -from typing import BinaryIO, Sequence, TextIO, Union - -import torch -from torch.fx.passes.split_utils import getattr_recursive -from torch.package import PackageExporter - -""" -A tool to package acc submodule as a torch package. The packaged model can be loaded -with just PyTorch library. -""" - - -def flatten_model(model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - Remove all original modules with an attr holder module so that all original modules - and names are not present. - """ - holder_module = torch.nn.Module() - model._holder = holder_module - attr_id = 0 - - for node in model.graph.nodes: - assert node.op != "call_module" - if node.op == "get_attr": - attr = getattr_recursive(model, node.target) - setattr(holder_module, f"_attr_{attr_id}", attr) - with model.graph.inserting_before(node): - new_node = model.graph.get_attr(f"_holder._attr_{attr_id}") - node.replace_all_uses_with(new_node) - attr_id += 1 - - model.graph.eliminate_dead_code() - model.recompile() - model.delete_all_unused_submodules() - return model - - -def generate_standalone_repro( - model: torch.fx.GraphModule, output: Union[str, Path, TextIO], prelude: str = "" -) -> None: - """ - Generate a standalone python file for the model where weights are randomized - and the model flattened. - This only works if leaf nodes are only torch.nn modules. - """ - model = flatten_model(model) - - INDENT = " " - lines = [ - "", - "import torch", - "from torch import nn", - "", - "", - "class ExportedModule(nn.Module):", - f"{INDENT}def __init__(self):", - f"{INDENT * 2}super().__init__()", - ] - for k, v in model._holder.named_parameters(): - shape = ", ".join([str(i) for i in v.shape]) - rand_func = "randn" if torch.is_floating_point(v) else "randint" - int_range = "" if torch.is_floating_point(v) else "0, 5, " - lines.append( - f"{INDENT * 2}self.{k} = nn.Parameter(torch.{rand_func}({int_range}{shape}, dtype={v.dtype}))" - ) - code = str(model.code) - - def dump(f): - f.write(prelude) - f.write("\n".join(lines)) - f.write( - "\n".join( - [ - INDENT + line.replace("self._holder.", "self.") - for line in code.split("\n") - ] - ) - ) - f.write("\n") - - if isinstance(output, (Path, str)): - with open(str(output), "w") as f: - dump(f) - else: - dump(output) - - -class ModelPackager: - @classmethod - def set_extern_modules(cls, pe: PackageExporter) -> None: - pe.extern( - [ - "builtins", - "sys", - "torch.**", - ] - ) - - @classmethod - def set_mocked_modules(cls, pe: PackageExporter): - pe.mock( - "**", - exclude=[ - "torch_tensorrt.fx.tracer.acc_tracer.acc_ops", - "torch_tensorrt.fx.tracer.acc_tracer.acc_normalizer", - "torch_tensorrt.fx.tracer.acc_tracer.acc_op_properties", - ], - ) - - @classmethod - def package_model( - cls, - model: torch.nn.Module, - model_inputs: Sequence[torch.Tensor], - output: Union[str, Path, BinaryIO], - preserve_model_structure: bool = False, - ) -> None: - if not preserve_model_structure: - model = flatten_model(model) - with PackageExporter(output) as pe: - cls.set_extern_modules(pe) - cls.set_mocked_modules(pe) - pe.intern("**") - pe.save_pickle("repro", "model", model) - pe.save_pickle("repro", "inputs", model_inputs) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/node_profiler.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/node_profiler.py deleted file mode 100644 index 1a37c27197..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/node_profiler.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Any - -import torch -from torch import fx - - -class NodeProfiler(fx.Interpreter): - """ - This is basically a variant of shape prop in - https://github.com/pytorch/pytorch/blob/74849d9188de30d93f7c523d4eeceeef044147a9/torch/fx/passes/shape_prop.py#L65. - Instead of propagating just the shape, we record all the intermediate node Tensor values. - This is useful to debug some of lowering pass issue where we want to check a specific - tensor value. Note that output value can be tuple(Tensor) as well as Tensor. - """ - - def __init__(self, module: fx.GraphModule): - super().__init__(module) - self.execution_time = {} - self.node_map = {} - self.iter = 100 - - def run_node(self, n: fx.Node) -> Any: - result = super().run_node(n) - if n.op not in {"call_function", "call_method", "call_module"}: - return result - - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - for _ in range(self.iter): - result = super().run_node(n) - - end_event.record() - torch.cuda.synchronize() - - self.execution_time[f"{n.name}"] = ( - start_event.elapsed_time(end_event) / self.iter - ) - self.node_map[n.name] = n - return result - - def propagate(self, *args): - """ - Run `module` via interpretation and return the result and - record the shape and type of each node. - Args: - *args (Tensor): the sample input. - Returns: - Any: The value returned from executing the Module - """ - return super().run(*args) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/tensor_prop.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/tensor_prop.py deleted file mode 100644 index a52e0a3929..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/tensor_prop.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any - -from torch import fx - - -class TensorProp(fx.Interpreter): - """ - This is basically a variant of shape prop in - https://github.com/pytorch/pytorch/blob/74849d9188de30d93f7c523d4eeceeef044147a9/torch/fx/passes/shape_prop.py#L65. - Instead of propagating just the shape, we record all the intermediate node Tensor values. - This is useful to debug some of lowering pass issue where we want to check a specific - tensor value. Note that output value can be tuple(Tensor) as well as Tensor. - """ - - def __init__(self, module: fx.GraphModule): - super().__init__(module) - self.tensor_map = {} - - def run_node(self, n: fx.Node) -> Any: - result = super().run_node(n) - self.tensor_map[n.name] = result - return result - - def propagate(self, *args): - """ - Run `module` via interpretation and return the result and - record the shape and type of each node. - Args: - *args (Tensor): the sample input. - Returns: - Any: The value returned from executing the Module - """ - return super().run(*args) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/timing_cache_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/timing_cache_utils.py deleted file mode 100644 index 4580843e98..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/timing_cache_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -import os - -logger = logging.getLogger(__name__) - - -class TimingCacheManager: - def __init__(self, timing_cache_prefix: str = "", save_timing_cache=False): - # Setting timing cache for TRTInterpreter - tc = os.environ.get("TRT_TIMING_CACHE_PREFIX", "") - timing_cache_prefix_name = timing_cache_prefix - if not timing_cache_prefix and tc: - timing_cache_prefix_name = tc - - self.timing_cache_prefix_name = timing_cache_prefix_name - self.save_timing_cache = save_timing_cache - - def get_file_full_name(self, name: str): - return f"{self.timing_cache_prefix_name}_{name}.npy" - - def get_timing_cache_trt(self, timing_cache_file: str) -> bytearray: - timing_cache_file = self.get_file_full_name(timing_cache_file) - try: - with open(timing_cache_file, "rb") as raw_cache: - cache_data = raw_cache.read() - return bytearray(cache_data) - except Exception: - return None - - def update_timing_cache( - self, timing_cache_file: str, serilized_cache: bytearray - ) -> None: - if not self.save_timing_cache: - return - timing_cache_file = self.get_file_full_name(timing_cache_file) - with open(timing_cache_file, "wb") as local_cache: - local_cache.seek(0) - local_cache.write(serilized_cache) - local_cache.truncate() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_profiler_sorted.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_profiler_sorted.py deleted file mode 100644 index 48293773c4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_profiler_sorted.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -import logging -import operator -from typing import List, Mapping, Optional - -import torch - -from tensorrt import tensorrt as trt - -from .. import TRTModule - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -class SortedTRTProfiler(trt.IProfiler): - def __init__(self): - super().__init__() - self.layers = {} - - def report_layer_time(self, layer_name: str, ms: int) -> None: - self.layers[layer_name] = ms - - def print_sorted_profile( - self, additional_info: Optional[Mapping[str, str]] - ) -> None: - additional_info = {} if additional_info is None else additional_info - for k, v in sorted(self.layers.items(), key=operator.itemgetter(1)): - additional_str = additional_info.get(k, "") - _LOGGER.info(f"{k} {additional_str}: {v}ms") - - -def profile_trt_module( - name: str, trt_mod: TRTModule, mod_input: List[torch.Tensor] -) -> None: - """ - Provide per layer timing and shape info - """ - layer_info = json.loads(trt_mod.get_layer_info()) # pyre-ignore[29] - shape_map = {} - for layer in layer_info["Layers"]: - # if type is str, it means verbose_profile is off in interpreter.run() - # Theorectically, we can print profiling information without shape information - # but we choose to not print profiling information so we can use verbose_profile to control it - if type(layer) is str: - return - name = layer["Name"] - input_str = ", ".join( - [str(x.get("Dimensions", "[]")) for x in layer.get("Inputs", [])] - ) - output_str = ", ".join( - [str(x.get("Dimensions", "[]")) for x in layer.get("Outputs", [])] - ) - shape_map[name] = f"({input_str}) -> ({output_str})" - - trt_mod.enable_profiling(profiler=SortedTRTProfiler()) # pyre-ignore[29] - _ = trt_mod(*mod_input) - trt_mod.context.profiler.print_sorted_profile(shape_map) # pyre-ignore[16] - trt_mod.disable_profiling() # pyre-ignore[29] diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_splitter.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_splitter.py deleted file mode 100644 index c48f6d4e7d..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_splitter.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Dict, Iterable, Sequence - -import torch -import torch.fx.passes.operator_support as ops -import torch.fx.passes.splitter_base as splitter_base -from torch.fx.passes.tools_common import get_acc_ops_name, Tensors - -from .. import ( - CONVERTERS, - InputTensorSpec, - NO_EXPLICIT_BATCH_DIM_SUPPORT, - NO_IMPLICIT_BATCH_DIM_SUPPORT, - TRTInterpreter, -) -from torch_tensorrt.fx import TRTModule -from ..tools.trt_minimizer import TensorRTMinimizer - - -def create_trt_operator_support( - use_implicit_batch_dim=True, - exclude_support_node_name: set = (), -) -> ops.OperatorSupportBase: - """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.""" - # Create an `OperatorSupport` that declares a node supported if it - # finds a registered TRT converter. - support_dict: Dict[str, None] = {} - for k in CONVERTERS.keys(): - if use_implicit_batch_dim: - if k not in NO_IMPLICIT_BATCH_DIM_SUPPORT.keys(): - support_dict[get_acc_ops_name(k)] = None - elif k not in NO_EXPLICIT_BATCH_DIM_SUPPORT.keys(): - support_dict[get_acc_ops_name(k)] = None - supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict) - - return ops.chain( - ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), - # 1. Node is not supported if it has args with int64 dtype: - ops.OpSupports.decline_if_input_dtype(torch.int64), - # 2. Node is supported if it has TRT converter: - supported_if_converter_registered, - ) - - -class TRTSplitterSetting(splitter_base._SplitterSettingBase): - def __init__(self): - super().__init__() - - # Determines what batch mode we'll use for lowering. - # During split, we'll split out the operators that - # don't support the batch dim. - self.use_implicit_batch_dim: bool = True - self.exclude_support_node_name: set = set() - self.use_experimental_rt: bool = False - - if self.use_experimental_rt and self.use_implicit_batch_dim: - raise ValueError( - "The experimental unifed runtime only supports explicit batch. Please make sure to set use_implicit_batch_dim=False when use_experimental_rt=True" - ) - - -class TRTSplitter(splitter_base._SplitterBase): - def __init__( - self, - module: torch.fx.GraphModule, - sample_input: Sequence[Any], - operator_support: ops.OperatorSupportBase = None, - settings: TRTSplitterSetting = None, - ): - if not settings: - settings = TRTSplitterSetting() - if not operator_support: - operator_support = create_trt_operator_support( - settings.use_implicit_batch_dim, settings.exclude_support_node_name - ) - super().__init__( - module, - sample_input, - operator_support, - settings, - non_acc_submodule_name="_run_on_gpu_", - ) - - def _lower_model_to_backend( - self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor] - ): - """ - Lower a GraphModule `mod` to TensorRT with `inputs`. - """ - # Current code for lowering is place-holder, subject to future change - # based on feeds model's actual status - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - interpreter_result = interp.run(*inputs) - if self.settings.use_experimental_rt: - import io - - from torch_tensorrt._Device import Device - from torch_tensorrt._TRTModuleNext import TRTModuleNext - - with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) - engine_str = engine_bytes.getvalue() - - return TRTModuleNext( - engine_str, - name=str(type(mod)), - input_binding_names=interpreter_result.input_names, - output_binding_names=interpreter_result.output_names, - target_device=Device(f"cuda:{torch.cuda.current_device()}"), - # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do - ) - else: - return TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) - - def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors): - """ - This function serves the preview functionality in Splitter. When previewing - splitting result, if something wrong happens during lowering model to TensorRT - or running a TensorRT model, this function will be called to find any culprit - that is responsible for the error. - """ - # Since we don't care about accuracy here, we pass in a dummy compare function. - minimizer = TensorRTMinimizer(mod, inputs, lambda a, b, c: (1, True)) - minimizer.settings.traverse_method = "sequential" - minimizer.settings.find_all = True - culprits = minimizer.minimize() - - if len(culprits) == 0: - reports = "Unable to find a culprit!\n" - else: - reports = "Found some problematic nodes:\n" - for node in culprits: - reports += f"{node.format_node()}\n" - - return reports diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/types.py b/py/torch_tensorrt/dynamo/fx_ts_compat/types.py deleted file mode 100644 index f233f8dd9c..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/types.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Sequence, Tuple - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt - -if hasattr(trt, "__version__"): - TRTNetwork = trt.INetworkDefinition - TRTTensor = trt.tensorrt.ITensor - TRTLayer = trt.ILayer - TRTPluginFieldCollection = trt.PluginFieldCollection - TRTPlugin = trt.IPluginV2 - TRTDataType = trt.DataType - TRTElementWiseOp = trt.ElementWiseOperation -else: - TRTNetwork = "trt.INetworkDefinition" - TRTTensor = "trt.tensorrt.ITensor" - TRTLayer = "trt.ILayer" - TRTPluginFieldCollection = "trt.PluginFieldCollection" - TRTPlugin = "trt.IPluginV2" - TRTDataType = "trt.DataType" - TRTElementWiseOp = "trt.ElementWiseOperation" - -Shape = Sequence[int] -ShapeRange = Tuple[Shape, Shape, Shape] diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/utils.py deleted file mode 100644 index 79779f604e..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -from enum import Enum -from typing import List, Callable -from packaging import version - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt -import torch -from functorch import make_fx -from functorch.experimental import functionalize -from torch_tensorrt.fx.passes.lower_basic_pass import ( - replace_op_with_indices, - run_const_fold, -) - -from .types import Shape, TRTDataType - - -class LowerPrecision(Enum): - FP32 = "fp32" - FP16 = "fp16" - INT8 = "int8" - - -def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType: - """ - Convert PyTorch data types to TensorRT data types. - - Args: - dtype (torch.dtype): A PyTorch data type. - - Returns: - The equivalent TensorRT data type. - """ - if trt.__version__ >= "7.0" and dtype == torch.bool: - return trt.bool - elif dtype == torch.int8: - return trt.int8 - elif dtype == torch.int32: - return trt.int32 - elif dtype == torch.float16: - return trt.float16 - elif dtype == torch.float32: - return trt.float32 - else: - raise TypeError("%s is not supported by tensorrt" % dtype) - - -def torch_dtype_from_trt(dtype: TRTDataType) -> torch.dtype: - """ - Convert TensorRT data types to PyTorch data types. - - Args: - dtype (TRTDataType): A TensorRT data type. - - Returns: - The equivalent PyTorch data type. - """ - if dtype == trt.int8: - return torch.int8 - elif trt.__version__ >= "7.0" and dtype == trt.bool: - return torch.bool - elif dtype == trt.int32: - return torch.int32 - elif dtype == trt.float16: - return torch.float16 - elif dtype == trt.float32: - return torch.float32 - else: - raise TypeError("%s is not supported by torch" % dtype) - - -def get_dynamic_dims(shape: Shape) -> List[int]: - """ - This function finds the dynamic dimensions in the given - shape. A dimension is dynamic if it's -1. - - Args: - shape (Shape): A sequence of integer that represents - the shape of a tensor. - - Returns: - A list of integers contains all the dynamic dimensions - in the given shape - """ - dynamic_dims = [] - - for i, s in enumerate(shape): - if s == -1: - dynamic_dims.append(i) - - return dynamic_dims - - -def proxytensor_trace(mod, inputs): - - mod.eval() - - def f(*inp): - return mod(*inp) - - mod = make_fx(functionalize(f))(*inputs) - - # Remove const operation. For ex, nn.Linear has transpose operation on weight - mod.graph.eliminate_dead_code() - mod = run_const_fold(mod) - mod = replace_op_with_indices(mod) - return mod - - -def req_torch_version(min_torch_version: str = "2.dev"): - """ - Create a decorator which verifies the Torch version installed - against a specified version range - - Args: - min_torch_version (str): The minimum required Torch version - for the decorated function to work properly - - Returns: - A decorator which raises a descriptive error message if - an unsupported Torch version is used - """ - - def nested_decorator(f: Callable): - def function_wrapper(*args, **kwargs): - # Parse minimum and current Torch versions - min_version = version.parse(min_torch_version) - current_version = version.parse(torch.__version__) - - if current_version < min_version: - raise AssertionError( - f"Expected Torch version {min_torch_version} or greater, " - + f"when calling {f}. Detected version {torch.__version__}" - ) - else: - return f(*args, **kwargs) - - return function_wrapper - - return nested_decorator From 2addf5ef4a290f1a99667878420cb9bc9e40c315 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 18 Apr 2023 23:53:58 -0700 Subject: [PATCH 43/45] chore: Linter fixes Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_Input.py | 4 +--- .../dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 33dc23b4e5..e76817e041 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -384,9 +384,7 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor ) if self.shape_mode == Input._ShapeMode.STATIC: - return torch.rand(self.shape).to( - dtype=self.torch_dtype - ) + return torch.rand(self.shape).to(dtype=self.torch_dtype) else: return torch.rand(self.shape[optimization_profile_field]).to( dtype=self.torch_dtype diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py index b22986a0c5..0761b964f8 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py @@ -72,7 +72,9 @@ def test_from_static_input(self): self._validate_spec(spec, tensor) def test_from_dynamic_input(self): - inputs = torch_tensorrt.Input(min_shape=(2, 2, 3), opt_shape=(4, 2, 3), max_shape=(8, 2, 3)) + inputs = torch_tensorrt.Input( + min_shape=(2, 2, 3), opt_shape=(4, 2, 3), max_shape=(8, 2, 3) + ) example_tensor = inputs.example_tensor(optimization_profile_field="opt_shape") spec = InputTensorSpec.from_input(inputs) self._validate_spec(spec, example_tensor, dynamic_dims=[0]) From fb41cf7ebb4a88342d5111892f1bcc897a7e071d Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Wed, 26 Apr 2023 14:33:28 -0700 Subject: [PATCH 44/45] fix: Add test suite for torch.compile backend (#1849) --- .circleci/config.yml | 17 ++++ py/torch_tensorrt/dynamo/test/utils.py | 39 -------- .../torch_compile/test/test_compiler_utils.py | 57 +++++++++++ .../torch_compile/test/test_lowering.py | 54 +++++++++++ .../torch_compile/test/test_partitioning.py | 68 ++++++++++++++ .../dynamo/torch_compile/test/utils.py | 94 +++++++++++++++++++ .../dynamo/torch_compile/utils.py | 2 + 7 files changed, 292 insertions(+), 39 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py create mode 100644 py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py create mode 100644 py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py create mode 100644 py/torch_tensorrt/dynamo/torch_compile/test/utils.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 88b547729e..1604bea3df 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -727,6 +727,22 @@ commands: - store_artifacts: path: /tmp/testlogs + test-dynamo-torch_compile-core: + description: "Test the Dynamo torch_compile path" + steps: + - run: + name: Run Dynamo torch_compile core tests + command: | + cd py/torch_tensorrt/dynamo/torch_compile + pushd test/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + test-dynamo-torch_compile: description: "Test the Dynamo torch_compile path" steps: @@ -953,6 +969,7 @@ jobs: # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env - test-dynamo-torch_compile + - test-dynamo-torch_compile-core - test-dynamo-fx_ts package-x86_64-linux: diff --git a/py/torch_tensorrt/dynamo/test/utils.py b/py/torch_tensorrt/dynamo/test/utils.py index ff6bc39158..b1e6632ec3 100644 --- a/py/torch_tensorrt/dynamo/test/utils.py +++ b/py/torch_tensorrt/dynamo/test/utils.py @@ -13,42 +13,3 @@ def cosine_similarity(gt_tensor, pred_tensor): res = res.cpu().detach().item() return res - - -def same_output_format(trt_output, torch_output): - # For each encountered collection type, ensure the torch and trt outputs agree - # on type and size, checking recursively through all member elements. - if isinstance(trt_output, tuple): - return ( - isinstance(torch_output, tuple) - and (len(trt_output) == len(torch_output)) - and all( - same_output_format(trt_entry, torch_entry) - for trt_entry, torch_entry in zip(trt_output, torch_output) - ) - ) - elif isinstance(trt_output, list): - return ( - isinstance(torch_output, list) - and (len(trt_output) == len(torch_output)) - and all( - same_output_format(trt_entry, torch_entry) - for trt_entry, torch_entry in zip(trt_output, torch_output) - ) - ) - elif isinstance(trt_output, dict): - return ( - isinstance(torch_output, dict) - and (len(trt_output) == len(torch_output)) - and (trt_output.keys() == torch_output.keys()) - and all( - same_output_format(trt_output[key], torch_output[key]) - for key in trt_output.keys() - ) - ) - elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): - raise AssertionError( - "Unsupported output type 'set' encountered in output format check." - ) - else: - return type(trt_output) is type(torch_output) diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py b/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py new file mode 100644 index 0000000000..da7157c3e5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py @@ -0,0 +1,57 @@ +from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs +from utils import same_output_format +import torch_tensorrt +import unittest +import torch + + +class TestPrepareDevice(unittest.TestCase): + def test_prepare_cuda_device(self): + gpu_id = 0 + device = torch.device(f"cuda:{gpu_id}") + prepared_device = prepare_device(device) + self.assertTrue(isinstance(prepared_device, torch.device)) + self.assertTrue(prepared_device.index == gpu_id) + + def test_prepare_trt_device(self): + gpu_id = 4 + device = torch_tensorrt.Device(gpu_id=gpu_id) + prepared_device = prepare_device(device) + self.assertTrue(isinstance(prepared_device, torch.device)) + self.assertTrue(prepared_device.index == gpu_id) + + +class TestPrepareInputs(unittest.TestCase): + def test_prepare_single_tensor_input(self): + inputs = [torch.ones((4, 4))] + prepared_inputs = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + ) + + def test_prepare_trt_input(self): + inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)] + prepared_inputs = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + ) + + def test_prepare_mixed_type_compound_tensor_input(self): + inputs = { + "first": [ + torch.ones((4, 4)), + torch_tensorrt.Input(shape=(4, 3), dtype=torch.float), + ], + "second": ( + torch.rand((5, 1)), + (torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))), + ), + } + prepared_inputs = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py b/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py new file mode 100644 index 0000000000..d14acb815b --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py @@ -0,0 +1,54 @@ +from functools import partial +from utils import fx_dynamo_testing_backend +from torch.testing._internal.common_utils import run_tests, TestCase +import torch + + +class TestLowering(TestCase): + def test_lowering_inplace_op(self): + class FullySupported(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + x = torch.ops.aten.add_.Tensor(x, y) + x = torch.ops.aten.relu_.default(x) + return x + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default} + + # Trace module and set up custom backend to track intermediate graphs + fx_graph = torch.fx.symbolic_trace(FullySupported()) + partitioned_graphs = [] + custom_backend = partial( + fx_dynamo_testing_backend, + store_intermediate_graphs=partitioned_graphs, + ) + + # Invoke compilation + compiled_graph = torch.compile(fx_graph, backend=custom_backend) + compiled_graph( + torch.rand( + 5, + ).cuda(), + torch.rand( + 5, + ).cuda(), + ) + + # Iterate over intermediate graphs, attempt to match nodes + for fx_module in partitioned_graphs: + for _, submodule in fx_module.named_children(): + for node in submodule.graph.nodes: + + if node.op == "call_function" and node.target in expected_ops: + expected_ops.remove(node.target) + + self.assertEqual( + len(expected_ops), 0, "All operators should have been decomposed" + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py b/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py new file mode 100644 index 0000000000..b068f9c413 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py @@ -0,0 +1,68 @@ +from torch_tensorrt.dynamo.torch_compile.lowering import partition +from torch.testing._internal.common_utils import run_tests, TestCase +import torch +from copy import deepcopy +import numpy as np + + +class TestPartitioning(TestCase): + def test_partition_fully_supported_one_op(self): + class FullySupportedOneOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) + partitioned_graph = partition(deepcopy(fx_graph)) + self.assertEqual( + len(list(partitioned_graph.named_children())), + 0, + "Single operators should not be segmented", + ) + + def test_partition_fully_supported_multi_op(self): + class FullySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_ = torch.ops.aten.sub.Tensor(x, y) + concat_ = torch.ops.aten.cat.default(x, sum_) + relu_ = torch.ops.aten.relu.default(concat_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph)) + self.assertEqual( + len(list(partitioned_graph.named_children())), + 1, + "All operators are supported, there should be one segment", + ) + + def test_partition_partially_supported_multi_op(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = np.sum(sum_1) + np.sum(sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph)) + self.assertEqual( + len(list(partitioned_graph.named_children())), + 2, + "Unsupported operators interleave supported ones, expected 2 segments", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/utils.py b/py/torch_tensorrt/dynamo/torch_compile/test/utils.py new file mode 100644 index 0000000000..bdcbbfcc4a --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/utils.py @@ -0,0 +1,94 @@ +from copy import deepcopy +from functools import partial +from typing import List, Sequence +import torch +from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( + partition, +) + +from torch._dynamo.backends.common import fake_tensor_unsupported + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + + +@fake_tensor_unsupported +def fx_dynamo_testing_backend( + gm: torch.fx.GraphModule, + sample_inputs: Sequence[torch.Tensor], + *, + store_intermediate_graphs: List, +): + """Helper Dynamo backend exclusively for testing""" + custom_backend = partial( + compile_module_testing, + store_intermediate_graphs=store_intermediate_graphs, + ) + + # Invoke AOTAutograd to translate operators to aten + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(custom_backend), + decompositions=get_decompositions(), + ) + + +def compile_module_testing( + gm: torch.fx.GraphModule, + example_inputs: Sequence[torch.Tensor], + *, + store_intermediate_graphs: List, +) -> torch.fx.GraphModule: + """Helper compiler exclusively for testing""" + partitioned_module = partition(gm) + + # Store intermediate graph from partitioned module + store_intermediate_graphs.append(deepcopy(partitioned_module)) + + return partitioned_module + + +def same_output_format(trt_output, torch_output, enforce_tensor_type=True): + # For each encountered collection type, ensure the torch and trt outputs agree + # on type and size, checking recursively through all member elements. + if isinstance(trt_output, tuple): + return ( + isinstance(torch_output, tuple) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry, enforce_tensor_type) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, list): + return ( + isinstance(torch_output, list) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry, enforce_tensor_type) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, dict): + return ( + isinstance(torch_output, dict) + and (len(trt_output) == len(torch_output)) + and (trt_output.keys() == torch_output.keys()) + and all( + same_output_format( + trt_output[key], torch_output[key], enforce_tensor_type + ) + for key in trt_output.keys() + ) + ) + elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): + raise AssertionError( + "Unsupported output type 'set' encountered in output format check." + ) + elif enforce_tensor_type: + return type(trt_output) is type(torch_output) + else: + return True diff --git a/py/torch_tensorrt/dynamo/torch_compile/utils.py b/py/torch_tensorrt/dynamo/torch_compile/utils.py index c096eb9397..ba76536338 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/utils.py +++ b/py/torch_tensorrt/dynamo/torch_compile/utils.py @@ -64,3 +64,5 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device: raise ValueError( "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" ) + + return device From b8cd7c39f7951f8023bee65448be4c218a94f32a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 2 May 2023 14:54:39 -0700 Subject: [PATCH 45/45] Improve warning wording --- py/torch_tensorrt/_Device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 4a53ad8885..3eaa5aad4e 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -9,8 +9,7 @@ from torch_tensorrt import _C except: warnings.warn( - "Unable to import _C extension of Torch-TensorRT. Some methods might be unavailable. You can ignore this error if you're \ - not using any functions dependent on internal C++ APIs" + "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." )