Skip to content

Commit 2d40091

Browse files
committed
fix small bugs
1 parent 645513e commit 2d40091

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def embedding_bag(
360360
weight,
361361
)
362362
embed = cast_trt_tensor(
363-
ctx, embed, torch.float, f"{name}_cast_embed_to_fp16", target, source_ir
363+
ctx, embed, torch.float, f"{name}_cast_embed_to_fp32", target, source_ir
364364
)
365365

366366
# give weights to embedding

tests/py/dynamo/conversion/harness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def run_test(
9292
):
9393
ref_outputs = [ref_outputs]
9494
for out, ref in zip(outputs, ref_outputs):
95+
ref = ref.cpu() # to_dtype test has cases with gpu output
9596
if not isinstance(ref, torch.Tensor):
9697
ref = torch.tensor([ref])
97-
ref = ref.cpu() # to_dtype test has cases with gpu output
9898
if ref.dtype == torch.int64:
9999
ref = ref.int() # convert torch.max's index output tensor to int32
100100
torch.testing.assert_close(
101101
out.cpu(),
102-
ref.cpu(),
102+
ref,
103103
rtol=rtol,
104104
atol=atol,
105105
equal_nan=True,

0 commit comments

Comments
 (0)