From 89408545f9b7acd8dd6d1d47ba52b76f0f7274a3 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 27 Oct 2023 13:00:11 -0700 Subject: [PATCH] fix type error --- py/torch_tensorrt/dynamo/conversion/impl/embedding.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 145ec663de..ac9faf9f4d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -75,7 +75,7 @@ def embedding_bag( # TODO: support 2D inputs # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,)) - + reduce_name = "" if mode == 0: # sum reduce_op = functools.partial( impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir @@ -143,7 +143,6 @@ def embedding_bag( # however, pytorch doc says if `include_last_offset` is True, the size of offsets # is equal to the number of bags + 1. The last element is the size of the input, # or the ending index position of the last bag (sequence). - offsets[-1] = indices.shape[0] # separately reduce embeddings for different bags @@ -158,8 +157,8 @@ def embedding_bag( f"{name}_slice_embed_{i}", embed, 0, - offsets[i], - offsets[i + 1], + int(offsets[i]), + int(offsets[i + 1]), 1, ) reduced_sliced_embed = reduce_op(