1
1
"""Run smoke tests"""
2
2
3
- import os
4
3
import sys
5
4
from pathlib import Path
6
5
7
6
import torch
8
- import torch .nn as nn
9
7
import torchvision
10
- from torchvision .io import read_image
8
+ from torchvision .io import decode_jpeg , read_file , read_image
11
9
from torchvision .models import resnet50 , ResNet50_Weights
12
10
13
11
SCRIPT_DIR = Path (__file__ ).parent
@@ -22,13 +20,20 @@ def smoke_test_torchvision() -> None:
22
20
23
21
def smoke_test_torchvision_read_decode () -> None :
24
22
img_jpg = read_image (str (SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg" ))
25
- if img_jpg .ndim != 3 or img_jpg . numel () < 100 :
23
+ if img_jpg .shape != ( 3 , 606 , 517 ) :
26
24
raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
27
25
img_png = read_image (str (SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png" ))
28
- if img_png .ndim != 3 or img_png . numel () < 100 :
26
+ if img_png .shape != ( 4 , 471 , 354 ) :
29
27
raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
30
28
31
29
30
+ def smoke_test_torchvision_decode_jpeg_cuda ():
31
+ img_jpg_data = read_file (str (SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg" ))
32
+ img_jpg = decode_jpeg (img_jpg_data , device = "cuda" )
33
+ if img_jpg .shape != (3 , 606 , 517 ):
34
+ raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
35
+
36
+
32
37
def smoke_test_compile () -> None :
33
38
try :
34
39
model = resnet50 ().cuda ()
@@ -77,6 +82,7 @@ def main() -> None:
77
82
smoke_test_torchvision_read_decode ()
78
83
smoke_test_torchvision_resnet50_classify ()
79
84
if torch .cuda .is_available ():
85
+ smoke_test_torchvision_decode_jpeg_cuda ()
80
86
smoke_test_torchvision_resnet50_classify ("cuda" )
81
87
smoke_test_compile ()
82
88
0 commit comments