diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index bcbbc04ca9..1705dd06db 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -231,26 +231,7 @@ def aten_ops_cat(
     )
 
 
-def embedding_param_validator(embedding_node: Node) -> bool:
-    scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
-    sparse = args_bounds_check(embedding_node.args, 4)
-
-    if scale_grad_by_freq is not None:
-        _LOGGER.debug(
-            f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}."
-        )
-        return False
-
-    if sparse is not None:
-        _LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.")
-        return False
-
-    return True
-
-
-@dynamo_tensorrt_converter(
-    torch.ops.aten.embedding.default, capability_validator=embedding_param_validator
-)
+@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
 def aten_ops_embedding(
     ctx: ConversionContext,
     target: Target,
@@ -265,22 +246,19 @@ def aten_ops_embedding(
         name,
         input=args[1],
         weight=args[0],
-        # args[2] is the padding index, which is useful for training only
-        scale_grad_by_freq=args_bounds_check(args, 3),
-        sparse=args_bounds_check(args, 4),
     )
 
 
 def embedding_bag_validator(node: Node) -> bool:
-    mode = args_bounds_check(node.args, 4, 0)
-    indices = node.args[1].meta.get("tensor_meta")
+    if not one_user_validator(node):
+        return False
+    meta = node.args[1].meta
+    indices = meta.get("tensor_meta")
+    if indices is None:
+        indices = meta.get("val")
     if indices is None:
         return False
-    return (
-        bool(node.args[2].op == "get_attr")
-        and (mode == 0 or mode == 1 or mode == 2)
-        and len(indices.shape) == 1
-    )
+    return len(indices.shape) == 1  # currently only support 1D indices
 
 
 @dynamo_tensorrt_converter(
@@ -293,7 +271,6 @@ def embedding_bag_validator(node: Node) -> bool:
     {
         0: (TRTTensor,),
         1: (TRTTensor,),
-        2: (np.ndarray, torch.Tensor),
     }
 )
 def aten_ops_embedding_bag(
@@ -311,12 +288,9 @@ def aten_ops_embedding_bag(
         weight=args[0],
         indices=args[1],
         offsets=args[2],
-        scale_grad_by_freq=args_bounds_check(args, 3, False),
         mode=args_bounds_check(args, 4, 0),
-        sparse=args_bounds_check(args, 5, False),
         per_sample_weights=args_bounds_check(args, 6, None),
         include_last_offset=args_bounds_check(args, 7, False),
-        # padding index is useful for training only
     )
 
 
diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
index 8f11e7fb91..a263440128 100644
--- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py
+++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
@@ -5,6 +5,7 @@
 
 import numpy as np
 import torch
+import torch_tensorrt.dynamo.conversion.impl as impl
 from torch import SymBool, SymFloat, SymInt
 from torch.fx.node import Argument, Target
 from torch_tensorrt import _enums
@@ -530,3 +531,111 @@ def flatten_dims(
     new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])
 
     return new_shape
+
+
+def append(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    original_tensor: TRTTensor,
+    new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
+    dim: int = 0,
+) -> TRTTensor:
+    """
+    Append a new value to the last of the original tensor along the specified dimension (default 0).
+    For example, if the original tensor is [1, 2, 3], the new value is 4, and the dim is 0,
+    the new tensor will be [1, 2, 3, 4].
+
+    Args:
+        ctx (ConversionContext): A ConversionContext containing the TensorRT network
+        target (Target): Target of calling node
+        source_ir (Optional[SourceIR]): SourceIR of calling converter
+        name (str): Name of the calling layer
+        original_tensor (TRTTensor): A TRTTensor to append the new value to
+        new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to append
+        dim (int, optional): Dimention to append the new value. Defaults to 0.
+
+    Returns:
+        TRTTensor: A new TRTTensor that is the result of appending the new value to the original tensor
+    """
+    if isinstance(new_value, (int, float)):
+        new_value = np.array([new_value])
+    new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)
+
+    return impl.cat.cat(
+        ctx,
+        target,
+        source_ir,
+        f"{name}_concat",
+        [original_tensor, new_value],
+        get_positive_dim(dim, len(original_tensor.shape)),
+    )
+
+
+def set_item(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    original_tensor: TRTTensor,
+    index: int,
+    new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
+) -> TRTTensor:
+    """
+    Set a new value to the original tensor at the specified index. For example,
+    if the original tensor is [1, 2, 3], the new value is 4, and the index is 1,
+    the new tensor will be [1, 4, 3].
+    If the index is out of bound, the new value will be appended to the end.
+
+    Args:
+        ctx (ConversionContext): A ConversionContext containing the TensorRT network
+        target (Target): Target of calling node
+        source_ir (Optional[SourceIR]): SourceIR of calling converter
+        name (str): Name of the calling layer
+        original_tensor (TRTTensor): A TRTTensor to set the new value to
+        index (int): The index to set the new value
+        new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to set
+
+    Returns:
+        TRTTensor: A new TRTTensor that is the result of setting the new value to the original tensor
+    """
+    if isinstance(new_value, (int, float)):
+        new_value = np.array([new_value])
+    new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)
+
+    len_original_tensor = original_tensor.shape[0]
+    index = get_positive_dim(index, len_original_tensor)
+
+    front_tensor = impl.slice.slice_op(
+        ctx,
+        target,
+        source_ir,
+        f"{name}_slice_front",
+        original_tensor,
+        dim=0,
+        start=0,
+        stop=index,
+        step=1,
+    )
+    rear_tensor = impl.slice.slice_op(
+        ctx,
+        target,
+        source_ir,
+        f"{name}_slice_rear",
+        original_tensor,
+        dim=0,
+        start=index + 1,
+        stop=len_original_tensor,
+        step=1,
+    )
+
+    ans = impl.cat.cat(
+        ctx,
+        target,
+        source_ir,
+        f"{name}_concat",
+        [front_tensor, new_value, rear_tensor],
+        0,
+    )
+    return ans
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py
index ee3354ae08..f4e98ac3ee 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py
@@ -1,16 +1,23 @@
 import functools
