Skip to content

Adding tutorial for data loaders in multi device setups #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -67,10 +67,12 @@
'JAX_transformer_text_classification.md',
'data_loaders_on_cpu_with_jax.md',
'data_loaders_on_gpu_with_jax.md',
'data_loaders_for_multi_device_setups_with_jax.md',
]

suppress_warnings = [
'misc.highlighting_failure', # Suppress warning in exception in digits_vae
'mystnb.unknown_mime_type', # Suppress warning for unknown mime type (e.g. colab-display-data+json)
]

# -- Options for myst ----------------------------------------------
@@ -104,4 +106,5 @@
'JAX_transformer_text_classification.ipynb',
'data_loaders_on_cpu_with_jax.ipynb',
'data_loaders_on_gpu_with_jax.ipynb',
'data_loaders_for_multi_device_setups_with_jax.ipynb',
]

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -13,40 +13,26 @@ kernelspec:

+++ {"id": "PUFGZggH49zp"}

# Introduction to Data Loaders on CPU with JAX
# Introduction to Data Loaders for Multi-Device Training with JAX

+++ {"id": "3ia4PKEV5Dr8"}

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_cpu_with_jax.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_for_multi_device_setups_with_jax.ipynb)

This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:
This tutorial explores various data loading strategies for **JAX** in **multi-device distributed** environments, leveraging [**TPUs**](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#what-is-a-tpu). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:
* [**PyTorch DataLoader**](https://github.com/pytorch/data)
* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)
* [**Grain**](https://github.com/google/grain)
* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)

- [**PyTorch DataLoader**](https://github.com/pytorch/data)
- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)
- [**Grain**](https://github.com/google/grain)
- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)
You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset.

In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.
Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets.

Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.
If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).

If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).

If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html).

+++ {"id": "pEsb135zE-Jo"}

## Setting JAX to Use CPU Only

First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading.

```{code-cell}
:id: vqP6xyObC0_9
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
```

+++ {"id": "-rsMgVtO6asW"}

Import JAX API
@@ -56,19 +42,20 @@ Import JAX API
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from jax import grad, jit, vmap, random, device_put
from jax.sharding import Mesh, PartitionSpec, NamedSharding
```

+++ {"id": "TsFdlkSZKp9S"}

### CPU Setup Verification
## Checking TPU Availability for JAX

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: N3sqvaF3KJw1
outputId: 449c83d9-d050-4b15-9a8d-f71e340501f2
outputId: ee3286d0-b75f-46c5-8548-b57e3d895dd7
---
jax.devices()
```
@@ -94,18 +81,18 @@ def init_network_params(sizes, key):
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10] # Layers of the network
step_size = 0.01 # Learning rate for optimization
step_size = 0.01 # Learning rate
num_epochs = 8 # Number of training epochs
batch_size = 128 # Batch size for training
n_targets = 10 # Number of classes (digits 0-9)
num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels)
num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels
data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset
# Initialize network parameters using the defined layer sizes and a random seed
params = init_network_params(layer_sizes, random.PRNGKey(0))
```

+++ {"id": "6Ci_CqW7q6XM"}
+++ {"id": "rHLdqeI7D2WZ"}

## Model Prediction with Auto-Batching

@@ -122,7 +109,7 @@ def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example prediction
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
@@ -136,21 +123,39 @@ def predict(params, image):
batched_predict = vmap(predict, in_axes=(None, 0))
```

+++ {"id": "niTSr34_sDZi"}
+++ {"id": "AMWmxjVEpH2D"}

## Multi-device setup using a Mesh of devices

```{code-cell}
:id: 4Jc5YLFnpE-_
# Get the number of available devices (GPUs/TPUs) for sharding
num_devices = len(jax.devices())
# Multi-device setup using a Mesh of devices
devices = jax.devices()
mesh = Mesh(devices, ('device',))
# Define the sharding specification - split the data along the first axis (batch)
sharding_spec = PartitionSpec('device')
```

+++ {"id": "rLqfeORsERek"}

## Utility and Loss Functions

You'll now define utility functions for:

- One-hot encoding: Converts class indices to binary vectors.
- Accuracy calculation: Measures the performance of the model on the dataset.
- Loss computation: Calculates the difference between predictions and targets.

To optimize performance:

- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.
- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation.

- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to distribute the dataset across TPU cores.

