Replies: 4 comments 4 replies
-
@apbose is there a way to create a hard subset of complex that we can support easily and grow from there? |
Beta Was this translation helpful? Give feedback.
-
There needs to be more detail here: For example there are many possible designs, do we need to change the graph signature, or do we put the unpacking code in the graph or does the runtime need to detect and insert decomposition for complex inputs at the runtime level? |
Beta Was this translation helpful? Give feedback.
-
There is likely some sort of data structure we need here that marks insertion points so we can replace the graph later, as well as the original graph and then later on the new graph |
Beta Was this translation helpful? Give feedback.
-
@apbose write some standalone cases where we have the original complex graph and the expected target graph |
Beta Was this translation helpful? Give feedback.
-
Complex number handling in Torch-TensorRT
TL;DR
This RFC proposes the addition of complex number support in Torch-TensorRT. TensorRT does not support complex numbers, but with the use of rotary embeddings in positional embeddings, complex numbers play an important role on how these embeddings are applied.
Goal
To support the multi-GPU example of Llama 3 model running end to end
Use case
Through this feature we intend to demonstrate the end to end forward pass of torchTRT compiled llama3 distributed model in multi GPU. Below illustrates how complex numbers are inputs to the llama3 model
The query and key vectors are viewed as complex, while the freq vectors are computed in the polar form with complex frequency.
The reason we encounter this only for distributed examples is because when we compile the model using
torch.compile(distributed_model, backend = torch_tensorrt)
The distributed tensors are hoisted to inputs when model is wrapped with
aot_autograd
leading to complex inputs to torchTRT compiled graph.Ref- pytorch/pytorch#136289
Implementation Stages
Complex unpacking
Convert the complex numbers into a tuple of real and imaginary parts. Complex number denoted by x+iy, should be provided as input in the form of (x,y)
This involves modifying the meta data shape and data type of the complex nodes. Also the subsequent operations with these complex numbers as input
Numeric truncation
In the above complex64 should be unpacked to a tuple of float32. Similarly complex128 should be unpacked to a tuple of float32. For which the truncate_flag has to be used
Function signature modification
Identify the boundary of the operations affected by the complex inputs. Below is an example of how it looks like in llama3 model for the rotary embedding operation
eg:
The signature of these complex operations needs to be modified so that there are no graph breaks, and it handles the complex unpacking also
Unification of pre_lowering and post_lowering pass for distributed and non distributed
The
pre_lowering
andpost_lowering
needs to uniform across both distributed and non distributed cases.Diagram
In the above there has to be additional handling in the torch TRT runtime. All the above will be called via an API in the post lowering passes.
API changes
We will discuss these APIs citing the example of rotary embeddings in Llama3 model
Detection stage
torch_tensorrt/dynamo/lowering/passes/pass_utils.py
The above API should return the subgraph . The metainfo of the subgraph can be captured in a class named complexSubGraphInfo
With respect to Llama3, it is the subgraph as denoted in the figure below. This graph broadly captures the operations in the rotary embeddings of query and key vectors. Please note that the freq node is denoted by
rehape_default_12
, which is same for both query and key vectors. The rotated query and key vector embeddings from the view_as_real nodes are then inputs to the scaled_dot_product_attention torch node with the value vector. That means for our design below, in each subgraph, we can have multiple anchor nodes.n_heads = 32, n_kv_heads = 8, so there are 32/8 = 4 such subgraphs to be captured.
anchor_node = [view_as_real, view_as_real_1], subgraph_nodes = [view_as_complex, mul_2, slice_1, reshape_default_12, mul_3, view_as_complex_1], input_nodes = [reshape_default_10, arg2_1, reshape_default_11]
Below elaborates on the detector class which returns the subgraph guided by the anchor nodes.
It can be called as
Decomposition stage
Here we can decompose the complex input to input.real and input.complex, and concatenate along the last dimension for torch_inputs in
torch_tensorrt/dynamo/backend/backends.py
. These would be at the indices returned above.Graph Rewrite stage
torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite
'torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite
Modifying the graph nodes and their signature can be done through-
torch.fx.subgraph_rewriter.replace_pattern_with_filters()
with approriate match filters.The above are explained in the below diagram with respect to Llama 3 examples
Blue arrows denote the graph modifications/rewrite passes.
Below represents the modified target graph
All the above need to be called sequentially in the
torch_tensorrt/dynamo/backend/backends.py
Further to be explored are the changes in the runtimes in
_PythonTorchTensorRTModule.py
,_TorchTensorRTModule.py
,_CudaGraphsTorchTensorRTModule.py
since we are modifying the inputsRuntime changes
In all the above runtimes, since the inputs is now processed such that the partitioned graph module should now get a reshaped input of real and complex number, with last dimension as 2 (eg 1, 512, 1, 64 -> 1, 512, 1, 64, 2) the input_tensors need to be reshaped and fed to the corresponding runtimes.
This can be done in the runtime level when we can insert decomposition.
In
_PythonTorchTensorRTModule.py
this can be done insetup_input_tensors
. When the contiguous_inputs[i] dtype or shape varies from the compiled engine input dtype or shape inwe change the shape and dtype of input. We could use the same API as in the decomposition stage. A similar analysis can be done in _TorchTensorRTModule.py
Beta Was this translation helpful? Give feedback.
All reactions