+import time
 from typing import Optional, Sequence, Tuple, Union
 
 import numpy as np
+import tensorrt as trt
 import torch
 import torch_tensorrt.dynamo.conversion.impl as impl
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+    append,
+    cast_trt_tensor,
+    get_trt_tensor,
+    set_item,
+    to_numpy,
+)
 from torch_tensorrt.fx.converters.converter_utils import set_layer_name
-
-import tensorrt as trt
+from torch_tensorrt.fx.types import TRTTensor
 
 
 def embedding(
@@ -18,25 +25,29 @@ def embedding(
     target: Target,
     source_ir: Optional[SourceIR],
     name: str,
-    input: trt.ITensor,
-    weight: trt.ITensor,
-    scale_grad_by_freq: bool,
-    sparse: bool,
-) -> trt.ITensor:
+    input: TRTTensor,
+    weight: TRTTensor,
+) -> TRTTensor:
     indices_tensor = input
     embedding_tensor = weight
+    if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
+        raise RuntimeError(
+            "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
+        )
     indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
     embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
     # unsupported parameters
-    # ignore padding_idx since it is meaningful for training only
+    # ignore padding_idx, scale_grad_by_freq, and sparse
+    # since they are meaningful for training only
 
-    if scale_grad_by_freq:
-        raise RuntimeError(
-            "Currently we don't support scale gradient by word frequency."
-        )
+    # useful for training only
+    # if scale_grad_by_freq:
+    #     raise RuntimeError(
+    #         "Currently we don't support scale gradient by word frequency."
+    #     )
 
-    if sparse:
-        raise RuntimeError("Currently we don't support sparse gradient.")
+    # if sparse:
+    #     raise RuntimeError("Currently we don't support sparse gradient.")
 
     # Implement embedding lookup with gather layer
     gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0)