```{code-cell}
:id: sA0a06raEQfS
@@ -185,15 +190,16 @@ def reshape_and_one_hot(x, y):
return x, y
def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):
"""Train the model for a given number of epochs."""
"""Train the model for a given number of epochs and device_put for TPU transfer."""
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:
x, y = reshape_and_one_hot(x, y)
x, y = device_put(x, NamedSharding(mesh, sharding_spec)), device_put(y, NamedSharding(mesh, sharding_spec))
params = update(params, x, y)
print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: "
f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, "
f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f},"
f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}")
```

@@ -207,8 +213,8 @@ This section shows how to load the MNIST dataset using PyTorch's DataLoader, con
---
colab:
base_uri: https://localhost:8080/
id: jmsfrWrHxIhC
outputId: 33dfeada-a763-4d26-f778-a27966e34d55
id: 33Wyf77WzNjA
outputId: a2378431-79f2-4dc4-aa1a-d98704657d26
---
!pip install torch torchvision
```
@@ -226,21 +232,34 @@ from torchvision.datasets import MNIST
:id: 6f6qU8PCc143
def numpy_collate(batch):
"""Convert a batch of PyTorch data to NumPy arrays."""
"""Collate function to convert a batch of PyTorch data into NumPy arrays."""
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
"""Custom DataLoader to return NumPy arrays from a PyTorch Dataset."""
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs)
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
"""Transform class to flatten and cast images to float32."""
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
```

+++ {"id": "mfSnfJND6I8G"}
+++ {"id": "ec-MHhv6hYsK"}

### Load Dataset with Transformations

@@ -250,8 +269,8 @@ Standardize the data by flattening the images, casting them to `float32`, and en
---
colab:
base_uri: https://localhost:8080/
id: Kxbl6bcx6crv
outputId: 372bbf4c-3ad5-4fd8-cc5d-27b50f5e4f38
id: nSviwX9ohhUh
outputId: 0bb3bc04-11ac-4fb6-8854-76a3f5e725a5
---
mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())
```
@@ -288,22 +307,24 @@ test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)
colab:
base_uri: https://localhost:8080/
id: Oz-UVnCxG5E8
outputId: abbaa26d-491a-4e63-e8c9-d3c571f53a28
outputId: 0f44cb63-b12c-47a7-8bd5-ed773e2b2ec5
---
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
```

+++ {"id": "m3zfxqnMiCbm"}
+++ {"id": "mfSnfJND6I8G"}

### Training Data Generator

Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.
Define a generator function using PyTorch's DataLoader for batch training.
Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.

Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes.
Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`
This warning can be safely ignored since data loaders do not use JAX within the forked processes.

```{code-cell}
:id: B-fES82EiL6Z
:id: Kxbl6bcx6crv
def pytorch_training_generator(mnist_dataset):
return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
@@ -319,29 +340,40 @@ The training loop uses the PyTorch DataLoader to iterate through batches and upd
---
colab:
base_uri: https://localhost:8080/
id: vtUjHsh-rJs8
outputId: 4766333e-4366-493b-995a-102778d1345a
id: MUrJxpjvUyOm
outputId: 629a19b1-acba-418a-f04b-3b78d7909de1
---
train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')
```

+++ {"id": "Nm45ZTo6yrf5"}
+++ {"id": "ACy1PoSVa3zH"}

## Loading Data with TensorFlow Datasets (TFDS)

This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow.

+++ {"id": "tcJRzpyOveWK"}

Ensure you have the latest versions of both TensorFlow and TensorFlow Datasets

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
height: 1000
id: _f55HPGAZu6P
outputId: 838c8f76-aa07-49d5-986d-3c88ed516b22
---
!pip install --upgrade tensorflow tensorflow-datasets
```

```{code-cell}
:id: sGaQAk1DHMUx
import tensorflow_datasets as tfds
import tensorflow as tf
# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF
tf.config.set_visible_devices([], device_type='GPU')
```

+++ {"id": "3xdQY7H6wr3n"}
+++ {"id": "F6OlzaDqwe4p"}

### Fetch Full Dataset for Evaluation

