Skip to content

Commit d232200

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] improve smoke test (#7550)
Reviewed By: vmoens Differential Revision: D45522829 fbshipit-source-id: aba08b74b26aa59111f68e50e11ab3a36abe1980
1 parent 0127d8e commit d232200

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

test/smoke_test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""Run smoke tests"""
22

3-
import os
43
import sys
54
from pathlib import Path
65

76
import torch
8-
import torch.nn as nn
97
import torchvision
10-
from torchvision.io import read_image
8+
from torchvision.io import decode_jpeg, read_file, read_image
119
from torchvision.models import resnet50, ResNet50_Weights
1210

1311
SCRIPT_DIR = Path(__file__).parent
@@ -22,13 +20,20 @@ def smoke_test_torchvision() -> None:
2220

2321
def smoke_test_torchvision_read_decode() -> None:
2422
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
25-
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
23+
if img_jpg.shape != (3, 606, 517):
2624
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
2725
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
28-
if img_png.ndim != 3 or img_png.numel() < 100:
26+
if img_png.shape != (4, 471, 354):
2927
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
3028

3129

30+
def smoke_test_torchvision_decode_jpeg_cuda():
31+
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
32+
img_jpg = decode_jpeg(img_jpg_data, device="cuda")
33+
if img_jpg.shape != (3, 606, 517):
34+
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
35+
36+
3237
def smoke_test_compile() -> None:
3338
try:
3439
model = resnet50().cuda()
@@ -77,6 +82,7 @@ def main() -> None:
7782
smoke_test_torchvision_read_decode()
7883
smoke_test_torchvision_resnet50_classify()
7984
if torch.cuda.is_available():
85+
smoke_test_torchvision_decode_jpeg_cuda()
8086
smoke_test_torchvision_resnet50_classify("cuda")
8187
smoke_test_compile()
8288

0 commit comments

Comments
 (0)