diff --git a/.github/scripts/install-torch-tensorrt.sh b/.github/scripts/install-torch-tensorrt.sh index 2930421d5b..69a318c179 100644 --- a/.github/scripts/install-torch-tensorrt.sh +++ b/.github/scripts/install-torch-tensorrt.sh @@ -4,6 +4,6 @@ set -eou pipefail source ${BUILD_ENV_FILE} ${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision pyyaml export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()") -${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com +${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.nvidia.com -echo -e "Running test script"; \ No newline at end of file +echo -e "Running test script"; diff --git a/README.md b/README.md index 875b640304..78602255f1 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR - Libtorch 2.3.0.dev (latest nightly) (built with CUDA 12.1) - CUDA 12.1 - cuDNN 8.9.5 -- TensorRT 8.6.1 +- TensorRT 9.2.0 ## Prebuilt Binaries and Wheel files diff --git a/WORKSPACE b/WORKSPACE index bbc1803296..dc1bdb50d8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -81,10 +81,10 @@ http_archive( http_archive( name = "tensorrt", build_file = "@//third_party/tensorrt/archive:BUILD", - sha256 = "0f8157a5fc5329943b338b893591373350afa90ca81239cdadd7580cd1eba254", - strip_prefix = "TensorRT-8.6.1.6", + sha256 = "3dd505a9e0d0adf9257080b543f51d91df736dbd1f75417b9dde1a7b7a5d87f2", + strip_prefix = "tensorrt-9.2.0.5", urls = [ - "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/8.6.1/tars/TensorRT-8.6.1.6.Linux.x86_64-gnu.cuda-12.0.tar.gz", + "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz", ], ) diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 442485474c..a4cc2974f1 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,4 +1,4 @@ __version__: "2.3.0.dev0" __cuda_version__: "12.1" __cudnn_version__: "8.9" -__tensorrt_version__: "8.6" +__tensorrt_version__: "9.2.0.post12.dev5" diff --git a/docker/README.md b/docker/README.md index 9f83f25134..0b4799a6a5 100644 --- a/docker/README.md +++ b/docker/README.md @@ -17,14 +17,14 @@ Note: By default the container uses the `pre-cxx11-abi` version of Torch + Torch ### Instructions -- The example below uses CUDNN 8.9 and TensorRT 8.6 +- The example below uses CUDNN 8.9 and TensorRT 9.2 - See dependencies for a list of current default dependencies. > From root of Torch-TensorRT repo Build: ``` -DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=8.6 --build-arg CUDNN_VERSION=8.9 -f docker/Dockerfile -t torch_tensorrt:latest . +DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=9.2 --build-arg CUDNN_VERSION=8.9 -f docker/Dockerfile -t torch_tensorrt:latest . ``` Run: diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 18cd5d9fe2..51b1bcfdd7 100755 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -2,15 +2,35 @@ # Install dependencies python3 -m pip install pyyaml -TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") +CUDNN_VERSION=$(python3 -c "import versions; print(versions.__cudnn_version__.split('.')[0])") yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo yum check-update -yum install -y ninja-build gettext tensorrt-${TRT_VERSION}.* +yum install -y ninja-build gettext libcudnn${CUDNN_VERSION} libcudnn${CUDNN_VERSION}-devel wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \ && mv bazelisk-linux-amd64 /usr/bin/bazel \ && chmod +x /usr/bin/bazel +wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz +mkdir -p /usr/tensorrt +tar -xzvf tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz -C /usr/tensorrt --strip-components=1 +mkdir -p /usr/lib +cp /usr/tensorrt/lib/* /usr/lib/ || : +mkdir -p /usr/lib64 +cp /usr/tensorrt/lib/* /usr/lib64/ || : +mkdir -p /usr/include +cp /usr/tensorrt/include/* /usr/include/ || : + +mkdir -p /usr/lib/x86_64-linux-gnu +cp /usr/tensorrt/targets/x86_64-linux-gnu/lib/* /usr/lib/x86_64-linux-gnu/ || : +mkdir -p /usr/include/x86_64-linux-gnu +cp /usr/tensorrt/targets/x86_64-linux-gnu/include/* /usr/include/x86_64-linux-gnu/ || : + +rm tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz +rm -rf /usr/tensorrt + export TORCH_BUILD_NUMBER=$(python -c "import torch, urllib.parse as ul; print(ul.quote_plus(torch.__version__))") cat toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl | envsubst > WORKSPACE export CI_BUILD=1 + +python -m pip config set global.extra-index-url "https://pypi.nvidia.com" diff --git a/py/ci/soname_excludes.params b/py/ci/soname_excludes.params index a5eecb7c9a..c92ceb123f 100644 --- a/py/ci/soname_excludes.params +++ b/py/ci/soname_excludes.params @@ -24,11 +24,13 @@ --exclude libcudart.so.11 --exclude libcudart.so.11.7.60 --exclude libnvrtc.so.11.2 +--exclude libnvinfer_plugin.so.9 --exclude libnvinfer_plugin.so.8 --exclude libcublas.so.11 --exclude libcuda.so.1 --exclude libcuda.so.515 --exclude libcublasLt.so.11 +--exclude libnvinfer.so.9 --exclude libnvinfer.so.8 --exclude libcudnn.so.8 --exclude libcublas.so.12 @@ -36,4 +38,4 @@ --exclude libcublas.so.12.1.3.1 --exclude libcublasLt.so.12.1.3.1 --exclude libcudart.so.11.8.89 ---exclude libcudart.so.11 \ No newline at end of file +--exclude libcudart.so.11 diff --git a/py/requirements.txt b/py/requirements.txt index cd52d32436..e187213370 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -4,6 +4,6 @@ pybind11==2.6.2 --extra-index-url https://download.pytorch.org/whl/nightly/cu121 torch>=2.3.0.dev,<2.4.0 torchvision>=0.18.0.dev,<0.19.0 ---extra-index-url https://pypi.ngc.nvidia.com -tensorrt==8.6.1 +--extra-index-url https://pypi.nvidia.com +tensorrt==9.2.0.post12.dev5 pyyaml diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 45949a1c8d..32ef60c604 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -229,26 +229,7 @@ def aten_ops_cat( ) -def embedding_param_validator(embedding_node: Node) -> bool: - scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) - sparse = args_bounds_check(embedding_node.args, 4) - - if scale_grad_by_freq is not None: - _LOGGER.debug( - f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}." - ) - return False - - if sparse is not None: - _LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.") - return False - - return True - - -@dynamo_tensorrt_converter( - torch.ops.aten.embedding.default, capability_validator=embedding_param_validator -) +@dynamo_tensorrt_converter(torch.ops.aten.embedding.default) def aten_ops_embedding( ctx: ConversionContext, target: Target, @@ -263,22 +244,18 @@ def aten_ops_embedding( name, input=args[1], weight=args[0], - # args[2] is the padding index, which is useful for training only - scale_grad_by_freq=args_bounds_check(args, 3), - sparse=args_bounds_check(args, 4), + # args[2, 3, 4] are useful for training only + padding_idx=args_bounds_check(args, 2, -1), + scale_grad_by_freq=args_bounds_check(args, 3, False), + sparse=args_bounds_check(args, 4, False), ) def embedding_bag_validator(node: Node) -> bool: - mode = args_bounds_check(node.args, 4, 0) indices = node.args[1].meta.get("tensor_meta") if indices is None: return False - return ( - bool(node.args[2].op == "get_attr") - and (mode == 0 or mode == 1 or mode == 2) - and len(indices.shape) == 1 - ) + return len(indices.shape) == 1 # currently only support 1D indices @dynamo_tensorrt_converter( @@ -291,7 +268,6 @@ def embedding_bag_validator(node: Node) -> bool: { 0: (TRTTensor,), 1: (TRTTensor,), - 2: (np.ndarray, torch.Tensor), } ) def aten_ops_embedding_bag( @@ -309,12 +285,13 @@ def aten_ops_embedding_bag( weight=args[0], indices=args[1], offsets=args[2], - scale_grad_by_freq=args_bounds_check(args, 3, False), mode=args_bounds_check(args, 4, 0), - sparse=args_bounds_check(args, 5, False), per_sample_weights=args_bounds_check(args, 6, None), include_last_offset=args_bounds_check(args, 7, False), - # padding index is useful for training only + # scale_grad_by_freq, sparse, and padding_idx are useful for training only + scale_grad_by_freq=args_bounds_check(args, 3, False), + sparse=args_bounds_check(args, 5, False), + padding_idx=args_bounds_check(args, 8, -1), ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f9d14917f1..4464264f40 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -6,6 +6,7 @@ import numpy as np import tensorrt as trt import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -541,3 +542,73 @@ def flatten_dims( new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :]) return new_shape + + +def append( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + original_tensor: TRTTensor, + new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray], +) -> TRTTensor: + if isinstance(new_value, (int, float)): + new_value = np.array([new_value]) + + return impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat", + [original_tensor, new_value], + 0, + ) + + +def set_item( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + original_tensor: TRTTensor, + index: int, + new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray], +) -> TRTTensor: + if isinstance(new_value, (int, float)): + new_value = np.array([new_value]) + + len_original_tensor = original_tensor.shape[0] + index = get_positive_dim(index, len_original_tensor) + + front_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_front", + original_tensor, + dim=0, + start=0, + stop=index, + step=1, + ) + rear_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_rear", + original_tensor, + dim=0, + start=index + 1, + stop=len_original_tensor, + step=1, + ) + + ans = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat", + [front_tensor, new_value, rear_tensor], + 0, + ) + return ans diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index ac9faf9f4d..8533d76b9b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -7,7 +7,12 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy +from torch_tensorrt.dynamo.conversion.converter_utils import ( + append, + get_trt_tensor, + set_item, + to_numpy, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -19,6 +24,7 @@ def embedding( name: str, input: TRTTensor, weight: TRTTensor, + padding_idx: int, scale_grad_by_freq: bool, sparse: bool, ) -> TRTTensor: @@ -31,15 +37,17 @@ def embedding( indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor") embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor") # unsupported parameters - # ignore padding_idx since it is meaningful for training only + # ignore padding_idx, scale_grad_by_freq, and sparse + # since they are meaningful for training only - if scale_grad_by_freq: - raise RuntimeError( - "Currently we don't support scale gradient by word frequency." - ) + # useful for training only + # if scale_grad_by_freq: + # raise RuntimeError( + # "Currently we don't support scale gradient by word frequency." + # ) - if sparse: - raise RuntimeError("Currently we don't support sparse gradient.") + # if sparse: + # raise RuntimeError("Currently we don't support sparse gradient.") # Implement embedding lookup with gather layer gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0) @@ -47,34 +55,16 @@ def embedding( return gather_layer.get_output(0) -def embedding_bag( +def embedding_bag_with_traversable_offsets( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - weight: TRTTensor, - indices: TRTTensor, - offsets: Union[torch.Tensor, np.ndarray, Sequence[int]], - scale_grad_by_freq: bool, + embed: TRTTensor, + offsets_list: Union[torch.Tensor, np.ndarray, Sequence[int]], mode: int, - sparse: bool, - per_sample_weights: Optional[TRTTensor], include_last_offset: bool, ) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: - """ - This function is for calculating embedding bags. - - In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N), - it will be treated as B bags (sequences) each of fixed length N, and this will return - B values aggregated in a way depending on the mode. `offsets` is ignored and required - to be None in this case. - - However, according to the schema, `offsets` is required for input with any dimensions. - Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags. - """ - - # TODO: support 2D inputs - # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,)) reduce_name = "" if mode == 0: # sum reduce_op = functools.partial( @@ -96,6 +86,260 @@ def embedding_bag( ) reduce_name = "max" + offsets: np.ndarray = to_numpy(offsets_list) + len_embed = embed.shape[0] + + if include_last_offset: + # modify the last index of offsets to the end index + # however, pytorch doc says if `include_last_offset` is True, the size of offsets + # is equal to the number of bags + 1. The last element is the size of the input, + # or the ending index position of the last bag (sequence). + offsets.itemset(-1, len_embed) + else: + # add the end index to offsets + offsets = np.append(offsets, len_embed) + + zero_tensor = get_trt_tensor( + ctx, np.zeros((1, embed.shape[1]), dtype=np.float32), f"{name}_zero_tensor" + ) + + # separately reduce embeddings for different bags + reduced_embed_bags = [] + len_offsets = offsets.shape[0] + for i in range(len_offsets - 1): + if offsets[i] < offsets[i + 1]: + sliced_embed = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_embed_{i}", + embed, + 0, + int(offsets[i]), + int(offsets[i + 1]), + 1, + ) + reduced_one_bag = reduce_op( + name=f"{name}_{reduce_name}_{i}", + input_val=sliced_embed, + dim=0, + keepdim=True, + ) + reduced_embed_bags.append(reduced_one_bag) + else: # offsets[i] == offsets[i + 1] + reduced_embed_bags.append(zero_tensor) + + out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed_bags, 0) + return out, None, None, None + + +def embedding_bag_with_ITensor_offsets( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + embed: TRTTensor, + offsets: TRTTensor, + mode: int, + include_last_offset: bool, +) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: + len_embed = embed.shape[0] + + if include_last_offset: + # modify the last index of offsets to the end index + # however, pytorch doc says if `include_last_offset` is True, the size of offsets + # is equal to the number of bags + 1. The last element is the size of the input, + # or the ending index position of the last bag (sequence). + offsets = set_item( + ctx, target, source_ir, f"{name}_set_item", offsets, -1, len_embed + ) + else: + # add the end index to `offsets` + offsets = append(ctx, target, source_ir, f"{name}_append", offsets, len_embed) + + reduced_embed_bags = [] + # get the first item in offsets + start = ctx.net.add_gather( + offsets, get_trt_tensor(ctx, 0, f"{name}_tensor_0"), 0 + ).get_output(0) + + # create a placeholder tensor, whose shape is the same as an embedding + # if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros + # if mode is 2 (max), the placeholder tensor is filled with negative infinity + zero_tensor = get_trt_tensor( + ctx, np.zeros((1, embed.shape[1]), dtype=np.float32), f"{name}_zero_tensor" + ) + placeholder_tensor = ( + get_trt_tensor( + ctx, + np.full((1, embed.shape[1]), -np.inf, dtype=np.float32), + f"{name}_negative_inf_tensor", + ) + if mode == 2 + else zero_tensor + ) + + # create a list of constant ITensor for reuse + incremental_tensor_list = [] + for i in range(0, len_embed): + incremental_tensor_list.append( + get_trt_tensor(ctx, i, f"incremental_tensor_{i}") + ) + + # traverse offsets to calculate the embedding of each bag + for i in range(1, offsets.shape[0]): + end = ctx.net.add_gather(offsets, incremental_tensor_list[i], 0).get_output(0) + + one_bag_list = [] + # traverse the constant list to see if the index is in the range of the current bag + for j in range(0, len_embed): + j_tensor = incremental_tensor_list[j] + + # create a TRT conditional layer + conditional_layer = ctx.net.add_if_conditional() + # two conditions + cond1 = impl.elementwise.ge( + ctx, target, source_ir, f"{name}_ge_{i}_{j}", j_tensor, start + ) + cond2 = impl.elementwise.lt( + ctx, target, source_ir, f"{name}_lt_{i}_{j}", j_tensor, end + ) + condition = impl.elementwise.logical_and( + ctx, target, source_ir, f"{name}_and_{i}_{j}", cond1, cond2 + ) + condition = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_condition_{i}_{j}", + condition, + [], + ) + # set the combined condition to the conditional layer + conditional_layer.set_condition(condition) + # if true, run this subgraph + one_piece_embed = impl.select.index( + ctx, target, source_ir, f"{name}_index_{i}_{j}", embed, [j_tensor] + ) + true_sg = conditional_layer.add_input(one_piece_embed) + # if false, run this subgraph + false_sg = conditional_layer.add_input(placeholder_tensor) + + cond_output_layer = conditional_layer.add_output( + true_sg.get_output(0), false_sg.get_output(0) + ) + one_bag_list.append(cond_output_layer.get_output(0)) + + # concat the one_bag_list along the first dimension + one_bag = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat_bag{i}", + one_bag_list, + dim=0, + ) + + # reduce the one_bag along the first dimension, the result of which is an embedding of each bag + if mode == 0: # sum + reduced_one_bag = impl.reduce.sum( + ctx, + target, + source_ir, + name=f"{name}_sum_bag{i}", + input_val=one_bag, + dim=0, + keepdim=True, + ) + + # Since one_bag includes many zeros, directly calculating mean will cause results incorrect + elif mode == 1: # mean + reduced_one_bag = impl.reduce.sum( + ctx, + target, + source_ir, + name=f"{name}_sum_bag{i}", + input_val=one_bag, + dim=0, + keepdim=True, + ) + diff = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_diff_bag{i}", end, start + ) + reduced_one_bag = impl.elementwise.div( + ctx, target, source_ir, f"{name}_div_bag{i}", reduced_one_bag, diff + ) + + elif mode == 2: # max + reduced_one_bag = impl.reduce.max( + ctx, + target, + source_ir, + name=f"{name}_max_bag{i}", + input_val=one_bag, + dim=0, + keepdim=True, + return_indices=False, + ) + + # create a TRT conditional layer + conditional_layer = ctx.net.add_if_conditional() + # two conditions + condition = impl.elementwise.eq( + ctx, target, source_ir, f"{name}_eq_{i}", start, end + ) + condition = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_condition_eq_{i}", condition, [] + ) + # set the combined condition to the conditional layer + conditional_layer.set_condition(condition) + # if true, run this subgraph + true_sg = conditional_layer.add_input(zero_tensor) + # if false, run this subgraph + false_sg = conditional_layer.add_input(reduced_one_bag) + + reduced_one_bag_layer = conditional_layer.add_output( + true_sg.get_output(0), false_sg.get_output(0) + ) + + reduced_embed_bags.append(reduced_one_bag_layer.get_output(0)) + start = end + + # concat the reduced_embed_bags along the first dimension + out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed_bags, 0) + return out, None, None, None + + +def embedding_bag( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weight: TRTTensor, + indices: TRTTensor, + offsets: TRTTensor, + scale_grad_by_freq: bool, + mode: int, + sparse: bool, + per_sample_weights: Optional[TRTTensor], # for sum mode only + include_last_offset: bool, + padding_idx: int, +) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: + """ + This function is for calculating embedding bags. + + In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N), + it will be treated as B bags (sequences) each of fixed length N, and this will return + B values aggregated in a way depending on the mode. `offsets` is ignored and required + to be None in this case. + + However, according to the schema, `offsets` is required for input with any dimensions. + Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags. + """ + + # TODO: support 2D inputs + # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,)) + # calculate embedding embed = embedding( ctx, @@ -104,6 +348,7 @@ def embedding_bag( f"{name}_embedding", indices, weight, + padding_idx, scale_grad_by_freq, sparse, ) @@ -133,43 +378,12 @@ def embedding_bag( per_sample_weights, ) - offsets = to_numpy(offsets) - - if include_last_offset is False: - # add the end index to offsets - offsets = np.append(offsets, indices.shape[0]) + if isinstance(offsets, TRTTensor): + return embedding_bag_with_ITensor_offsets( + ctx, target, source_ir, name, embed, offsets, mode, include_last_offset + ) else: - # modify the last index of offsets to the end index - # however, pytorch doc says if `include_last_offset` is True, the size of offsets - # is equal to the number of bags + 1. The last element is the size of the input, - # or the ending index position of the last bag (sequence). - offsets[-1] = indices.shape[0] - - # separately reduce embeddings for different bags - reduced_embed = [] - len_offsets = len(offsets) - for i in range(len_offsets - 1): - if offsets[i] < offsets[i + 1]: - sliced_embed = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_embed_{i}", - embed, - 0, - int(offsets[i]), - int(offsets[i + 1]), - 1, - ) - reduced_sliced_embed = reduce_op( - name=f"{name}_{reduce_name}_{i}", - input_val=sliced_embed, - dim=0, - keepdim=True, - ) - reduced_embed.append(reduced_sliced_embed) - - out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0) - # out = reduce_op(input_val=embed, dim=1, keepdim=False) # Note: This implementation doesn't work for N-dim - - return out, None, None, None + # this branch has less time complexity + return embedding_bag_with_traversable_offsets( + ctx, target, source_ir, name, embed, offsets, mode, include_last_offset + ) diff --git a/pyproject.toml b/pyproject.toml index 5c42700ef8..62d2ad4bf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,8 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt>=8.6,<8.7", "torch >=2.3.0.dev,<2.4.0", + "tensorrt==9.2.0.post12.dev5", "pybind11==2.6.2", "numpy", ] @@ -42,7 +42,7 @@ requires-python = ">=3.8" keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] dependencies = [ "torch >=2.3.0.dev,<2.4.0", - "tensorrt>=8.6,<8.7", + "tensorrt==9.2.0.post12.dev5", "packaging>=23", "numpy", "typing-extensions>=4.7.0", diff --git a/tests/py/dynamo/conversion/test_embedding_aten.py b/tests/py/dynamo/conversion/test_embedding_aten.py index 0ce4c5b49b..c04d89ff9e 100644 --- a/tests/py/dynamo/conversion/test_embedding_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_aten.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -14,11 +13,13 @@ class TestEmbeddingConverter(DispatchTestCase): test_name="1d_indices", indices_tensor=torch.tensor([3, 1, 2], dtype=torch.int32), weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=False, ), param( test_name="2d_indices", indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]], dtype=torch.int32), weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=True, ), param( test_name="3d_indices", @@ -26,6 +27,7 @@ class TestEmbeddingConverter(DispatchTestCase): [[[0, 1], [2, 3]], [[3, 4], [4, 0]]], dtype=torch.int32 ), weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=True, ), ] ) @@ -38,7 +40,7 @@ def test_embedding( max_norm=None, norm_type=2.0, scale_grad_by_freq=None, - sparse=None, + sparse=False, ): class TestEmbedding(torch.nn.Module): def forward(self, indices, weights): diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 6d7b05f0e1..7649e17dac 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -8,36 +8,221 @@ class TestEmbeddingBagConverter(DispatchTestCase): @parameterized.expand( [ + # mode=0: sum, mode=1: mean, mode=2: max # 1D input param( test_name="1d_indices_1", - weight=torch.randn((10, 3), dtype=torch.float32), - indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32), - offsets=torch.tensor([0, 3], dtype=torch.int32), + weight=torch.randn((10, 2), dtype=torch.float32), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=True, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_2", + weight=torch.randn((10, 2), dtype=torch.float32), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), scale_grad_by_freq=False, mode=1, + sparse=True, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_3", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, sparse=False, per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_4", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=torch.randn((8,)), include_last_offset=True, padding_idx=-1, ), param( - test_name="1d_indices_2", - weight=torch.randn((10, 3), dtype=torch.float32), - indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32), - offsets=torch.tensor([0, 5], dtype=torch.int32), + test_name="1d_indices_5", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_6", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_7", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), scale_grad_by_freq=False, mode=0, sparse=False, - per_sample_weights=torch.randn((6,)), + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_8", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + ] + ) + def test_embedding_bag_with_traversable_offsets( + self, + test_name, + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ): + class TestEmbeddingBag(torch.nn.Module): + def forward(self, weight, indices): + return torch.ops.aten._embedding_bag.default( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + )[0] + + self.run_test( + TestEmbeddingBag(), + inputs=[weight, indices], + # use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # mode=0: sum, mode=1: mean, mode=2: max + # 1D input + param( + test_name="1d_indices_1", + weight=torch.randn((10, 2), dtype=torch.float32), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=True, + per_sample_weights=None, include_last_offset=False, padding_idx=-1, ), + # TODO: BUG! outputs of tensor and ITensor not matched + # param( + # test_name="1d_indices_2", + # # weight=torch.randn((10, 2), dtype=torch.float32), + # # indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32), + # weight=torch.arange(12, dtype=torch.float32).reshape(6, 2), + # indices=torch.tensor([0,1, 2,3, 4,5], dtype=torch.int32), + # offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + # scale_grad_by_freq=False, + # mode=0, + # sparse=True, + # per_sample_weights=None, + # include_last_offset=True, + # padding_idx=-1, + # ), param( test_name="1d_indices_3", - weight=torch.randn((10, 3), dtype=torch.float32), + weight=torch.randn((10, 4), dtype=torch.float32), indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), - offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_4", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=torch.randn((8,)), + include_last_offset=True, + padding_idx=-1, + ), + # TODO: BUG! outputs of tensor and ITensor not matched + # param( + # test_name="1d_indices_5", + # # weight=torch.randn((10, 4), dtype=torch.float32), + # # indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + # weight=torch.arange(12, dtype=torch.float32).reshape(6, 2), + # indices=torch.tensor([0,1,2, 3,4,5], dtype=torch.int32), + # offsets=torch.tensor([0, 3, 3], dtype=torch.int32), + # scale_grad_by_freq=False, + # mode=1, + # sparse=False, + # per_sample_weights=None, + # include_last_offset=True, + # padding_idx=-1, + # ), + param( + test_name="1d_indices_6", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), scale_grad_by_freq=False, mode=2, sparse=False, @@ -45,6 +230,30 @@ class TestEmbeddingBagConverter(DispatchTestCase): include_last_offset=False, padding_idx=-1, ), + param( + test_name="1d_indices_7", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_8", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), # 2D input # param( # test_name="2d_indices_1", @@ -103,7 +312,7 @@ class TestEmbeddingBagConverter(DispatchTestCase): # ), ] ) - def test_embedding_bag( + def test_embedding_bag_with_ITensor_offsets( self, test_name, weight, @@ -117,7 +326,7 @@ def test_embedding_bag( padding_idx, ): class TestEmbeddingBag(torch.nn.Module): - def forward(self, weight, indices): + def forward(self, weight, indices, offsets): return torch.ops.aten._embedding_bag.default( weight, indices, @@ -132,7 +341,8 @@ def forward(self, weight, indices): self.run_test( TestEmbeddingBag(), - inputs=[weight, indices], + inputs=[weight, indices, offsets], + # use_dynamo_tracer=True, enable_passes=True, ) diff --git a/tests/py/ts/models/hw_compat.ts b/tests/py/ts/models/hw_compat.ts index ab43e5e040..63b8ec6325 100644 Binary files a/tests/py/ts/models/hw_compat.ts and b/tests/py/ts/models/hw_compat.ts differ diff --git a/third_party/cudnn/local/BUILD b/third_party/cudnn/local/BUILD index d83ac2ec16..ec9494f1ce 100644 --- a/third_party/cudnn/local/BUILD +++ b/third_party/cudnn/local/BUILD @@ -28,8 +28,8 @@ config_setting( cc_library( name = "cudnn_headers", - hdrs = ["include/cudnn.h"] + glob([ - "include/cudnn_*.h", + hdrs = glob([ + "include/cudnn*.h", ]), includes = ["include/"], visibility = ["//visibility:private"], diff --git a/third_party/tensorrt/local/BUILD b/third_party/tensorrt/local/BUILD index 9cbe98a41e..e6bcbe70c4 100644 --- a/third_party/tensorrt/local/BUILD +++ b/third_party/tensorrt/local/BUILD @@ -40,9 +40,7 @@ cc_library( "include/aarch64-linux-gnu/NvInferPluginUtils.h", ], ), - ":ci_rhel_x86_64_linux": [ - "include/NvUtils.h", - ] + glob( + ":ci_rhel_x86_64_linux": glob( [ "include/NvInfer*.h", ], diff --git a/toolchains/legacy/pyproject.toml b/toolchains/legacy/pyproject.toml index ce9e6423cb..5606eaf214 100644 --- a/toolchains/legacy/pyproject.toml +++ b/toolchains/legacy/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt>=8.6,<8.7", + "tensorrt==9.2.0.post12.dev5", "torch>=1.13.0,<2.0", "pybind11==2.6.2", "numpy", @@ -42,7 +42,7 @@ requires-python = ">=3.8" keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] dependencies = [ "torch>=1.13.0,<2.0", - "tensorrt>=8.6,<8.7", + "tensorrt==9.2.0.post12.dev5", "packaging>=23", "numpy", "typing-extensions>=4.7.0",