@@ -352,12 +384,12 @@ Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it fo
colab:
base_uri: https://localhost:8080/
height: 104
referenced_widgets: [b8cdabf5c05848f38f03850cab08b56f, a8b76d5f93004c089676e5a2a9b3336c,
119ac8428f9441e7a25eb0afef2fbb2a, 76a9815e5c2b4764a13409cebaf66821, 45ce8dd5c4b949afa957ec8ffb926060,
05b7145fd62d4581b2123c7680f11cdd, b96267f014814ec5b96ad7e6165104b1, bce34bdbfbd64f1f8353a4e8515cee0b,
93b8206f8c5841a692cdce985ae301d8, c95f592620c64da595cc787567b2c4db, 8a97071f862c4ec3b4b4140d2e34eda2]
referenced_widgets: [43d95e3e6b704cb5ae941541862e35fe, fca543b71352477db00545b3990d44fa,
d3c971a3507249c9a22cad026e46d739, 6da776e94f7740b9aae06f298c1e03cd, b4aec5e3895e4a19912c74777e9ea835,
ef4dc5b756d74129bd2d643d99a1ab2e, 30243b81748e497eb526b25404e95826, 3bb9b93e595d4a0ca973ded476c0a5d0,
b770951ecace4b02ad1575fe9eb9e640, 79009c4ea2bf46b1a3a2c6558fa6ec2f, 5cb081d3a038482583350d018a768bd4]
id: 1hOamw_7C8Pb
outputId: ca166490-22db-4732-b29f-866b7593e489
outputId: 0e3805dc-1bfd-4222-9052-0b2111ea3091
---
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
@@ -380,13 +412,13 @@ test_labels = one_hot(test_labels, n_targets)
colab:
base_uri: https://localhost:8080/
id: Td3PiLdmEf7z
outputId: 96403b0f-6079-43ce-df16-d4583f09906b
outputId: 464da4f6-f028-4667-889d-a812382739b0
---
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
```

+++ {"id": "UWRSaalfdyDX"}
+++ {"id": "yy9PunCJdI-G"}

### Define the Training Generator

@@ -414,8 +446,8 @@ Use the training generator in a custom training loop.
---
colab:
base_uri: https://localhost:8080/
id: h2sO13XDGvq1
outputId: a150246e-ceb5-46ac-db71-2a8177a9d04d
id: AsFKboVRaV6r
outputId: 9cb33f79-1b17-439d-88d3-61cd984124f6
---
train_model(num_epochs, params, training_generator)
```
@@ -435,7 +467,7 @@ Install Grain
colab:
base_uri: https://localhost:8080/
id: L78o7eeyGvn5
outputId: 76d16565-0d9e-4f5f-c6b1-4cf4a683d0e7
outputId: 8f32bb0f-9a73-48a9-dbcd-4eb93ba3f606
---
!pip install grain
```
@@ -468,6 +500,7 @@ class Dataset:
self.load_data()
def load_data(self):
# Load the MNIST dataset using PyGrain
self.dataset = MNIST(self.data_dir, download=True, train=self.train)
def __len__(self):
@@ -495,12 +528,12 @@ mnist_dataset = Dataset(data_dir)
```{code-cell}
:id: f1VnTuX3u_kL
# Convert training data to JAX arrays and encode labels as one-hot vectors
train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)
train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)
# Load test dataset and process it
mnist_dataset_test = MNIST(data_dir, download=True, train=False)
# Convert test images to JAX arrays and encode test labels as one-hot vectors
test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)
test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)
```
@@ -510,27 +543,24 @@ test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnis
colab:
base_uri: https://localhost:8080/
id: a2NHlp9klrQL
outputId: 14be58c0-851e-4a44-dfcc-d02f0718dab5
outputId: cc9e0958-8484-4669-a2d1-abac36a3097f
---
print("Train:", train_images.shape, train_labels.shape)
print("Test:", test_images.shape, test_labels.shape)
```

+++ {"id": "fETnWRo2crhf"}
+++ {"id": "1QPbXt7O0JN-"}

### Initialize PyGrain DataLoader

Set up a PyGrain DataLoader for sequential batch sampling.

```{code-cell}
:id: 9RuFTcsCs2Ac
sampler = pygrain.SequentialSampler(
num_records=len(mnist_dataset),
shard_options=pygrain.NoSharding()) # Single-device, no sharding
shard_options=pygrain.ShardByJaxProcess()) # Shard across TPU cores
def pygrain_training_generator():
"""Grain DataLoader generator for training."""
return pygrain.DataLoader(
data_source=mnist_dataset,
sampler=sampler,
@@ -549,7 +579,7 @@ Run the training loop using the Grain DataLoader.
colab:
base_uri: https://localhost:8080/
id: cjxJRtiTadEI
outputId: 3f624366-b683-4d20-9d0a-777d345b0e21
outputId: a620e9f7-7a01-4ba8-fe16-6f988401c7c1
---
train_model(num_epochs, params, pygrain_training_generator)
```
@@ -569,15 +599,11 @@ Install the Hugging Face `datasets` library.
colab:
base_uri: https://localhost:8080/
id: 19ipxPhI6oSN
outputId: 684e445f-d23e-4924-9e76-2c2c9359f0be
outputId: e0d52dfb-6c60-4539-a043-574d2533a744
---
!pip install datasets
```

+++ {"id": "be0h_dZv0593"}

Import Library

```{code-cell}
:id: 8v1N59p76zn0
@@ -586,68 +612,72 @@ from datasets import load_dataset

