Skip to content

Commit 338dda9

Browse files
committed
Apply cuda fix
1 parent ee468d6 commit 338dda9

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

test/smoke_test.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import torch
7+
import torch.nn as nn
78
import torchvision
89
from torchvision.io import read_image
910
from torchvision.models import resnet50, ResNet50_Weights
@@ -27,10 +28,9 @@ def smoke_test_torchvision_read_decode() -> None:
2728
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
2829

2930
def smoke_test_compile() -> None:
30-
import torch.nn as nn
3131
model = resnet50().cuda()
3232
model = torch.compile(model)
33-
x = torch.randn(1, 3, 224, 224).cuda()
33+
x = torch.randn(1, 3, 224, 224, device="cuda")
3434
out = model(x)
3535
print(out.shape)
3636

@@ -66,12 +66,10 @@ def main() -> None:
6666
smoke_test_torchvision_resnet50_classify()
6767
if torch.cuda.is_available():
6868
smoke_test_torchvision_resnet50_classify("cuda")
69-
<<<<<<< HEAD
69+
smoke_test_compile()
7070
if torch.backends.mps.is_available():
7171
smoke_test_torchvision_resnet50_classify("mps")
72-
=======
73-
smoke_test_compile()
74-
>>>>>>> 2b8667d9a4 (Add smoke test Using a simple RN50 with torch.compile)
72+
7573

7674

7775
if __name__ == "__main__":

0 commit comments

Comments
 (0)