@@ -44,34 +55,16 @@ def embedding(
     return gather_layer.get_output(0)
 
 
-def embedding_bag(
+def embedding_bag_with_traversable_offsets(
     ctx: ConversionContext,
     target: Target,
     source_ir: Optional[SourceIR],
     name: str,
-    weight: trt.ITensor,
-    indices: trt.ITensor,
-    offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
-    scale_grad_by_freq: bool,
+    embed: TRTTensor,
+    offsets_list: Union[torch.Tensor, np.ndarray, Sequence[int]],
     mode: int,
-    sparse: bool,
-    per_sample_weights: Optional[trt.ITensor],
     include_last_offset: bool,
-) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]:
-    """
-    This function is for calculating embedding bags.
-
-    In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N),
-    it will be treated as B bags (sequences) each of fixed length N, and this will return
-    B values aggregated in a way depending on the mode. `offsets` is ignored and required
-    to be None in this case.
-
-    However, according to the schema, `offsets` is required for input with any dimensions.
-    Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags.
-    """
-
-    # TODO: support 2D inputs
-    # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))
+) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
     reduce_name = ""
     if mode == 0:  # sum
         reduce_op = functools.partial(
@@ -93,6 +86,270 @@ def embedding_bag(
         )
         reduce_name = "max"
 
+    offsets: np.ndarray = to_numpy(offsets_list)
+    len_embed = embed.shape[0]
+
+    if include_last_offset:
+        # modify the last index of offsets to the end index
+        # however, pytorch doc says if `include_last_offset` is True, the size of offsets
+        # is equal to the number of bags + 1. The last element is the size of the input,
+        # or the ending index position of the last bag (sequence).
+        offsets.itemset(-1, len_embed)
+    else:
+        # add the end index to offsets
+        offsets = np.append(offsets, len_embed)
+
+    zero_tensor = get_trt_tensor(
+        ctx, np.zeros((1, embed.shape[1]), dtype=np.float32), f"{name}_zero_tensor"
+    )
+
+    # separately reduce embeddings for different bags
+    reduced_embed_bags = []
+    len_offsets = offsets.shape[0]
+    for i in range(len_offsets - 1):
+        if offsets[i] < offsets[i + 1]:
+            sliced_embed = impl.slice.slice_op(
+                ctx,
+                target,
+                source_ir,
+                f"{name}_slice_embed_{i}",
+                embed,
+                0,
+                int(offsets[i]),
+                int(offsets[i + 1]),
+                1,
+            )
+            reduced_one_bag = reduce_op(
+                name=f"{name}_{reduce_name}_{i}",
+                input_val=sliced_embed,
+                dim=0,
+                keepdim=True,
+            )
+            reduced_embed_bags.append(reduced_one_bag)
+        else:  # offsets[i] == offsets[i + 1]
+            reduced_embed_bags.append(zero_tensor)
+
+    out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed_bags, 0)
+    return out, None, None, None
+
+
+def embedding_bag_with_ITensor_offsets(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    embed: TRTTensor,
+    offsets: TRTTensor,
+    mode: int,
+    include_last_offset: bool,
+) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
+    len_embed = embed.shape[0]
+
+    if include_last_offset:
+        # modify the last index of offsets to the end index
+        # however, pytorch doc says if `include_last_offset` is True, the size of offsets
+        # is equal to the number of bags + 1. The last element is the size of the input,
+        # or the ending index position of the last bag (sequence).
+        offsets = set_item(
+            ctx, target, source_ir, f"{name}_set_item", offsets, -1, len_embed
+        )
+    else:
+        # add the end index to `offsets`
+        offsets = append(ctx, target, source_ir, f"{name}_append", offsets, len_embed)
+
+    # create a placeholder tensor, whose shape is the same as an embedding
+    # if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros
+    # if mode is 2 (max), the placeholder tensor is filled with negative infinity
+    placeholder_tensor = (
+        get_trt_tensor(
+            ctx,
+            np.full(embed.shape, -np.inf, dtype=np.float32),
+            f"{name}_negative_inf_tensor",
+        )
+        if mode == 2
+        else get_trt_tensor(
+            ctx, np.zeros(embed.shape, dtype=np.float32), f"{name}_zero_tensors"
+        )
+    )
+
+    # prepare some tensors for future use
+    zero_tensor = get_trt_tensor(
+        ctx, np.zeros((embed.shape[1],), dtype=np.float32), f"{name}_zero_tensor"
+    )
+    constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0")
+    constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1")
+
+    # Use two for loops to calculate the embedding of each bag
+    ###### Outer loop: traverse offsets ######
+    loop1 = ctx.net.add_loop()
+    trip_limit1 = ctx.net.add_constant(
+        shape=(),
+        weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.dtype("i"))),
+    ).get_output(0)
+    loop1.add_trip_limit(trip_limit1, trt.TripLimit.COUNT)
+
+    rec1_i_tensor = loop1.add_recurrence(constant_1)
+    set_layer_name(rec1_i_tensor, target, f"{name}_rec1_i_tensor", source_ir)
+    i_tensor = rec1_i_tensor.get_output(0)
+
+    start = ctx.net.add_gather(offsets, constant_0, 0).get_output(0)
+    rec1_start = loop1.add_recurrence(start)
+    set_layer_name(rec1_start, target, f"{name}_rec1_start", source_ir)
+    start = rec1_start.get_output(0)
+
+    end = ctx.net.add_gather(offsets, constant_1, 0).get_output(0)
+    rec1_end = loop1.add_recurrence(end)
+    set_layer_name(rec1_end, target, f"{name}_rec1_end", source_ir)
+    end = rec1_end.get_output(0)
+
+    ###### Inner loop: traverse indices ######
+    loop2 = ctx.net.add_loop()
+    trip_limit2 = ctx.net.add_constant(
+        shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.dtype("i")))
+    ).get_output(0)
+    loop2.add_trip_limit(trip_limit2, trt.TripLimit.COUNT)
+    rec2_j_tensor = loop2.add_recurrence(constant_0)
+    set_layer_name(rec2_j_tensor, target, f"{name}_rec2_j_tensor", source_ir)
+    j_tensor = rec2_j_tensor.get_output(0)
+
+    # create a TRT Select layer
+    cond1 = impl.elementwise.ge(
+        ctx, target, source_ir, f"{name}_ge_{time.time()}", j_tensor, start
+    )
+    cond2 = impl.elementwise.lt(
+        ctx, target, source_ir, f"{name}_lt_{time.time()}", j_tensor, end
+    )
+    condition1 = impl.elementwise.logical_and(
+        ctx, target, source_ir, f"{name}_and_{time.time()}", cond1, cond2
+    )
+    next_j = impl.elementwise.add(
+        ctx, target, source_ir, f"{name}_j_tensor_add_1_{time.time()}", j_tensor, 1
+    )
+    rec2_j_tensor.set_input(1, next_j)
+    loop_out2 = loop2.add_loop_output(condition1, trt.LoopOutput.CONCATENATE)
+    loop_out2.set_input(1, trip_limit2)
+    ####### Inner loop end #######
+
+    select_layer1 = ctx.net.add_select(
+        loop_out2.get_output(0), embed, placeholder_tensor
+    )
+    one_bag = select_layer1.get_output(0)
+
+    # reduce the one_bag along the dim=0, the result of which is an embedding of each bag
+    if mode == 0:  # sum
+        reduced_one_bag = impl.reduce.sum(
+            ctx,
+            target,
+            source_ir,
+            name=f"{name}_sum_bag{time.time()}",
+            input_val=one_bag,
+            dim=0,
+            keepdim=False,
+        )
+
+    # Since one_bag includes many zeros, directly calculating mean will cause results incorrect
+    elif mode == 1:  # mean
+        reduced_one_bag = impl.reduce.sum(
+            ctx,
+            target,
+            source_ir,
+            name=f"{name}_sum_bag{time.time()}",
+            input_val=one_bag,
+            dim=0,
+            keepdim=False,
+        )
+        diff = impl.elementwise.sub(
+            ctx, target, source_ir, f"{name}_diff_bag{time.time()}", end, start
+        )
+        reduced_one_bag = impl.elementwise.div(
+            ctx,
+            target,
+            source_ir,
+            f"{name}_div_bag{time.time()}",
+            reduced_one_bag,
+            diff,
+        )
+
+    elif mode == 2:  # max
+        reduced_one_bag = impl.reduce.max(
+            ctx,
+            target,
+            source_ir,
+            name=f"{name}_max_bag{time.time()}",
+            input_val=one_bag,
+            dim=0,
+            keepdim=False,
+            return_indices=False,
+        )
+
+    # create a TRT conditional layer
+    conditional_layer1 = ctx.net.add_if_conditional()
+    condition2 = impl.elementwise.eq(
+        ctx, target, source_ir, f"{name}_condition2_eq_{time.time()}", start, end
+    )
+    condition2 = impl.shuffle.reshape(
+        ctx,
+        target,
+        source_ir,
+        f"{name}_reshape_condition2_eq_{time.time()}",
+        condition2,
+        [],
+    )
+    # set the combined condition to the conditional layer
+    conditional_layer1.set_condition(condition2)
+    # if true, run this subgraph
+    true_sg = conditional_layer1.add_input(zero_tensor)
+    # if false, run this subgraph
+    false_sg = conditional_layer1.add_input(reduced_one_bag)
+
+    reduced_one_bag_layer = conditional_layer1.add_output(
+        true_sg.get_output(0), false_sg.get_output(0)
+    )
+
+    # reset the variables for the next iteration of the outer loop
+    next_i = impl.elementwise.add(
+        ctx, target, source_ir, f"{name}_i_tensor_add_1_{time.time()}", i_tensor, 1
+    )
+    rec1_i_tensor.set_input(1, next_i)
+    rec1_start.set_input(1, end)
+    rec1_end.set_input(1, ctx.net.add_gather(offsets, next_i, 0).get_output(0))
+
+    loop_out1 = loop1.add_loop_output(
+        reduced_one_bag_layer.get_output(0), trt.LoopOutput.CONCATENATE
+    )
+    loop_out1.set_input(1, trip_limit1)
+    reduced_embed_bags = loop_out1.get_output(0)
+    ####### Outer loop end #######
+    return reduced_embed_bags, None, None, None
+
+
+def embedding_bag(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    weight: TRTTensor,
+    indices: TRTTensor,
+    offsets: TRTTensor,
+    mode: int,
+    per_sample_weights: Optional[TRTTensor],  # for sum mode only
+    include_last_offset: bool,
+) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
+    """
+    This function is for calculating embedding bags.
+
+    In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N),
+    it will be treated as B bags (sequences) each of fixed length N, and this will return
+    B values aggregated in a way depending on the mode. `offsets` is ignored and required
+    to be None in this case.
+
+    However, according to the schema, `offsets` is required for input with any dimensions.
+    Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags.
+    """
+
+    # TODO: support 2D inputs
+    # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))
+
     # calculate embedding
     embed = embedding(
         ctx,
@@ -101,8 +358,9 @@ def embedding_bag(
         f"{name}_embedding",
         indices,
         weight,
-        scale_grad_by_freq,
-        sparse,
+    )
+    embed = cast_trt_tensor(
+        ctx, embed, torch.float, f"{name}_cast_embed_to_fp32", target, source_ir
     )
 
     # give weights to embedding
@@ -130,43 +388,12 @@ def embedding_bag(
             per_sample_weights,
         )
 
-    offsets = to_numpy(offsets)
-
-    if include_last_offset is False:
-        # add the end index to offsets
-        offsets = np.append(offsets, indices.shape[0])
+    if isinstance(offsets, TRTTensor):
+        return embedding_bag_with_ITensor_offsets(
+            ctx, target, source_ir, name, embed, offsets, mode, include_last_offset
+        )
     else:
-        # modify the last index of offsets to the end index
-        # however, pytorch doc says if `include_last_offset` is True, the size of offsets
-        # is equal to the number of bags + 1. The last element is the size of the input,
-        # or the ending index position of the last bag (sequence).
-        offsets[-1] = indices.shape[0]  # type: ignore[index]
-
-    # separately reduce embeddings for different bags
-    reduced_embed = []
-    len_offsets = len(offsets)
-    for i in range(len_offsets - 1):
-        if offsets[i] < offsets[i + 1]:
-            sliced_embed = impl.slice.slice_op(
-                ctx,
-                target,
-                source_ir,
-                f"{name}_slice_embed_{i}",
-                embed,
-                0,
-                int(offsets[i]),
-                int(offsets[i + 1]),
-                1,
-            )
-            reduced_sliced_embed = reduce_op(
-                name=f"{name}_{reduce_name}_{i}",
-                input_val=sliced_embed,
-                dim=0,
-                keepdim=True,
-            )
-            reduced_embed.append(reduced_sliced_embed)
-
-    out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0)
-    # out = reduce_op(input_val=embed, dim=1, keepdim=False)  # Note: This implementation doesn't work for N-dim
-
-    return out, None, None, None
+        # this branch has less time complexity
+        return embedding_bag_with_traversable_offsets(
+            ctx, target, source_ir, name, embed, offsets, mode, include_last_offset
+        )
diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
index 84d9af5939..bfb3d9545c 100644
--- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
+++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
@@ -35,7 +35,6 @@
     aten.diagonal_backward,
     aten.dot,
     aten.elu_backward,
