Skip to content

Commit b25d7c2

Browse files
committed
Fix tracing of generator in BERT model on Windows
1 parent 5fac363 commit b25d7c2

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/py/dynamo/models/test_models_export.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import pytest
55
import timm
66
import torch
7-
import torch_tensorrt as torchtrt
87
import torchvision.models as models
98
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
109
from transformers import BertModel
1110
from transformers.utils.fx import symbolic_trace as transformers_trace
1211

12+
import torch_tensorrt as torchtrt
13+
1314
assertions = unittest.TestCase()
1415

1516

@@ -108,7 +109,9 @@ def test_efficientnet_b0(ir):
108109

109110
@pytest.mark.unit
110111
def test_bert_base_uncased(ir):
111-
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
112+
model = (
113+
BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
114+
)
112115
input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
113116
input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
114117

@@ -139,8 +142,8 @@ def test_bert_base_uncased(ir):
139142
msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.",
140143
)
141144

142-
for key, _ in model_outputs.items():
143-
out, trt_out = model_outputs[key], trt_model_outputs[key]
145+
for index in range(len(model_outputs)):
146+
out, trt_out = model_outputs[index], trt_model_outputs[index]
144147
cos_sim = cosine_similarity(out, trt_out)
145148
assertions.assertTrue(
146149
cos_sim > COSINE_THRESHOLD,

0 commit comments

Comments
 (0)