+++ {"id": "8Gaj11tO7C86"}

### Load and Format MNIST Dataset

Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
height: 301
referenced_widgets: [32f6132a31aa4c508d3c3c5ef70348bb, d7c2ffa6b143463c91cbf8befca6ca01,
fd964ecd3926419d92927c67f955d5d0, 60feca3fde7c4447ad8393b0542eb999, 3354a0baeca94d18bc6b2a8b8b465b58,
a0d0d052772b46deac7657ad052991a4, fb34783b9cba462e9b690e0979c4b07a, 8d8170c1ed99490589969cd753c40748,
f1ecb6db00a54e088f1e09164222d637, 3cf5dd8d29aa4619b39dc2542df7e42e, 2e5d42ca710441b389895f2d3b611d0a,
5d8202da24244dc896e9a8cba6a4ed4f, a6d64c953631412b8bd8f0ba53ae4d32, 69240c5cbfbb4e91961f5b49812a26f0,
865f38532b784a7c971f5d33b87b443e, ceb1c004191947cdaa10af9b9c03c80d, 64c6041037914779b5e8e9cf5a80ad04,
562fa6a0e7b846a180ac4b423c5511c5, b3b922288f9c4df2a4088279ff6d1531, 75a1a8ffda554318890cf74c345ed9a9,
3bae06cacf394a5998c2326199da94f5, ff6428a3daa5496c81d5e664aba01f97, 1ba3f86870724f55b94a35cb6b4173af,
b3e163fd8b8a4f289d5a25611cb66d23, abd2daba215e4f7c9ddabde04d6eb382, e22ee019049144d5aba573cdf4dbe4fc,
6ac765dac67841a69218140785f024c6, 7b057411a54e434fb74804b90daa8d44, 563f71b3c67d47c3ab1100f5dc1b98f3,
d81a657361ab4bba8bcc0cf309d2ff64, 20316312ab88471ba90cbb954be3e964, 698fda742f834473a23fb7e5e4cf239c,
289b52c5a38146b8b467a5f4678f6271, d07c2f37cf914894b1551a8104e6cb70, 5b55c73d551d483baaa6a1411c2597b1,
2308f77723f54ac898588f48d1853b65, 54d2589714d04b2e928b816258cb0df4, f84b795348c04c7a950165301a643671,
bc853a4a8d3c4dbda23d183f0a3b4f27, 1012ddc0343842d8b913a7d85df8ab8f, 771a73a8f5084a57afc5654d72e022f0,
311a43449f074841b6df4130b0871ac9, cd4d29cb01134469b52d6936c35eb943, 013cf89ee6174d29bb3f4fdff7b36049,
9237d877d84e4b3ab69698ecf56915bb, 337ef4d37e6b4ff6bf6e8bd4ca93383f, b4096d3837b84ccdb8f1186435c87281,
7259d3b7e11b4736b4d2aa8e9c55e994, 1ad1f8e99a864fc4a2bc532d9a4ff110, b2b50451eabd40978ef46db5e7dd08c4,
2dad5c5541e243128e23c3dd3e420ac2, a3de458b61e5493081d6bb9cf7e923db, 37760f8a7b164e6f9c1a23d621e9fe6b,
745a2aedcfab491fb9cffba19958b0c5, 2f6c670640d048d2af453638cfde3a1e]
referenced_widgets: [86617153e14143c6900da3535b74ef07, 8de57c9ecba14aa5b1f642af5c7e9094,
515fe154b1b74ed981e877aef503aa99, 4e291a8b028847328ea1d9a650c20beb, 87a0c8cdc0ad423daba7082b985cbd2b,
4764b5b806b94734b760cf6cc2fc224d, 5307bf3142804235bb688694c517d80c, 6a2fd6755667443abe7710ad607a79cc,
91bc1755904e40db8d758db4d09754e3, 69c38d75960542fb83fa087cae761957, dc31cb349c9b4c3580b2b77cbad1325c,
d451224a0ce540648b0c28d433d85803, 52f2f12dcffe4507ab92286fd3810db6, 6ab919475c80413e94afa66304b05338,
305d05093c6e411cb438a0bbf122d574, aa11f21e68994a8d9ddead215f2f4920, 59a7233abf61461b8b3feeb31b2f544f,
9d909399be9a4fa48bc3d781905c7f5a, 5b6172eb4e0541a3b07d4f82de77a303, bc3bec617b0040f487f80134537a3068,
9fe417f8159244f8ac808f2844922cf3, c4748e35e8574bb286a527295df98c8e, f50572e8058c4864bb8143c364d191f9,
436955f611674e27b4ddf3e040cc5ce9, 048231bf788c447091b8ef0174101f42, 97009f7e20d84c7c9d89f7497efc494c,
84e2844437884f6c89683e6545a2262e, df3019cc6aa44a4cbcb62096444769a7, ce17fe81850c49cd924297d21ecda621,
422117e32e0b4a95bed7925c99fd9f78, 56ab1fa0212a43a4a70838e440be0e9c, 1c5483472cea483bbf2a8fe2a9182ce0,
00034cb6a66143d8a87922befb1da7a6, 368b51d79aed4184854f155e2951da81, eb9de18be48d4a0db1034a38a0287ea6,
dbec1d9b196849a5ad79a5f083dbe64e, 66db6915d27b4fb49e1b44f70cb61654, 80f3e3a30dc24d3fa54bb72dc1c60182,
c320096ba1e74c7bbbd9509cc11c22e9, a664dd9c446040e8b175bb91d1c051db, 66c7826ff9b4455db9f7e9717a432f73,
74ec8cec0f3c4c04b76f5fb87ea2d9bb, ea4537aef1e247378de1935ad50ef76c, a9cffb2f5e194dfaba516bb4c8c47e3f,
4f17b7ab6ae94ce3b122561bcd8d4427, 3c0bdc06fe07412bacc00daa6f1eec34, 1ba273ced1484bcf9855366ff0dc3645,
7413d8bab616446ba6b820a3f874f6a0, 53c160c26c634b53a914be18ed91016c, ebc4ad2fae264e72a5307a0481a97ab3,
83ab5e7617fb45898c259bc20f71e958, 21f1138e807e4946953e3074d72d9a27, 86d7357878634706b9e214103efa262a,
3713a0e1880a43bc8b23225dbb8b4c45, f9f85ce1cbf34a7da27804ce7cc6444e]
id: a22kTvgk6_fJ
outputId: 35fc38b9-a6ab-4b02-ffa4-ab27fac69df4
outputId: 53e1d208-5360-479b-c097-0c03c7fac3e8
---
mnist_dataset = load_dataset("mnist").with_format("numpy")
mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy")
```

+++ {"id": "IFjTyGxY19b0"}
+++ {"id": "tgI7dIaX7JzM"}

### Extract images and labels

Get image shape and flatten for model input
Get image shape and flatten for model input.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: NHrKatD_7HbH
outputId: deec1739-2fc0-4e71-8567-f2e0c9db198b
---
:id: NHrKatD_7HbH
train_images = mnist_dataset["train"]["image"]
train_labels = mnist_dataset["train"]["label"]
test_images = mnist_dataset["test"]["image"]
test_labels = mnist_dataset["test"]["label"]
# Flatten images and one-hot encode labels
# Extract image shape
image_shape = train_images.shape[1:]
num_features = image_shape[0] * image_shape[1]
# Flatten the images
train_images = train_images.reshape(-1, num_features)
test_images = test_images.reshape(-1, num_features)
# One-hot encode the labels
train_labels = one_hot(train_labels, n_targets)
test_labels = one_hot(test_labels, n_targets)
```

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: dITh435Z7Nwb
outputId: cd77ebf6-7d44-420f-f8d8-4357f915c956
---
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
```
@@ -678,14 +708,16 @@ Run the training loop using the Hugging Face training generator.
---
colab:
base_uri: https://localhost:8080/
id: RhloYGsw6nPf
outputId: d49c1cd2-a546-46a6-84fb-d9507c38f4ca
id: Ui6aLiZP7aLe
outputId: 48347baf-30f2-443d-b3bf-b12100d96b8f
---
train_model(num_epochs, params, hf_training_generator)
```

+++ {"id": "qXylIOwidWI3"}
+++ {"id": "_JR0V1Aix9Id"}

## Summary

This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements.
This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements.

For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html).
1 change: 1 addition & 0 deletions docs/source/tutorials.md
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ JAX_time_series_classification
JAX_transformer_text_classification
data_loaders_on_cpu_with_jax
data_loaders_on_gpu_with_jax
data_loaders_for_multi_device_setups_with_jax
```

Once you've gone through this content, you can refer to package-specific