-    aten._embedding_bag,
     aten.embedding_dense_backward,
     aten.empty_like,
     aten._euclidean_dist.default,
diff --git a/tests/py/dynamo/conversion/test_embedding_aten.py b/tests/py/dynamo/conversion/test_embedding_aten.py
index 0ce4c5b49b..c04d89ff9e 100644
--- a/tests/py/dynamo/conversion/test_embedding_aten.py
+++ b/tests/py/dynamo/conversion/test_embedding_aten.py
@@ -1,5 +1,4 @@
 import torch
-import torch.nn as nn
 from parameterized import param, parameterized
 from torch.testing._internal.common_utils import run_tests
 from torch_tensorrt import Input
@@ -14,11 +13,13 @@ class TestEmbeddingConverter(DispatchTestCase):
                 test_name="1d_indices",
                 indices_tensor=torch.tensor([3, 1, 2], dtype=torch.int32),
                 weights_tensor=torch.randn((5, 10), dtype=torch.float32),
+                sparse=False,
             ),
             param(
                 test_name="2d_indices",
                 indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]], dtype=torch.int32),
                 weights_tensor=torch.randn((5, 10), dtype=torch.float32),
+                sparse=True,
             ),
             param(
                 test_name="3d_indices",
@@ -26,6 +27,7 @@ class TestEmbeddingConverter(DispatchTestCase):
                     [[[0, 1], [2, 3]], [[3, 4], [4, 0]]], dtype=torch.int32
                 ),
                 weights_tensor=torch.randn((5, 10), dtype=torch.float32),
