diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 3e10fd7c7d..5006c43378 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -30,6 +30,27 @@ std::vector split(const std::string& str, char delim) { return strings; } +DynamicOutputAllocator::DynamicOutputAllocator(const std::unordered_map& output_dtypes) + : dtypes(output_dtypes) {} + +void* DynamicOutputAllocator::reallocateOutputAsync( + char const* tensorName, + void* currentMemory, + uint64_t size, + uint64_t alignment, + cudaStream_t stream) { + std::vector shape = {static_cast(size)}; + auto it = buffers.find(tensorName); + if (it == buffers.end() || it->second.sizes() != shape) { + buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(c10::kCUDA)); + } + return buffers[tensorName].data_ptr(); +} + +void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept { + shapes[tensorName] = dims; +} + TRTEngine::TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, @@ -137,7 +158,6 @@ TRTEngine::TRTEngine( in_binding_names.resize(inputs); input_buffers.resize(inputs); out_binding_names.resize(outputs); - output_buffers.resize(outputs); for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) { std::string bind_name = cuda_engine->getIOTensorName(x); if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) { @@ -179,7 +199,6 @@ TRTEngine::TRTEngine( uint64_t outputs = _out_binding_names.size(); out_binding_names.resize(outputs); - output_buffers.resize(outputs); for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) { auto binding_name = _out_binding_names[pyt_idx]; // Check if the binding name provided is in the list of engine's bindings diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index e1d8ba5471..1a60c1b049 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -69,11 +69,39 @@ struct TorchTRTRuntimeStates { } }; +class DynamicOutputAllocator : public nvinfer1::IOutputAllocator { + public: + DynamicOutputAllocator(const std::unordered_map& output_dtypes); + + void* reallocateOutputAsync( + char const* tensorName, + void* currentMemory, + uint64_t size, + uint64_t alignment, + cudaStream_t stream) override; + + void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; + + const std::unordered_map& getBuffers() const { + return buffers; + } + + const std::unordered_map& getShapes() const { + return shapes; + } + + private: + std::unordered_map dtypes; + std::unordered_map buffers; + std::unordered_map shapes; +}; + struct TRTEngine : torch::CustomClassHolder { // Each engine needs it's own runtime object std::shared_ptr rt; std::shared_ptr cuda_engine; std::shared_ptr exec_ctx; + std::shared_ptr output_allocator; std::pair num_io; std::string name; RTDevice device_info; @@ -141,7 +169,6 @@ struct TRTEngine : torch::CustomClassHolder { at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream(); std::vector input_buffers = {}; - std::vector output_buffers = {}; std::string shape_key = "None"; bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 5348ade8c4..8721d3a3a9 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -163,22 +163,23 @@ void setup_input_tensors( } } } -std::vector create_output_tensors(c10::intrusive_ptr compiled_engine) { - std::vector outputs(compiled_engine->num_io.second); - for (auto output_indices : compiled_engine->out_binding_map) { - // out_binding_map stores TRT_IDX: PYT_IDX - auto pyt_idx = output_indices.second; - - std::string name = compiled_engine->out_binding_names[pyt_idx]; - auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str()); - LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape); - - auto dims = core::util::toVec(out_shape); - auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); + +void setup_output_allocator(c10::intrusive_ptr compiled_engine) { + if (compiled_engine->output_allocator == nullptr) { + std::unordered_map output_dtypes_dict; + for (size_t o = 0; o < compiled_engine->out_binding_names.size(); ++o) { + auto name = compiled_engine->out_binding_names[o]; + output_dtypes_dict[name] = + util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + } + compiled_engine->output_allocator = std::make_shared(output_dtypes_dict); } - return outputs; + for (const auto& output_name : compiled_engine->out_binding_names) { + if (!compiled_engine->exec_ctx->setOutputAllocator(output_name.c_str(), compiled_engine->output_allocator.get())) { + throw std::runtime_error("Failed to set output allocator for " + output_name); + } + } } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { @@ -218,7 +219,6 @@ std::vector execute_engine(std::vector inputs, c10::intr } // Intialize inputs and outputs to be available throughout the succeeding scopes - std::vector outputs(compiled_engine->num_io.second); if (MULTI_DEVICE_SAFE_MODE) { std::unique_ptr device_profiler_guard; @@ -287,44 +287,20 @@ std::vector execute_engine(std::vector inputs, c10::intr << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); } - { // Output Setup - std::unique_ptr output_profiler_guard; + { // OutputAllocator Setup + std::unique_ptr output_allocator_profiler_guard; if (compiled_engine->profile_execution) { - output_profiler_guard = + output_allocator_profiler_guard = std::make_unique(compiled_engine->output_profile_path); } - if (can_use_pre_allocated_outputs) { - outputs = compiled_engine->pre_allocated_outputs; - } else { - outputs = create_output_tensors(compiled_engine); - } - - for (auto output_indices : compiled_engine->out_binding_map) { - auto pyt_idx = output_indices.second; - std::string name = compiled_engine->out_binding_names[pyt_idx]; - if (need_cudagraphs_record) { - // If we are recording the cuda graph then we need to update the persistent output buffer - compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); - } - - if (cudagraphs_enabled) { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress( - name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); - } else { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); - } - } + setup_output_allocator(compiled_engine); } auto current_device_id = -1; if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else if (outputs.size() > 0) { - current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart + } else { + current_device_id = c10::cuda::current_device(); } compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); @@ -368,21 +344,32 @@ std::vector execute_engine(std::vector inputs, c10::intr } } // End engine exeuction (resets to caller stream) - // Create output buffer for next execution of graph or trt context. - if (compiled_engine->use_pre_allocated_outputs) { - compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); - } - // Block caller stream until engine execution is complete at::cuda::CUDAEvent trt_exec_complete; trt_exec_complete.record(compiled_engine->engine_stream); trt_exec_complete.block(compiled_engine->caller_stream); - if (cudagraphs_enabled) { - // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) - for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { - outputs[o].copy_(compiled_engine->output_buffers[o], false); + std::unique_ptr output_profiler_guard; + if (compiled_engine->profile_execution) { + output_profiler_guard = + std::make_unique(compiled_engine->output_profile_path); + } + std::vector outputs; + for (size_t i = 0; i < compiled_engine->out_binding_names.size(); i++) { + auto name = compiled_engine->out_binding_names[i]; + auto dims = compiled_engine->output_allocator->getShapes().at(name); + auto dtype = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + at::Tensor output = compiled_engine->output_allocator->getBuffers().at(name).clone().detach(); + int64_t prod = 1; + for (int i = 0; i < dims.nbDims; ++i) { + prod *= dims.d[i]; + } + std::vector dims_vec(dims.nbDims); + for (int i = 0; i < dims.nbDims; ++i) { + dims_vec[i] = dims.d[i]; } + output = output.reshape(-1).view(dtype).slice(0, 0, prod).reshape(dims_vec); + outputs.push_back(output); } if (compiled_engine->profile_execution) { diff --git a/examples/dynamo/pre_allocated_output_example.py b/examples/dynamo/pre_allocated_output_example.py deleted file mode 100644 index d938034758..0000000000 --- a/examples/dynamo/pre_allocated_output_example.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -.. _pre_allocated_output_example: - -Pre-allocated output buffer -====================================================== - -The TensorRT runtime module acts as a wrapper around a PyTorch model (or subgraph) that has been compiled and optimized into a TensorRT engine. - -When the compiled module is executed, input and output tensors are set to TensorRT context for processing. -If output buffer allocation is moved after the execution of the TensorRT context and used it for next inference, GPU tasks and memory allocation tasks can operate concurrently. This overlap allows for more efficient use of GPU resources, potentially improving the performance of inference. - -This optimization is particularly effective in below cases - -1. Small inference time - - The allocation of output buffers typically requires minimal CPU cycles, as the caching mechanism efficiently handles memory reuse. The time taken for this allocation is relatively constant compared to the overall inference time, leading to noticeable performance improvements, especially in scenarios involving small inference workloads. This is because the reduced allocation time contributes to faster execution when the computational workload is not large enough to overshadow these savings. -2. Multiple graph breaks - - If the module contains operations that are not supported by TensorRT, the unsupported parts are handled by PyTorch and this fallback results in a graph break. The cumulative effect of optimized buffer allocations across multiple subgraphs can enhance overall inference performance. - - While optimizing output buffers can mitigate some of this overhead, reducing or removing graph breaks should be prioritized as it enables more comprehensive optimizations -3. Static input or infrequent input shape change - - If shape is changed, pre-allocated buffer cannot be used for next inference and there will new allocation before executing the TensorRT context. This feature is not suitable for use cases with frequent input shape changes -""" - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -import timeit - -import numpy as np -import torch -import torch_tensorrt -from transformers import BertModel - -# %% -# Define function to measure inference performance -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - - -def test_module_perf(model, *input): - timings = [] - - # Warm-up phase to ensure consistent and accurate performance measurements. - with torch.no_grad(): - for _ in range(3): - model(*input) - torch.cuda.synchronize() - - # Timing phase to measure inference performance - with torch.no_grad(): - for i in range(10): - start_time = timeit.default_timer() - model(*input) - torch.cuda.synchronize() - end_time = timeit.default_timer() - timings.append(end_time - start_time) - times = np.array(timings) - time_med = np.median(times) - - # Return the median time as a representative performance metric - return time_med - - -# %% -# Load model and compile -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Load bert model -model = ( - BertModel.from_pretrained("bert-base-uncased", torchscript=True) - .eval() - .half() - .to("cuda") -) -# Define sample inputs -inputs = [ - torch.randint(0, 5, (1, 128), dtype=torch.int32).to("cuda"), - torch.randint(0, 5, (1, 128), dtype=torch.int32).to("cuda"), -] -# Next, we compile the model using torch_tensorrt.compile -optimized_model = torch_tensorrt.compile( - model, - ir="dynamo", - enabled_precisions={torch.half}, - inputs=inputs, -) - -# %% -# Enable/Disable pre-allocated output buffer feature using runtime api -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Enable pre-allocated output buffer using a context manager -with torch_tensorrt.runtime.enable_pre_allocated_outputs(optimized_model): - out_trt = optimized_model(*inputs) - # Subsequent inferences can use the pre-allocated output buffer (no shape change) - out_trt = optimized_model(*inputs) - -# Alternatively, we can enable the feature using a context object -pre_allocated_output_ctx = torch_tensorrt.runtime.enable_pre_allocated_outputs( - optimized_model -) -pre_allocated_output_ctx.set_pre_allocated_output(True) -time_opt = test_module_perf(optimized_model, *inputs) - -# Disable the pre-allocated output buffer feature and perform inference normally -pre_allocated_output_ctx.set_pre_allocated_output(False) -out_trt = optimized_model(*inputs) -time_normal = test_module_perf(optimized_model, *inputs) - -time_opt_ms = time_opt * 1000 -time_normal_ms = time_normal * 1000 - -print(f"normal trt model time: {time_normal_ms:.3f} ms") -print(f"pre-allocated output buffer model time: {time_opt_ms:.3f} ms") diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4f2d168d29..68e695dddf 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3582,3 +3582,20 @@ def aten_ops_full( fill_value=args[1], dtype=kwargs.get("dtype", None), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default) +def aten_ops_nonzero( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.nonzero( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 34b667acf1..89e490392d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -624,3 +624,18 @@ def native_dropout( mask = np.ones(input_val.shape, dtype=bool) mask = get_trt_tensor(ctx, mask, f"{name}_mask") return identity_layer.get_output(0), mask + + +def nonzero( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + non_zero_layer = ctx.net.add_non_zero(input_val) + set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir) + shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0)) + shuffle_layer.first_transpose = trt.Permutation([1, 0]) + set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir) + return shuffle_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 9086de657f..5c24f96be1 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -12,7 +12,6 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, @@ -23,25 +22,56 @@ logger = logging.getLogger(__name__) +class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] + def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: + trt.IOutputAllocator.__init__(self) + self.buffers: Dict[str, torch.Tensor] = {} + self.shapes: Dict[str, Tuple[int, ...]] = {} + self.dtypes: Dict[str, torch.dtype] = output_dtypes + + def reallocate_output_async( + self, + tensor_name: str, + memory: int, + size: int, + alignment: int, + stream: torch.cuda.Stream, + ) -> Any: + shape = (size,) + if tensor_name not in self.buffers: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + else: + if self.buffers[tensor_name].shape != shape: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + return self.buffers[tensor_name].data_ptr() + + def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None: + self.shapes[tensor_name] = tuple(shape) + + class TorchTRTRuntimeStates: def __init__(self, new_cudagraphs: bool): # Indicates whether CUDAGraphs were enabled in the previous execute_engine self.old_cudagraphs = new_cudagraphs - # Indicates whether pre-allocated output was enabled in the previous execute_engine - self.old_pre_allocated_outputs = False # Indicates whether context has changed self.context_changed = False def set_runtime_states( self, new_cudagraphs: bool, - new_pre_allocated_output: bool, - shape_changed: bool, - ) -> Tuple[bool, bool, bool]: - # Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs - # based on the current and previous states, as well as input shape has changed + input_shape_changed: bool, + ) -> Tuple[bool, bool]: + # Evaluates whether certain conditions are met to enable CUDA Graph recording + # based on the current and previous states and whether context or input shape changed need_cudagraphs_record = False - can_use_pre_allocated_outputs = False need_cudagraphs_reset = False # CUDA Graph recording is needed if CUDA graphs is enabled and: @@ -49,29 +79,19 @@ def set_runtime_states( # - or the shape has changed # - or the execution context has changed (e.g., weight streaming) if new_cudagraphs and ( - not self.old_cudagraphs or shape_changed or self.context_changed + not self.old_cudagraphs or input_shape_changed or self.context_changed ): need_cudagraphs_record = True - # Pre-allocated output can be used when previous and current state are true without shape change - if ( - self.old_pre_allocated_outputs - and new_pre_allocated_output - and (not shape_changed) - ): - can_use_pre_allocated_outputs = True - - if not new_cudagraphs or shape_changed or self.context_changed: + if not new_cudagraphs or input_shape_changed or self.context_changed: need_cudagraphs_reset = True self.old_cudagraphs = new_cudagraphs - self.old_pre_allocated_outputs = new_pre_allocated_output # reset flag self.context_changed = False return ( need_cudagraphs_record, - can_use_pre_allocated_outputs, need_cudagraphs_reset, ) @@ -128,7 +148,6 @@ def __init__( self.name = name self._input_buffers: List[torch.Tensor] = [] - self._output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None self._caller_stream: Optional[torch.cuda.Stream] = None self._engine_stream: Optional[torch.cuda.Stream] = None @@ -147,6 +166,8 @@ def __init__( self.output_names = ( output_binding_names if output_binding_names is not None else [] ) + self.output_allocator: Optional[DynamicOutputAllocator] = None + self.initialized = False self.target_device_id = ( settings.device.gpu_id @@ -164,8 +185,6 @@ def __init__( self.runtime_states = TorchTRTRuntimeStates( torch_tensorrt.runtime.get_cudagraphs_mode() ) - self.pre_allocated_outputs: List[torch.Tensor] = [] - self.use_pre_allocated_outputs = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -233,10 +252,6 @@ def setup_engine(self) -> None: dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype) for output_name in self.output_names ] - self.output_shapes = [ - self.engine.get_tensor_shape(output_name) - for output_name in self.output_names - ] if torch_tensorrt.runtime.get_cudagraphs_mode(): self.cudagraph = torch.cuda.CUDAGraph() @@ -320,7 +335,7 @@ def setup_input_tensors( # Clone is required to avoid re-using user-provided GPU memory self._input_buffers[i] = contiguous_inputs[i].clone() - # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers + # For shape tensors, we use CPU pointers; for data tensors, we use GPU pointers # as per TensorRT requirements if self.engine.is_shape_inference_io(input_name): # Shape tensor inputs are casted to int64 explicitly @@ -342,21 +357,18 @@ def setup_input_tensors( input_name, contiguous_inputs[i].data_ptr() ) - def create_output_tensors(self) -> List[torch.Tensor]: - # create output tensors - outputs: List[torch.Tensor] = [] - - for o, _ in enumerate(self.output_names): - output = torch.empty( - size=self.output_shapes[o], - dtype=self.output_dtypes[o], - device=torch.cuda.current_device(), - ) - outputs.append(output) - return outputs + def setup_output_allocator(self) -> None: + if self.output_allocator is None: + output_dtypes_dict = {} + for o, output_name in enumerate(self.output_names): + output_dtypes_dict[output_name] = self.output_dtypes[o] + self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) - def set_pre_allocated_outputs(self, enable: bool) -> None: - self.use_pre_allocated_outputs = enable + for output_name in self.output_names: + if not self.context.set_output_allocator( + output_name, self.output_allocator + ): + raise RuntimeError(f"Failed to set output allocator for {output_name}") def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: # Ensure inputs are available in all scopes and cast symbolic integers to Tensors @@ -372,13 +384,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._check_initialized() cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - shape_changed = self.validate_input_shapes(inputs) + input_shape_changed = self.validate_input_shapes(inputs) ( need_cudagraphs_record, - can_use_pre_allocated_outputs, need_cudagraphs_reset, ) = self.runtime_states.set_runtime_states( - cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed + cudagraphs_enabled, input_shape_changed ) if need_cudagraphs_reset and self.cudagraph: @@ -387,7 +398,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) - self._output_buffers = [None] * len(self.output_names) # If in safe mode, check at each iteration for whether a switch is required if ( @@ -436,7 +446,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . contiguous_inputs, cudagraphs_enabled, need_cudagraphs_record ) - if shape_changed: + if input_shape_changed: # Check if input shapes can be inferred. uninferred_input_names = self.context.infer_shapes() if uninferred_input_names: @@ -447,36 +457,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . with ( torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" + "PythonTorchTensorRTModule:ProcessOutputAllocators" ) if self.profiling_enabled else nullcontext() ): - if can_use_pre_allocated_outputs: - outputs = self.pre_allocated_outputs - else: - self.output_shapes = [ - tuple(self.context.get_tensor_shape(output_name)) - for output_name in self.output_names - ] - if DYNAMIC_DIM in self.output_shapes: - raise ValueError( - "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." - ) - outputs = self.create_output_tensors() - - for o, output_name in enumerate(self.output_names): - if need_cudagraphs_record: - self._output_buffers[o] = outputs[o].clone() - - if cudagraphs_enabled: - self.context.set_tensor_address( - output_name, self._output_buffers[o].data_ptr() - ) - else: - self.context.set_tensor_address( - output_name, outputs[o].data_ptr() - ) + self.setup_output_allocator() with ( torch.autograd.profiler.record_function( @@ -495,6 +481,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): + if cudagraphs_enabled: if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() @@ -507,7 +494,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ): self.context.execute_async_v3( self._engine_stream.cuda_stream - ) + ) # The OutputAllocator is called by execute_async_v3() if self.profiling_enabled: import tempfile @@ -524,12 +511,26 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._caller_stream.wait_stream(self._engine_stream) - if self.use_pre_allocated_outputs: - self.pre_allocated_outputs = self.create_output_tensors() - - if cudagraphs_enabled: - for idx, o in enumerate(outputs): - o.copy_(self._output_buffers[idx]) + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + outputs = [] + for o, output_name in enumerate(self.output_names): + assert self.output_allocator is not None + shape = self.output_allocator.shapes.get(output_name, None) + dtype = self.output_dtypes[o] + output = ( + self.output_allocator.buffers.get(output_name, None) + .clone() + .detach() + ) + prod = int(torch.prod(torch.tensor(shape))) + output = output.reshape(-1).view(dtype)[:prod].reshape(shape) + outputs.append(output) if len(outputs) == 1: return outputs[0] diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index b809e70ddf..d7cfc6608b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -272,9 +272,6 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.input_binding_names = state[2] self.output_binding_names = state[3] - def set_pre_allocated_outputs(self, enable: bool) -> None: - self.engine.use_pre_allocated_outputs = enable - def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 470074a377..09a478e807 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -9,5 +9,4 @@ set_cudagraphs_mode, ) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode -from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_pre_allocated_outputs.py b/py/torch_tensorrt/runtime/_pre_allocated_outputs.py deleted file mode 100644 index c392c38838..0000000000 --- a/py/torch_tensorrt/runtime/_pre_allocated_outputs.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -from typing import Any - -import torch -from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule - -logger = logging.getLogger(__name__) - - -class _PreAllocatedOutputContextManager(object): - """ - Helper class used to enable pre-allocated output feature in runtime module - """ - - def __init__(self, module: torch.fx.GraphModule) -> None: - rt_mods = [] - for name, rt_mod in module.named_children(): - if "_run_on_acc" in name and isinstance( - rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) - ): - rt_mods.append(rt_mod) - self.rt_mods = rt_mods - - def set_pre_allocated_output(self, enable: bool) -> None: - for mod in self.rt_mods: - mod.set_pre_allocated_outputs(enable) - - def __enter__(self) -> "_PreAllocatedOutputContextManager": - # Enable pre-allocated output - self.set_pre_allocated_output(True) - return self - - def __exit__(self, *args: Any) -> None: - # Disable pre-allocated output - self.set_pre_allocated_output(False) - - -def enable_pre_allocated_outputs( - module: torch.fx.GraphModule, -) -> _PreAllocatedOutputContextManager: - return _PreAllocatedOutputContextManager(module) diff --git a/tests/py/dynamo/conversion/test_nonzero_aten.py b/tests/py/dynamo/conversion/test_nonzero_aten.py new file mode 100644 index 0000000000..c75bd4d4a9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_nonzero_aten.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestNonZeroConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.int), + ((1, 20), torch.int32), + ((2, 3), torch.int64), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_non_zero(self, input_shape, dtype): + class NonZero(nn.Module): + def forward(self, input): + return torch.ops.aten.nonzero.default(input) + + inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)] + self.run_test( + NonZero(), + inputs, + ) + + @parameterized.expand( + [ + ( + "1d", + (1,), + (10,), + (100,), + torch.int32, + ), + ( + "2d", + (1, 2), + (5, 10), + (20, 40), + torch.float16, + ), + ( + "3d", + (1, 2, 3), + (5, 10, 20), + (30, 40, 50), + torch.float, + ), + ] + ) + def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype): + class NonZero(nn.Module): + def forward(self, input): + return torch.ops.aten.nonzero.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=dtype, + ), + ] + + self.run_test_with_dynamic_shape(NonZero(), input_specs) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 78522388d1..4041574b67 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -343,9 +343,9 @@ def forward(self, x): use_python_runtime=use_python_runtime, ) - # List of tuples representing different configurations for three features: - # Cuda graphs, pre-allocated output buffer, weight streaming change - states = list(itertools.product((True, False), repeat=3)) + # List of tuples representing different configurations for two features: + # Cuda graphs and weight streaming change + states = list(itertools.product((True, False), repeat=2)) # Create pairs of configurations representing an initial state and a changed state states_permutations = itertools.permutations(states, 2) @@ -368,16 +368,11 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list): for i in range(len(input_list)): ref_out_list.append(model(input_list[i])) - pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs( - optimized_model - ) - for init_state, changed_state in states_permutations: - for cuda_graphs, pre_allocated_output, weight_streaming in [ + for cuda_graphs, weight_streaming in [ init_state, changed_state, ]: - pre_allocated_output_ctx.set_pre_allocated_output(pre_allocated_output) if cuda_graphs: with torchtrt.runtime.enable_cudagraphs( optimized_model diff --git a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py deleted file mode 100644 index b8c7b61fb3..0000000000 --- a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch -import torch_tensorrt as torchtrt -from parameterized import parameterized -from torch.testing._internal.common_utils import TestCase, run_tests - -INPUT_SIZE = (3, 16, 16) -TRIALS = 5 - - -class TestPreAllocatedOutputs(TestCase): - @parameterized.expand( - [ - ("python_runtime", True), - ("cpp_runtime", False), - ] - ) - def test_pre_allocated_outputs_default(self, _, use_python_runtime): - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.softmax((x + 2) * 7, dim=0) - - model = SampleModel().eval().cuda() - inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] - fx_graph = torch.fx.symbolic_trace(model) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torchtrt.compile( - fx_graph, - "torch_compile", - inputs[0], - min_block_size=1, - pass_through_build_failures=True, - use_python_runtime=use_python_runtime, - ) - - ref_out_list = [] - trt_out_list = [] - with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): - for i in inputs: - ref_out_list.append(fx_graph(i).detach().cpu()) - trt_out_list.append(optimized_model(i).detach().cpu()) - - for torch_model_results, optimized_model_results in zip( - ref_out_list, trt_out_list - ): - torch.testing.assert_close( - torch_model_results, - optimized_model_results, - rtol=5e-03, - atol=5e-03, - equal_nan=True, - check_dtype=True, - ) - - torch._dynamo.reset() - - @parameterized.expand( - [ - ("python_runtime", True), - ("cpp_runtime", False), - ] - ) - def test_pre_allocated_outputs_dynamic(self, _, use_python_runtime): - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.relu((x + 2) * 0.5) - - inputs = torchtrt.Input( - min_shape=(1, 3, 128, 224), - opt_shape=(8, 3, 192, 224), - max_shape=(16, 3, 224, 224), - dtype=torch.float, - name="x", - ) - fx_graph = torch.fx.symbolic_trace(SampleModel()) - - optimized_model = torchtrt.compile( - fx_graph, - "dynamo", - inputs, - min_block_size=1, - pass_through_build_failures=True, - torch_executed_ops={"torch.ops.aten.mul.Tensor"}, - use_python_runtime=use_python_runtime, - ) - - input_list = [] - ref_out_list = [] - trt_out_list = [] - # Alternating cuda_graphs enable and input shapes at every five iterations. - for i in [1, 3, 8, 11, 16]: - for j in [128, 128, 222, 222, 224]: - input_list.append(torch.randn((i, 3, j, 224)).cuda()) - - pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs( - optimized_model - ) - pre_allocated_output = False - for enable_cuda_graphs in [False, True]: - for i in range(len(input_list)): - # Toggles cuda graph at all index in TRIALS - if i % TRIALS == i // TRIALS: - cuda_graphs = enable_cuda_graphs - else: - cuda_graphs = not enable_cuda_graphs - if i % 3 == 0: - pre_allocated_output = not pre_allocated_output - - torchtrt.runtime.set_cudagraphs_mode(cuda_graphs) - pre_allocated_output_ctx.set_pre_allocated_output(pre_allocated_output) - - ref_out_list.append(fx_graph(input_list[i])) - trt_out_list.append(optimized_model(input_list[i])) - - for torch_model_results, optimized_model_results in zip( - ref_out_list, trt_out_list - ): - torch.testing.assert_close( - torch_model_results, - optimized_model_results, - rtol=5e-03, - atol=5e-03, - equal_nan=True, - check_dtype=True, - ) - torch._dynamo.reset() - - -if __name__ == "__main__": - run_tests()