Skip to content

Commit f5b33c5

Browse files
authored
training UX: automatic generating make_train_step (#8495)
1 parent bbd66b9 commit f5b33c5

File tree

12 files changed

+1178
-263
lines changed

12 files changed

+1178
-263
lines changed

experimental/torch_xla2/examples/basic_training_jax.py

Lines changed: 26 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
44
"""
55

6+
import functools
7+
from torch_xla2 import train, interop
68
import torch
79
from torch.utils import _pytree as pytree
810
import torchvision
@@ -17,6 +19,8 @@
1719
from torch.utils.tensorboard import SummaryWriter
1820
from datetime import datetime
1921

22+
env = torch_xla2.enable_globally()
23+
2024

2125
transform = transforms.Compose(
2226
[transforms.ToTensor(),
@@ -38,29 +42,7 @@
3842
print('Training set has {} instances'.format(len(training_set)))
3943
print('Validation set has {} instances'.format(len(validation_set)))
4044

41-
import matplotlib.pyplot as plt
4245
import numpy as np
43-
44-
# Helper function for inline image display
45-
def matplotlib_imshow(img, one_channel=False):
46-
if one_channel:
47-
img = img.mean(dim=0)
48-
img = img / 2 + 0.5 # unnormalize
49-
npimg = img.numpy()
50-
if one_channel:
51-
plt.imshow(npimg, cmap="Greys")
52-
else:
53-
plt.imshow(np.transpose(npimg, (1, 2, 0)))
54-
55-
dataiter = iter(training_loader)
56-
images, labels = next(dataiter)
57-
58-
# Create a grid from the images and show them
59-
img_grid = torchvision.utils.make_grid(images)
60-
matplotlib_imshow(img_grid, one_channel=True)
61-
print(' '.join(classes[labels[j]] for j in range(4)))
62-
63-
6446
import torch.nn as nn
6547
import torch.nn.functional as F
6648

@@ -83,62 +65,55 @@ def forward(self, x):
8365
model = GarmentClassifier()
8466
loss_fn = torch.nn.CrossEntropyLoss()
8567

86-
jax_weights, jax_func = torch_xla2.extract_jax(model)
87-
jax_func = jax.jit(jax_func, inline=True)
8868
jax_optimizer = optax.adam(0.01)
89-
opt_state = jax_optimizer.init(jax_weights)
9069

70+
model.to('jax') # move the model to jax device
71+
model_jittable = interop.JittableModule(model)
72+
weights = model_jittable.params # these are trainable parameters
73+
buffers = model_jittable.buffers # these are non-trainable parameters
9174

92-
def jax_loss(weights, data, label):
93-
pred = jax_func(weights, data)
94-
loss = torch_xla2.interop.call_torch(loss_fn, pred, label)
95-
return loss
75+
opt_state = interop.call_jax(jax_optimizer.init, weights)
76+
model_fn = functools.partial(model_jittable.functional_call, 'forward')
9677

97-
grad_fn = jax.jit(jax.value_and_grad(jax_loss))
78+
train_step = train.make_train_step(model_fn, loss_fn, jax_optimizer)
9879

80+
train_step = interop.jax_jit(train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)})
9981

10082
# NB: Loss functions expect data in batches, so we're creating batches of 4
10183
# Represents the model's confidence in each of the 10 classes for a given input
102-
dummy_outputs = torch.rand(4, 10)
84+
dummy_inputs = torch.rand(4, 28, 28).to('jax')
85+
dummy_outputs = torch.rand(4, 10).to('jax')
10386
# Represents the correct class among the 10 being tested
104-
dummy_labels = torch.tensor([1, 5, 3, 7])
105-
106-
print(dummy_outputs)
107-
print(dummy_labels)
108-
109-
loss = loss_fn(dummy_outputs, dummy_labels)
110-
print('Total loss for this batch: {}'.format(loss.item()))
111-
87+
dummy_labels = torch.tensor([1, 5, 3, 7]).to('jax')
11288

113-
def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
89+
# test train_step
11490

91+
def train_one_epoch(weights, buffers, opt_state, epoch_index, tb_writer):
11592
running_loss = 0.
11693
last_loss = 0.
11794

11895
# Here, we use enumerate(training_loader) instead of
11996
# iter(training_loader) so that we can track the batch
12097
# index and do some intra-epoch reporting
12198
for i, data in enumerate(training_loader):
122-
# Every data instance is an input + label pair
123-
# NEW: Move model to XLA device
124-
data = pytree.tree_map_only(torch.Tensor,
125-
torch_xla2.tensor.t2j, data)
12699
inputs, labels = data
127100

128-
val, grads = grad_fn(jax_weights, (inputs, ), labels)
129-
updates, opt_state = jax_optimizer.update(grads, opt_state)
130-
jax_weights = optax.apply_updates(jax_weights, updates)
101+
inputs = inputs.to('jax')
102+
labels = labels.to('jax')
103+
104+
loss, weights, opt_state = train_step(
105+
weights, buffers, opt_state, inputs, labels)
131106

132107
# Gather data and report
133-
running_loss += val.item()
108+
running_loss += loss.item()
134109
if i % 1000 == 999:
135110
last_loss = running_loss / 1000 # loss per batch
136111
print(' batch {} loss: {}'.format(i + 1, last_loss))
137112
tb_x = epoch_index * len(training_loader) + i + 1
138113
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
139114
running_loss = 0.
140115

141-
return last_loss, opt_state
116+
return last_loss, weights, opt_state
142117

143118

144119

@@ -152,39 +127,5 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
152127
for epoch in range(EPOCHS):
153128
print('EPOCH {}:'.format(epoch_number + 1))
154129

155-
# Make sure gradient tracking is on, and do a pass over the data
156-
model.train(True)
157-
158-
avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer)
159-
160-
running_vloss = 0.0
161-
# Set the model to evaluation mode, disabling dropout and using population
162-
# statistics for batch normalization.
163-
model.eval()
164-
165-
# Disable gradient computation and reduce memory consumption.
166-
with torch.no_grad():
167-
for i, vdata in enumerate(validation_loader):
168-
169-
vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata)
170-
voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward
171-
vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels)
172-
running_vloss += vloss
173-
174-
avg_vloss = running_vloss / (i + 1)
175-
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
176-
177-
# Log the running loss averaged per batch
178-
# for both training and validation
179-
writer.add_scalars('Training vs. Validation Loss',
180-
{ 'Training' : np.asarray(avg_loss), 'Validation' : np.asarray(avg_vloss) },
181-
epoch_number + 1)
182-
writer.flush()
183-
184-
# Track best performance, and save the model's state
185-
if avg_vloss < best_vloss:
186-
best_vloss = avg_vloss
187-
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
188-
torch.save(model.state_dict(), model_path)
189-
190-
epoch_number += 1
130+
avg_loss, weights, opt_state = train_one_epoch(weights, buffers, opt_state, epoch_number, writer)
131+
print(avg_loss)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# syntax=docker/dockerfile:experimental
2+
# Use Python 3.10 as the base image
3+
FROM python:3.10-slim-bullseye
4+
5+
# Install system dependencies
6+
RUN apt-get update && apt-get upgrade -y
7+
RUN apt-get update && apt-get install -y curl gnupg
8+
9+
# Add the Google Cloud SDK package repository
10+
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
11+
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
12+
13+
# Install the Google Cloud SDK
14+
RUN apt-get update && apt-get install -y google-cloud-sdk git
15+
16+
# Set the default Python version to 3.10
17+
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1
18+
RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
19+
RUN pip install optax fire tensorflow tensorboard-plugin-profile
20+
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
21+
22+
WORKDIR /
23+
RUN git clone https://github.com/pytorch/torchtitan.git
24+
WORKDIR /torchtitan
25+
RUN pip install -r requirements.txt
26+
RUN pip install .
27+
28+
WORKDIR /
29+
RUN git clone https://github.com/pytorch/xla.git
30+
WORKDIR xla/experimental/torch_xla2
31+
RUN pip install -e .
32+
33+
ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"]
34+
CMD ["--batch_size=8", "--seqlen=2048"]

0 commit comments

Comments
 (0)