Skip to content

Commit 0cf73e8

Browse files
sutaakaropenshift-merge-bot[bot]
authored andcommitted
Download MNIST dataset from specific location
1 parent 89efe67 commit 0cf73e8

File tree

6 files changed

+61
-26
lines changed

6 files changed

+61
-26
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ go 1.20
55
require (
66
github.com/onsi/gomega v1.27.10
77
github.com/openshift/api v0.0.0-20230213134911-7ba313770556
8-
github.com/project-codeflare/codeflare-common v0.0.0-20231129165224-988ba1da9069
8+
github.com/project-codeflare/codeflare-common v0.0.0-20240111082724-8f0684651717
99
github.com/project-codeflare/instascale v0.4.0
1010
github.com/project-codeflare/multi-cluster-app-dispatcher v1.39.0
1111
github.com/ray-project/kuberay/ray-operator v1.0.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
386386
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
387387
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
388388
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
389-
github.com/project-codeflare/codeflare-common v0.0.0-20231129165224-988ba1da9069 h1:81+ma1mchF/LtAGsf+poAt50kJ/fLYjoTAcZOxci1Yc=
390-
github.com/project-codeflare/codeflare-common v0.0.0-20231129165224-988ba1da9069/go.mod h1:zdi2GCYJX+QyxFWyCLMoTme3NMz/aucWDJWMqKfigxk=
389+
github.com/project-codeflare/codeflare-common v0.0.0-20240111082724-8f0684651717 h1:knUKEKvfEzVuSwQ4NAe2+I/Oxo4WztU5rYR8d/F66Lw=
390+
github.com/project-codeflare/codeflare-common v0.0.0-20240111082724-8f0684651717/go.mod h1:2Ck9LC+6Xi4jTDSlCJoP00tCzSrxek0roLsjvUgL2gY=
391391
github.com/project-codeflare/instascale v0.4.0 h1:l/cb+x4FrJ2bN9wXjv1mCngy77tVw0CLMiqJovTAflo=
392392
github.com/project-codeflare/instascale v0.4.0/go.mod h1:CpduFXKeuqYW4Ph1CPOJV6dpAdpebOxhbU4CmccZWSo=
393393
github.com/project-codeflare/multi-cluster-app-dispatcher v1.39.0 h1:zoS7pEAWK6eGELPCIIHB3W8Zb/a27Rf55ChYso7EV3o=

test/e2e/mnist.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616

1717
import torch
18+
import requests
1819
from pytorch_lightning import LightningModule, Trainer
1920
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2021
from torch import nn
@@ -32,6 +33,8 @@
3233
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
3334
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
3435

36+
print("MNIST_DATASET_URL: is ", os.getenv("MNIST_DATASET_URL"))
37+
MNIST_DATASET_URL = os.getenv("MNIST_DATASET_URL")
3538

3639
class LitMNIST(LightningModule):
3740
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
@@ -110,8 +113,34 @@ def configure_optimizers(self):
110113
####################
111114

112115
def prepare_data(self):
113-
# download
114-
print("Downloading MNIST dataset...")
116+
datasetFiles = [
117+
"t10k-images-idx3-ubyte",
118+
"t10k-labels-idx1-ubyte",
119+
"train-images-idx3-ubyte",
120+
"train-labels-idx1-ubyte"
121+
]
122+
123+
# Create required folder structure
124+
downloadLocation = os.path.join(PATH_DATASETS, "MNIST", "raw")
125+
os.makedirs(downloadLocation, exist_ok=True)
126+
print(f"{downloadLocation} folder_path created!")
127+
128+
for file in datasetFiles:
129+
print(f"Downloading MNIST dataset {file}... to path : {downloadLocation}")
130+
response = requests.get(f"{MNIST_DATASET_URL}{file}", stream=True)
131+
filePath = os.path.join(downloadLocation, file)
132+
133+
#to download dataset file
134+
try:
135+
if response.status_code == 200:
136+
open(filePath, 'wb').write(response.content)
137+
print(f"{file}: Downloaded and saved zipped file to path - {filePath}")
138+
else:
139+
print(f"Failed to download file {file}")
140+
except Exception as e:
141+
print(e)
142+
print(f"Downloaded MNIST dataset to... {downloadLocation}")
143+
115144
MNIST(self.data_dir, train=True, download=True)
116145
MNIST(self.data_dir, train=False, download=True)
117146

test/e2e/mnist_pytorch_mcad_job_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ func TestMNISTPyTorchMCAD(t *testing.T) {
7979
Name: "job",
8080
Image: GetPyTorchImage(),
8181
Env: []corev1.EnvVar{
82-
corev1.EnvVar{Name: "PYTHONUSERBASE", Value: "/workdir"},
82+
{Name: "PYTHONUSERBASE", Value: "/workdir"},
83+
{Name: "MNIST_DATASET_URL", Value: GetMnistDatasetURL()},
8384
},
8485
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
8586
VolumeMounts: []corev1.VolumeMount{

test/e2e/mnist_rayjob_mcad_raycluster_test.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License.
1717
package e2e
1818

1919
import (
20-
"encoding/base64"
2120
"testing"
2221

2322
. "github.com/onsi/gomega"
@@ -143,13 +142,6 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
143142
RayStartParams: map[string]string{},
144143
Template: corev1.PodTemplateSpec{
145144
Spec: corev1.PodSpec{
146-
InitContainers: []corev1.Container{
147-
{
148-
Name: "init-myservice",
149-
Image: "busybox:1.28",
150-
Command: []string{"sh", "-c", "until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done"},
151-
},
152-
},
153145
Containers: []corev1.Container{
154146
{
155147
Name: "ray-worker",
@@ -230,21 +222,29 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
230222
},
231223
Spec: rayv1.RayJobSpec{
232224
Entrypoint: "python /home/ray/jobs/mnist.py",
233-
RuntimeEnv: base64.StdEncoding.EncodeToString([]byte(`
234-
{
235-
"pip": [
236-
"pytorch_lightning==1.5.10",
237-
"torchmetrics==0.9.1",
238-
"torchvision==0.12.0"
239-
],
240-
"env_vars": {
241-
}
242-
}
243-
`)),
225+
RuntimeEnvYAML: `
226+
pip:
227+
- pytorch_lightning==1.5.10
228+
- torchmetrics==0.9.1
229+
- torchvision==0.12.0
230+
env_vars:
231+
MNIST_DATASET_URL: "` + GetMnistDatasetURL() + `"
232+
`,
244233
ClusterSelector: map[string]string{
245234
RayJobDefaultClusterSelectorKey: rayCluster.Name,
246235
},
247236
ShutdownAfterJobFinishes: false,
237+
SubmitterPodTemplate: &corev1.PodTemplateSpec{
238+
Spec: corev1.PodSpec{
239+
RestartPolicy: corev1.RestartPolicyNever,
240+
Containers: []corev1.Container{
241+
{
242+
Image: GetRayImage(),
243+
Name: "rayjob-submitter-pod",
244+
},
245+
},
246+
},
247+
},
248248
},
249249
}
250250
rayJob, err = test.Client().Ray().RayV1().RayJobs(namespace.Name).Create(test.Ctx(), rayJob, metav1.CreateOptions{})
@@ -256,6 +256,10 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
256256
test.T().Logf("Connecting to Ray cluster at: %s", rayDashboardURL.String())
257257
rayClient := NewRayClusterClient(rayDashboardURL)
258258

259+
// Wait for Ray job id to be available, this value is needed for writing logs in defer
260+
test.Eventually(RayJob(test, rayJob.Namespace, rayJob.Name), TestTimeoutShort).
261+
Should(WithTransform(RayJobId, Not(BeEmpty())))
262+
259263
// Retrieving the job logs once it has completed or timed out
260264
defer WriteRayJobAPILogs(test, rayClient, GetRayJobId(test, rayJob.Namespace, rayJob.Name))
261265

test/upgrade/olm_upgrade_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ func TestMNISTCreateAppWrapper(t *testing.T) {
9696
Name: "job",
9797
Image: GetPyTorchImage(),
9898
Env: []corev1.EnvVar{
99-
corev1.EnvVar{Name: "PYTHONUSERBASE", Value: "/workdir"},
99+
{Name: "PYTHONUSERBASE", Value: "/workdir"},
100+
{Name: "MNIST_DATASET_URL", Value: GetMnistDatasetURL()},
100101
},
101102
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
102103
VolumeMounts: []corev1.VolumeMount{

0 commit comments

Comments
 (0)