+                sparse=True,
             ),
         ]
     )
@@ -38,7 +40,7 @@ def test_embedding(
         max_norm=None,
         norm_type=2.0,
         scale_grad_by_freq=None,
-        sparse=None,
+        sparse=False,
     ):
         class TestEmbedding(torch.nn.Module):
             def forward(self, indices, weights):
diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py
index 6d7b05f0e1..2154937b43 100644
--- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py
+++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py
@@ -8,12 +8,65 @@
 class TestEmbeddingBagConverter(DispatchTestCase):
     @parameterized.expand(
         [
+            # mode=0: sum, mode=1: mean, mode=2: max
             # 1D input
             param(
                 test_name="1d_indices_1",
-                weight=torch.randn((10, 3), dtype=torch.float32),
-                indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
-                offsets=torch.tensor([0, 3], dtype=torch.int32),
+                weight=torch.randn((10, 2), dtype=torch.float16),
+                indices=torch.tensor(
+                    [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
+                ),
+                offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=0,
+                sparse=True,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_2",
+                weight=torch.randn((10, 2), dtype=torch.float16),
+                indices=torch.tensor(
+                    [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
+                ),
+                offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=1,
+                sparse=True,
+                per_sample_weights=None,
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_3",
+                weight=torch.randn((10, 4), dtype=torch.float16),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=2,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_4",
+                weight=torch.randn((10, 4), dtype=torch.float16),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=0,
+                sparse=False,
+                per_sample_weights=torch.randn((8,), dtype=torch.float16),
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_5",
+                weight=torch.randn((10, 4), dtype=torch.float32),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
                 scale_grad_by_freq=False,
                 mode=1,
                 sparse=False,
@@ -22,22 +75,150 @@ class TestEmbeddingBagConverter(DispatchTestCase):
                 padding_idx=-1,
             ),
             param(
-                test_name="1d_indices_2",
-                weight=torch.randn((10, 3), dtype=torch.float32),
-                indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
-                offsets=torch.tensor([0, 5], dtype=torch.int32),
+                test_name="1d_indices_6",
+                weight=torch.randn((10, 4), dtype=torch.float32),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=2,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_7",
+                weight=torch.randn((10, 4), dtype=torch.float32),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
                 scale_grad_by_freq=False,
                 mode=0,
                 sparse=False,
-                per_sample_weights=torch.randn((6,)),
+                per_sample_weights=None,
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_8",
+                weight=torch.randn((10, 4), dtype=torch.float32),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=1,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
+        ]
+    )
+    def test_embedding_bag_with_traversable_offsets(
+        self,
+        test_name,
+        weight,
+        indices,
+        offsets,
+        scale_grad_by_freq,
+        mode,
+        sparse,
+        per_sample_weights,
+        include_last_offset,
+        padding_idx,
+    ):
+        class TestEmbeddingBag(torch.nn.Module):
+            def forward(self, weight, indices):
+                return torch.ops.aten._embedding_bag.default(
+                    weight,
+                    indices,
+                    offsets,
+                    scale_grad_by_freq,
+                    mode,
+                    sparse,
+                    per_sample_weights,
+                    include_last_offset,
+                    padding_idx,
+                )[0]
+
+        self.run_test(
+            TestEmbeddingBag(),
+            inputs=[weight, indices],
+            precision=weight.dtype,
+            enable_passes=True,
+        )
+
+    @parameterized.expand(
+        [
+            # mode=0: sum, mode=1: mean, mode=2: max
+            # 1D input
+            param(
+                test_name="1d_indices_1",
+                weight=torch.randn((10, 2), dtype=torch.float32),
+                indices=torch.tensor(
+                    [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
+                ),
+                offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=0,
+                sparse=True,
+                per_sample_weights=None,
                 include_last_offset=False,
                 padding_idx=-1,
             ),
+            param(
+                test_name="1d_indices_2",
+                weight=torch.randn((10, 2), dtype=torch.float32),
+                indices=torch.tensor(
+                    [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32
+                ),
+                offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=1,
+                sparse=True,
+                per_sample_weights=None,
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
             param(
                 test_name="1d_indices_3",
-                weight=torch.randn((10, 3), dtype=torch.float32),
+                weight=torch.randn((10, 4), dtype=torch.float32),
                 indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
-                offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
+                offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=2,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_4",
+                weight=torch.randn((10, 4), dtype=torch.float32),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 2, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=0,
+                sparse=False,
+                per_sample_weights=torch.randn((8,), dtype=torch.float32),
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_5",
+                weight=torch.randn((10, 4), dtype=torch.float16),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=1,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_6",
+                weight=torch.randn((10, 4), dtype=torch.float16),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 5, 5], dtype=torch.int32),
                 scale_grad_by_freq=False,
                 mode=2,
                 sparse=False,
@@ -45,6 +226,30 @@ class TestEmbeddingBagConverter(DispatchTestCase):
                 include_last_offset=False,
                 padding_idx=-1,
             ),
+            param(
+                test_name="1d_indices_7",
+                weight=torch.randn((10, 4), dtype=torch.float16),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=0,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=True,
+                padding_idx=-1,
+            ),
+            param(
+                test_name="1d_indices_8",
+                weight=torch.randn((10, 4), dtype=torch.float16),
+                indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
+                offsets=torch.tensor([0, 8, 8], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=1,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
             # 2D input
             # param(
             #     test_name="2d_indices_1",
@@ -103,7 +308,7 @@ class TestEmbeddingBagConverter(DispatchTestCase):
             # ),
         ]
     )
-    def test_embedding_bag(
+    def test_embedding_bag_with_ITensor_offsets(
         self,
         test_name,
         weight,
@@ -117,7 +322,7 @@ def test_embedding_bag(
         padding_idx,
     ):
         class TestEmbeddingBag(torch.nn.Module):
-            def forward(self, weight, indices):
+            def forward(self, weight, indices, offsets):
                 return torch.ops.aten._embedding_bag.default(
                     weight,
                     indices,
@@ -132,7 +337,71 @@ def forward(self, weight, indices):
 
         self.run_test(
             TestEmbeddingBag(),
-            inputs=[weight, indices],
+            inputs=[weight, indices, offsets],
+            precision=weight.dtype,
+            enable_passes=True,
+        )
+
+    @parameterized.expand(
+        [
+            param(
+                test_name="dynamic_offsets_1",
+                weight=torch.range(0, 29, dtype=torch.float32).reshape(15, 2),
+                indices=torch.tensor([i for i in range(15)], dtype=torch.int32),
+                offsets=torch.tensor([0, 2], dtype=torch.int32),
+                scale_grad_by_freq=False,
+                mode=0,
+                sparse=False,
+                per_sample_weights=None,
+                include_last_offset=False,
+                padding_idx=-1,
+            ),
+        ]
+    )
+    def test_embedding_bag_with_dynamic_offsets(
+        self,
+        test_name,
+        weight,
+        indices,
+        offsets,
+        scale_grad_by_freq,
+        mode,
+        sparse,
+        per_sample_weights,
+        include_last_offset,
+        padding_idx,
+    ):
+        class TestEmbeddingBag(torch.nn.Module):
+            def forward(self, weight, indices, offsets):
+                offsets_list = []
+                end = torch.randint(8, 14, (1,))[0]
+                for i in range(3, 0, -1):
+                    rand_tensor = torch.arange(5, end, step=i, dtype=torch.int32)
+                    offsets_list.append(
+                        torch.ops.aten.cat.default((offsets, rand_tensor))
+                    )
+
+                res = []
+                for one_offsets in offsets_list:
+                    output = torch.ops.aten._embedding_bag.default(
+                        weight,
+                        indices,
+                        one_offsets,
+                        scale_grad_by_freq,
+                        mode,
+                        sparse,
+                        per_sample_weights,
+                        include_last_offset,
+                        padding_idx,
+                    )[0]
+                    res.append(output)
+
+                return res
+
+        self.run_test(
+            TestEmbeddingBag(),
+            inputs=[weight, indices, offsets],
+            precision=weight.dtype,
             enable_passes=True,
         )