Skip to content

Commit d0274d0

Browse files
sutaakaropenshift-merge-bot[bot]
authored andcommitted
Prepare Pytorch MNIST test image for disconnected testing
1 parent aa9c7ff commit d0274d0

File tree

7 files changed

+249
-0
lines changed

7 files changed

+249
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# This workflow will build the MNIST job test image and push it to the project-codeflare image registry
2+
3+
name: MNIST Job Test Image
4+
5+
on:
6+
workflow_dispatch:
7+
push:
8+
branches:
9+
- main
10+
paths:
11+
- 'test/pytorch_mnist_image/**'
12+
13+
jobs:
14+
push:
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
- name: Set Go
20+
uses: actions/setup-go@v3
21+
with:
22+
go-version: v1.20
23+
24+
- name: Login to Quay.io
25+
uses: redhat-actions/podman-login@v1
26+
with:
27+
username: ${{ secrets.QUAY_ID }}
28+
password: ${{ secrets.QUAY_TOKEN }}
29+
registry: quay.io
30+
31+
- name: Image Build and Push
32+
run: |
33+
make image-mnist-job-test-push

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ ENVTEST_K8S_VERSION = 1.24.2
8585
# used to build the manifests.
8686
ENV ?= default
8787

88+
# Image URL to build MNIST job test image
89+
MNIST_JOB_TEST_VERSION ?= v0.0.2
90+
MNIST_JOB_TEST_IMG ?= $(IMAGE_ORG_BASE)/mnist-job-test:${MNIST_JOB_TEST_VERSION}
91+
8892
# Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set)
8993
ifeq (,$(shell go env GOBIN))
9094
GOBIN=$(shell go env GOPATH)/bin
@@ -383,3 +387,11 @@ imports: openshift-goimports ## Organize imports in go files using openshift-goi
383387
.PHONY: verify-imports
384388
verify-imports: openshift-goimports ## Run import verifications.
385389
./hack/verify-imports.sh $(OPENSHIFT-GOIMPORTS)
390+
391+
.PHONY: image-mnist-job-test-build
392+
image-mnist-job-test-build: ## Build container image with the MNIST job.
393+
podman build -t ${MNIST_JOB_TEST_IMG} ./test/pytorch_mnist_image
394+
395+
.PHONY: image-mnist-job-test-push
396+
image-mnist-job-test-push: image-mnist-job-test-build ## Push container image with the MNIST job.
397+
podman push ${MNIST_JOB_TEST_IMG}

