3
3
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
4
4
"""
5
5
6
+ import functools
7
+ from torch_xla2 import train , interop
6
8
import torch
7
9
from torch .utils import _pytree as pytree
8
10
import torchvision
17
19
from torch .utils .tensorboard import SummaryWriter
18
20
from datetime import datetime
19
21
22
+ env = torch_xla2 .enable_globally ()
23
+
20
24
21
25
transform = transforms .Compose (
22
26
[transforms .ToTensor (),
38
42
print ('Training set has {} instances' .format (len (training_set )))
39
43
print ('Validation set has {} instances' .format (len (validation_set )))
40
44
41
- import matplotlib .pyplot as plt
42
45
import numpy as np
43
-
44
- # Helper function for inline image display
45
- def matplotlib_imshow (img , one_channel = False ):
46
- if one_channel :
47
- img = img .mean (dim = 0 )
48
- img = img / 2 + 0.5 # unnormalize
49
- npimg = img .numpy ()
50
- if one_channel :
51
- plt .imshow (npimg , cmap = "Greys" )
52
- else :
53
- plt .imshow (np .transpose (npimg , (1 , 2 , 0 )))
54
-
55
- dataiter = iter (training_loader )
56
- images , labels = next (dataiter )
57
-
58
- # Create a grid from the images and show them
59
- img_grid = torchvision .utils .make_grid (images )
60
- matplotlib_imshow (img_grid , one_channel = True )
61
- print (' ' .join (classes [labels [j ]] for j in range (4 )))
62
-
63
-
64
46
import torch .nn as nn
65
47
import torch .nn .functional as F
66
48
@@ -83,62 +65,55 @@ def forward(self, x):
83
65
model = GarmentClassifier ()
84
66
loss_fn = torch .nn .CrossEntropyLoss ()
85
67
86
- jax_weights , jax_func = torch_xla2 .extract_jax (model )
87
- jax_func = jax .jit (jax_func , inline = True )
88
68
jax_optimizer = optax .adam (0.01 )
89
- opt_state = jax_optimizer .init (jax_weights )
90
69
70
+ model .to ('jax' ) # move the model to jax device
71
+ model_jittable = interop .JittableModule (model )
72
+ weights = model_jittable .params # these are trainable parameters
73
+ buffers = model_jittable .buffers # these are non-trainable parameters
91
74
92
- def jax_loss (weights , data , label ):
93
- pred = jax_func (weights , data )
94
- loss = torch_xla2 .interop .call_torch (loss_fn , pred , label )
95
- return loss
75
+ opt_state = interop .call_jax (jax_optimizer .init , weights )
76
+ model_fn = functools .partial (model_jittable .functional_call , 'forward' )
96
77
97
- grad_fn = jax . jit ( jax . value_and_grad ( jax_loss ) )
78
+ train_step = train . make_train_step ( model_fn , loss_fn , jax_optimizer )
98
79
80
+ train_step = interop .jax_jit (train_step , kwargs_for_jax_jit = {'donate_argnums' : (0 , 2 )})
99
81
100
82
# NB: Loss functions expect data in batches, so we're creating batches of 4
101
83
# Represents the model's confidence in each of the 10 classes for a given input
102
- dummy_outputs = torch .rand (4 , 10 )
84
+ dummy_inputs = torch .rand (4 , 28 , 28 ).to ('jax' )
85
+ dummy_outputs = torch .rand (4 , 10 ).to ('jax' )
103
86
# Represents the correct class among the 10 being tested
104
- dummy_labels = torch .tensor ([1 , 5 , 3 , 7 ])
105
-
106
- print (dummy_outputs )
107
- print (dummy_labels )
108
-
109
- loss = loss_fn (dummy_outputs , dummy_labels )
110
- print ('Total loss for this batch: {}' .format (loss .item ()))
111
-
87
+ dummy_labels = torch .tensor ([1 , 5 , 3 , 7 ]).to ('jax' )
112
88
113
- def train_one_epoch ( jax_weights , opt_state , epoch_index , tb_writer ):
89
+ # test train_step
114
90
91
+ def train_one_epoch (weights , buffers , opt_state , epoch_index , tb_writer ):
115
92
running_loss = 0.
116
93
last_loss = 0.
117
94
118
95
# Here, we use enumerate(training_loader) instead of
119
96
# iter(training_loader) so that we can track the batch
120
97
# index and do some intra-epoch reporting
121
98
for i , data in enumerate (training_loader ):
122
- # Every data instance is an input + label pair
123
- # NEW: Move model to XLA device
124
- data = pytree .tree_map_only (torch .Tensor ,
125
- torch_xla2 .tensor .t2j , data )
126
99
inputs , labels = data
127
100
128
- val , grads = grad_fn (jax_weights , (inputs , ), labels )
129
- updates , opt_state = jax_optimizer .update (grads , opt_state )
130
- jax_weights = optax .apply_updates (jax_weights , updates )
101
+ inputs = inputs .to ('jax' )
102
+ labels = labels .to ('jax' )
103
+
104
+ loss , weights , opt_state = train_step (
105
+ weights , buffers , opt_state , inputs , labels )
131
106
132
107
# Gather data and report
133
- running_loss += val .item ()
108
+ running_loss += loss .item ()
134
109
if i % 1000 == 999 :
135
110
last_loss = running_loss / 1000 # loss per batch
136
111
print (' batch {} loss: {}' .format (i + 1 , last_loss ))
137
112
tb_x = epoch_index * len (training_loader ) + i + 1
138
113
tb_writer .add_scalar ('Loss/train' , last_loss , tb_x )
139
114
running_loss = 0.
140
115
141
- return last_loss , opt_state
116
+ return last_loss , weights , opt_state
142
117
143
118
144
119
@@ -152,39 +127,5 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
152
127
for epoch in range (EPOCHS ):
153
128
print ('EPOCH {}:' .format (epoch_number + 1 ))
154
129
155
- # Make sure gradient tracking is on, and do a pass over the data
156
- model .train (True )
157
-
158
- avg_loss , opt_state = train_one_epoch (jax_weights , opt_state , epoch_number , writer )
159
-
160
- running_vloss = 0.0
161
- # Set the model to evaluation mode, disabling dropout and using population
162
- # statistics for batch normalization.
163
- model .eval ()
164
-
165
- # Disable gradient computation and reduce memory consumption.
166
- with torch .no_grad ():
167
- for i , vdata in enumerate (validation_loader ):
168
-
169
- vinputs , vlabels = pytree .tree_map_only (torch .Tensor , torch_xla2 .tensor .t2j , vdata )
170
- voutputs = jax_func (jax_weights , (vinputs , )) # call model's forward
171
- vloss = torch_xla2 .interop .call_torch (loss_fn , voutputs , vlabels )
172
- running_vloss += vloss
173
-
174
- avg_vloss = running_vloss / (i + 1 )
175
- print ('LOSS train {} valid {}' .format (avg_loss , avg_vloss ))
176
-
177
- # Log the running loss averaged per batch
178
- # for both training and validation
179
- writer .add_scalars ('Training vs. Validation Loss' ,
180
- { 'Training' : np .asarray (avg_loss ), 'Validation' : np .asarray (avg_vloss ) },
181
- epoch_number + 1 )
182
- writer .flush ()
183
-
184
- # Track best performance, and save the model's state
185
- if avg_vloss < best_vloss :
186
- best_vloss = avg_vloss
187
- model_path = 'model_{}_{}' .format (timestamp , epoch_number )
188
- torch .save (model .state_dict (), model_path )
189
-
190
- epoch_number += 1
130
+ avg_loss , weights , opt_state = train_one_epoch (weights , buffers , opt_state , epoch_number , writer )
131
+ print (avg_loss )
0 commit comments