Skip to content

Commit 96a2712

Browse files
committed
Refactor smoke tests to configure module included in the release
1 parent 9b31a47 commit 96a2712

File tree

2 files changed

+53
-107
lines changed

2 files changed

+53
-107
lines changed

.github/workflows/validate-nightly-binaries.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ on:
2525
- .github/workflows/validate-macos-binaries.yml
2626
- .github/workflows/validate-macos-arm64-binaries.yml
2727
- test/smoke_test/*
28-
2928
jobs:
3029
nightly:
3130
uses: ./.github/workflows/validate-binaries.yml

test/smoke_test/smoke_test.py

Lines changed: 53 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import argparse
66
import torch
77
import platform
8+
import importlib
9+
import subprocess
810

911
gpu_arch_ver = os.getenv("GPU_ARCH_VER")
1012
gpu_arch_type = os.getenv("GPU_ARCH_TYPE")
@@ -14,6 +16,21 @@
1416
SCRIPT_DIR = Path(__file__).parent
1517
NIGHTLY_ALLOWED_DELTA = 3
1618

19+
MODULES = [
20+
{
21+
"name": "torchvision",
22+
"repo": "https://github.com/pytorch/vision.git",
23+
"smoke_test": "python ./vision/test/smoke_test.py",
24+
"extension": "extension",
25+
},
26+
{
27+
"name": "torchaudio",
28+
"repo": "https://github.com/pytorch/audio.git",
29+
"smoke_test": "python ./audio/test/smoke_test/smoke_test.py --no-ffmpeg",
30+
"extension": "_extension",
31+
},
32+
]
33+
1734
def check_nightly_binaries_date(package: str) -> None:
1835
from datetime import datetime, timedelta
1936
format_dt = '%Y%m%d'
@@ -27,33 +44,16 @@ def check_nightly_binaries_date(package: str) -> None:
2744
)
2845

2946
if(package == "all"):
30-
import torchaudio
31-
import torchvision
32-
ta_str = torchaudio.__version__
33-
tv_str = torchvision.__version__
34-
date_ta_str = re.findall("dev\d+", torchaudio.__version__)
35-
date_tv_str = re.findall("dev\d+", torchvision.__version__)
36-
date_ta_delta = datetime.now() - datetime.strptime(date_ta_str[0][3:], format_dt)
37-
date_tv_delta = datetime.now() - datetime.strptime(date_tv_str[0][3:], format_dt)
38-
39-
# check that the above three lists are equal and none of them is empty
40-
if date_ta_delta.days > NIGHTLY_ALLOWED_DELTA or date_tv_delta.days > NIGHTLY_ALLOWED_DELTA:
41-
raise RuntimeError(
42-
f"Expected torchaudio, torchvision to be less then {NIGHTLY_ALLOWED_DELTA} days. But they are from {date_ta_str}, {date_tv_str} respectively"
43-
)
44-
45-
def check_cuda_version(version: str, dlibary: str):
46-
version = torch.ops.torchaudio.cuda_version()
47-
if version is not None and torch.version.cuda is not None:
48-
version_str = str(version)
49-
ta_version = f"{version_str[:-3]}.{version_str[-2]}"
50-
t_version = torch.version.cuda.split(".")
51-
t_version = f"{t_version[0]}.{t_version[1]}"
52-
if ta_version != t_version:
53-
raise RuntimeError(
54-
"Detected that PyTorch and {dlibary} were compiled with different CUDA versions. "
55-
f"PyTorch has CUDA version {t_version} whereas {dlibary} has CUDA version {ta_version}. "
56-
)
47+
for module in MODULES:
48+
imported_module = importlib.import_module(module["name"])
49+
module_version = imported_module.__version__
50+
date_m_str = re.findall("dev\d+", module_version)
51+
date_m_delta = datetime.now() - datetime.strptime(date_m_str[0][3:], format_dt)
52+
print(f"Nightly date check for {module['name']} version {module_version}")
53+
if date_m_delta.days > NIGHTLY_ALLOWED_DELTA:
54+
raise RuntimeError(
55+
f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}"
56+
)
5757

5858
def smoke_test_cuda(package: str) -> None:
5959
if not torch.cuda.is_available() and is_cuda_system:
@@ -69,12 +69,15 @@ def smoke_test_cuda(package: str) -> None:
6969
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
7070

7171
if(package == 'all' and is_cuda_system):
72-
import torchaudio
73-
import torchvision
74-
print(f"torchvision cuda: {torch.ops.torchvision._cuda_version()}")
75-
print(f"torchaudio cuda: {torch.ops.torchaudio.cuda_version()}")
76-
check_cuda_version(torch.ops.torchvision._cuda_version(), "TorchVision")
77-
check_cuda_version(torch.ops.torchaudio.cuda_version(), "TorchAudio")
72+
for module in MODULES:
73+
imported_module = importlib.import_module(module["name"])
74+
# TBD for vision move extension module to private so it will
75+
# be _extention. For audio add version return from the check
76+
if module["extension"] == "extension":
77+
version = imported_module.extension._check_cuda_version()
78+
print(f"{module['name']} CUDA: {version}")
79+
else:
80+
imported_module._extension._check_cuda_version()
7881

7982

8083
def smoke_test_conv2d() -> None:
@@ -97,67 +100,20 @@ def smoke_test_conv2d() -> None:
97100
out = conv(x)
98101

99102

100-
def smoke_test_torchvision() -> None:
101-
print(
102-
"Is torchvision useable?",
103-
all(
104-
x is not None
105-
for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]
106-
),
107-
)
108-
109-
110-
def smoke_test_torchvision_read_decode() -> None:
111-
from torchvision.io import read_image
112-
113-
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.jpg"))
114-
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
115-
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
116-
img_png = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.png"))
117-
if img_png.ndim != 3 or img_png.numel() < 100:
118-
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
119-
120-
121-
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
122-
from torchvision.io import read_image
123-
from torchvision.models import resnet50, ResNet50_Weights
124-
125-
img = read_image(str(SCRIPT_DIR / "assets" / "dog2.jpg")).to(device)
126-
127-
# Step 1: Initialize model with the best available weights
128-
weights = ResNet50_Weights.DEFAULT
129-
model = resnet50(weights=weights).to(device)
130-
model.eval()
131-
132-
# Step 2: Initialize the inference transforms
133-
preprocess = weights.transforms()
134-
135-
# Step 3: Apply inference preprocessing transforms
136-
batch = preprocess(img).unsqueeze(0)
137-
138-
# Step 4: Use the model and print the predicted category
139-
prediction = model(batch).squeeze(0).softmax(0)
140-
class_id = prediction.argmax().item()
141-
score = prediction[class_id].item()
142-
category_name = weights.meta["categories"][class_id]
143-
expected_category = "German shepherd"
144-
print(f"{category_name}: {100 * score:.1f}%")
145-
if category_name != expected_category:
146-
raise RuntimeError(
147-
f"Failed ResNet50 classify {category_name} Expected: {expected_category}"
148-
)
149-
150-
151-
def smoke_test_torchaudio() -> None:
152-
import torchaudio
153-
import torchaudio.compliance.kaldi # noqa: F401
154-
import torchaudio.datasets # noqa: F401
155-
import torchaudio.functional # noqa: F401
156-
import torchaudio.models # noqa: F401
157-
import torchaudio.pipelines # noqa: F401
158-
import torchaudio.sox_effects # noqa: F401
159-
import torchaudio.transforms # noqa: F401
160-
import torchaudio.utils # noqa: F401
103+
def smoke_test_modules():
104+
for module in MODULES:
105+
if module["repo"]:
106+
subprocess.check_output(f"git clone --depth 1 {module['repo']}", stderr=subprocess.STDOUT, shell=True)
107+
try:
108+
output = subprocess.check_output(
109+
module["smoke_test"], stderr=subprocess.STDOUT, shell=True,
110+
universal_newlines=True)
111+
except subprocess.CalledProcessError as exc:
112+
raise RuntimeError(
113+
f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}"
114+
)
115+
else:
116+
print("Output: \n{}\n".format(output))
161117

162118

163119
def main() -> None:
@@ -171,25 +127,16 @@ def main() -> None:
171127
)
172128
options = parser.parse_args()
173129
print(f"torch: {torch.__version__}")
174-
175130
smoke_test_cuda(options.package)
176131
smoke_test_conv2d()
177132

133+
if options.package == "all":
134+
smoke_test_modules()
135+
178136
# only makes sense to check nightly package where dates are known
179137
if installation_str.find("nightly") != -1:
180138
check_nightly_binaries_date(options.package)
181139

182-
if options.package == "all":
183-
import torchaudio
184-
import torchvision
185-
print(f"torchvision: {torchvision.__version__}")
186-
print(f"torchaudio: {torchaudio.__version__}")
187-
smoke_test_torchaudio()
188-
smoke_test_torchvision()
189-
smoke_test_torchvision_read_decode()
190-
smoke_test_torchvision_resnet50_classify()
191-
if torch.cuda.is_available():
192-
smoke_test_torchvision_resnet50_classify("cuda")
193140

194141
if __name__ == "__main__":
195142
main()

0 commit comments

Comments
 (0)