test/pytorch_mnist_image/Dockerfile

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Build the manager binary
2+
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
3+
4+
WORKDIR /test
5+
COPY entrypoint.sh entrypoint.sh
6+
7+
# Install MNIST requirements
8+
COPY mnist_pip_requirements.txt requirements.txt
9+
RUN pip install --requirement requirements.txt
10+
11+
# Prepare MNIST script
12+
COPY mnist.py mnist.py
13+
COPY download_dataset.py download_dataset.py
14+
RUN torchrun download_dataset.py
15+
16+
USER 65532:65532
17+
WORKDIR /workdir
18+
ENTRYPOINT ["/test/entrypoint.sh"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from torchvision.datasets import MNIST
18+
19+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
20+
MNIST(PATH_DATASETS, train=True, download=True)
21+
MNIST(PATH_DATASETS, train=False, download=True)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/sh
2+
3+
torchrun /test/mnist.py

test/pytorch_mnist_image/mnist.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import torch
18+
import requests
19+
from pytorch_lightning import LightningModule, Trainer
20+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
21+
from torch import nn
22+
from torch.nn import functional as F
23+
from torch.utils.data import DataLoader, random_split
24+
from torchmetrics import Accuracy
25+
from torchvision import transforms
26+
from torchvision.datasets import MNIST
27+
28+
PATH_WORKDIR = os.environ.get("PATH_WORKDIR", ".")
29+
PATH_DATASETS = os.environ.get("PATH_DATASETS", "/test")
30+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
31+
# %%
32+
33+
print("prior to running the trainer")
34+
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
35+
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
36+
37+
class LitMNIST(LightningModule):
38+
def __init__(self, data_dir=PATH_WORKDIR, hidden_size=64, learning_rate=2e-4):
39+
40+
super().__init__()
41+
42+
# Set our init args as class attributes
43+
self.data_dir = data_dir
44+
self.hidden_size = hidden_size
45+
self.learning_rate = learning_rate
46+
47+
# Hardcode some dataset specific attributes
48+
self.num_classes = 10
49+
self.dims = (1, 28, 28)
50+
channels, width, height = self.dims
51+
self.transform = transforms.Compose(
52+
[
53+
transforms.ToTensor(),
54+
transforms.Normalize((0.1307,), (0.3081,)),
55+
]
56+
)
57+
58+
# Define PyTorch model
59+
self.model = nn.Sequential(
60+
nn.Flatten(),
61+
nn.Linear(channels * width * height, hidden_size),
62+
nn.ReLU(),
63+
nn.Dropout(0.1),
64+
nn.Linear(hidden_size, hidden_size),
65+
nn.ReLU(),
66+
nn.Dropout(0.1),
67+
nn.Linear(hidden_size, self.num_classes),
68+
)
69+
70+
self.val_accuracy = Accuracy()
71+
self.test_accuracy = Accuracy()
72+
73+
def forward(self, x):
74+
x = self.model(x)
75+
return F.log_softmax(x, dim=1)
76+
77+
def training_step(self, batch, batch_idx):
78+
x, y = batch
79+
logits = self(x)
80+
loss = F.nll_loss(logits, y)
81+
return loss
82+
83+
def validation_step(self, batch, batch_idx):
84+
x, y = batch
85+
logits = self(x)
86+
loss = F.nll_loss(logits, y)
87+
preds = torch.argmax(logits, dim=1)
88+
self.val_accuracy.update(preds, y)
89+
90+
# Calling self.log will surface up scalars for you in TensorBoard
91+
self.log("val_loss", loss, prog_bar=True)
92+
self.log("val_acc", self.val_accuracy, prog_bar=True)
93+
94+
def test_step(self, batch, batch_idx):
95+
x, y = batch
96+
logits = self(x)
97+
loss = F.nll_loss(logits, y)
98+
preds = torch.argmax(logits, dim=1)
99+
self.test_accuracy.update(preds, y)
100+
101+
# Calling self.log will surface up scalars for you in TensorBoard
102+
self.log("test_loss", loss, prog_bar=True)
103+
self.log("test_acc", self.test_accuracy, prog_bar=True)
104+
105+
def configure_optimizers(self):
106+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
107+
return optimizer
108+
109+
####################
110+
# DATA RELATED HOOKS
111+
####################
112+
113+
def prepare_data(self):
114+
MNIST(PATH_DATASETS, train=True, download=True)
115+
MNIST(PATH_DATASETS, train=False, download=True)
116+
117+
def setup(self, stage=None):
118+
119+
# Assign train/val datasets for use in dataloaders
120+
if stage == "fit" or stage is None:
121+
mnist_full = MNIST(PATH_DATASETS, train=True, transform=self.transform)
122+
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
123+
124+
# Assign test dataset for use in dataloader(s)
125+
if stage == "test" or stage is None:
126+
self.mnist_test = MNIST(
127+
PATH_DATASETS, train=False, transform=self.transform
128+
)
129+
130+
def train_dataloader(self):
131+
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
132+
133+
def val_dataloader(self):
134+
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
135+
136+
def test_dataloader(self):
137+
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
138+
139+
140+
# Init DataLoader from MNIST Dataset
141+
142+
model = LitMNIST()
143+
144+
print("GROUP: ", int(os.environ.get("GROUP_WORLD_SIZE", 1)))
145+
print("LOCAL: ", int(os.environ.get("LOCAL_WORLD_SIZE", 1)))
146+
147+
# Initialize a trainer
148+
trainer = Trainer(
149+
accelerator="auto",
150+
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
151+
max_epochs=5,
152+
callbacks=[TQDMProgressBar(refresh_rate=20)],
153+
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
154+
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
155+
strategy="ddp",
156+
)
157+
158+
# Train the model ⚡
159+
trainer.fit(model)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pytorch_lightning==1.5.10
2+
torchmetrics==0.9.1
3+
torchvision==0.12.0

0 commit comments

Comments
 (0)