Skip to content

Commit 0a62d46

Browse files
committed
Download MNIST dataset from specific location
1 parent 89efe67 commit 0a62d46

File tree

4 files changed

+58
-23
lines changed

4 files changed

+58
-23
lines changed

test/e2e/mnist.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616

1717
import torch
18+
import requests
1819
from pytorch_lightning import LightningModule, Trainer
1920
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2021
from torch import nn
@@ -32,6 +33,8 @@
3233
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
3334
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
3435

36+
print("MNIST_DATASET_URL: is ", os.getenv("MNIST_DATASET_URL"))
37+
MNIST_DATASET_URL = os.getenv("MNIST_DATASET_URL")
3538

3639
class LitMNIST(LightningModule):
3740
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
@@ -110,8 +113,34 @@ def configure_optimizers(self):
110113
####################
111114

112115
def prepare_data(self):
113-
# download
114-
print("Downloading MNIST dataset...")
116+
datasetFiles = [
117+
"t10k-images-idx3-ubyte",
118+
"t10k-labels-idx1-ubyte",
119+
"train-images-idx3-ubyte",
120+
"train-labels-idx1-ubyte"
121+
]
122+
123+
# Create required folder structure
124+
downloadLocation = os.path.join(PATH_DATASETS, "MNIST", "raw")
125+
os.makedirs(downloadLocation, exist_ok=True)
126+
print(f"{downloadLocation} folder_path created!")
127+
128+
for file in datasetFiles:
129+
print(f"Downloading MNIST dataset {file}... to path : {downloadLocation}")
130+
response = requests.get(f"{MNIST_DATASET_URL}{file}", stream=True)
131+
filePath = os.path.join(downloadLocation, file)
132+
133+
#to download dataset file
134+
try:
135+
if response.status_code == 200:
136+
open(filePath, 'wb').write(response.content)
137+
print(f"{file}: Downloaded and saved zipped file to path - {filePath}")
138+
else:
139+
print(f"Failed to download file {file}")
140+
except Exception as e:
141+
print(e)
142+
print(f"Downloaded MNIST dataset to... {downloadLocation}")
143+
115144
MNIST(self.data_dir, train=True, download=True)
116145
MNIST(self.data_dir, train=False, download=True)
117146

test/e2e/mnist_pytorch_mcad_job_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ func TestMNISTPyTorchMCAD(t *testing.T) {
7979
Name: "job",
8080
Image: GetPyTorchImage(),
8181
Env: []corev1.EnvVar{
82-
corev1.EnvVar{Name: "PYTHONUSERBASE", Value: "/workdir"},
82+
{Name: "PYTHONUSERBASE", Value: "/workdir"},
83+
{Name: "MNIST_DATASET_URL", Value: GetMnistDatasetURL()},
8384
},
8485
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
8586
VolumeMounts: []corev1.VolumeMount{

test/e2e/mnist_rayjob_mcad_raycluster_test.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License.
1717
package e2e
1818

1919
import (
20-
"encoding/base64"
2120
"testing"
2221

2322
. "github.com/onsi/gomega"
@@ -143,13 +142,6 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
143142
RayStartParams: map[string]string{},
144143
Template: corev1.PodTemplateSpec{
145144
Spec: corev1.PodSpec{
146-
InitContainers: []corev1.Container{
147-
{
148-
Name: "init-myservice",
149-
Image: "busybox:1.28",
150-
Command: []string{"sh", "-c", "until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done"},
151-
},
152-
},
153145
Containers: []corev1.Container{
154146
{
155147
Name: "ray-worker",
@@ -230,21 +222,29 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
230222
},
231223
Spec: rayv1.RayJobSpec{
232224
Entrypoint: "python /home/ray/jobs/mnist.py",
233-
RuntimeEnv: base64.StdEncoding.EncodeToString([]byte(`
234-
{
235-
"pip": [
236-
"pytorch_lightning==1.5.10",
237-
"torchmetrics==0.9.1",
238-
"torchvision==0.12.0"
239-
],
240-
"env_vars": {
241-
}
242-
}
243-
`)),
225+
RuntimeEnvYAML: `
226+
pip:
227+
- pytorch_lightning==1.5.10
228+
- torchmetrics==0.9.1
229+
- torchvision==0.12.0
230+
env_vars:
231+
MNIST_DATASET_URL: "` + GetMnistDatasetURL() + `"
232+
`,
244233
ClusterSelector: map[string]string{
245234
RayJobDefaultClusterSelectorKey: rayCluster.Name,
246235
},
247236
ShutdownAfterJobFinishes: false,
237+
SubmitterPodTemplate: &corev1.PodTemplateSpec{
238+
Spec: corev1.PodSpec{
239+
RestartPolicy: corev1.RestartPolicyNever,
240+
Containers: []corev1.Container{
241+
{
242+
Image: GetRayImage(),
243+
Name: "rayjob-submitter-pod",
244+
},
245+
},
246+
},
247+
},
248248
},
249249
}
250250
rayJob, err = test.Client().Ray().RayV1().RayJobs(namespace.Name).Create(test.Ctx(), rayJob, metav1.CreateOptions{})
@@ -256,6 +256,10 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
256256
test.T().Logf("Connecting to Ray cluster at: %s", rayDashboardURL.String())
257257
rayClient := NewRayClusterClient(rayDashboardURL)
258258

259+
// Wait for Ray job id to be available, this value is needed for writing logs in defer
260+
test.Eventually(RayJob(test, rayJob.Namespace, rayJob.Name), TestTimeoutShort).
261+
Should(WithTransform(RayJobId, Not(BeEmpty())))
262+
259263
// Retrieving the job logs once it has completed or timed out
260264
defer WriteRayJobAPILogs(test, rayClient, GetRayJobId(test, rayJob.Namespace, rayJob.Name))
261265

test/upgrade/olm_upgrade_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ func TestMNISTCreateAppWrapper(t *testing.T) {
9696
Name: "job",
9797
Image: GetPyTorchImage(),
9898
Env: []corev1.EnvVar{
99-
corev1.EnvVar{Name: "PYTHONUSERBASE", Value: "/workdir"},
99+
{Name: "PYTHONUSERBASE", Value: "/workdir"},
100+
{Name: "MNIST_DATASET_URL", Value: GetMnistDatasetURL()},
100101
},
101102
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
102103
VolumeMounts: []corev1.VolumeMount{

0 commit comments

Comments
 (0)