Skip to content

Commit 1b5505f

Browse files
authored
Merge pull request #81 from runame/pytorch-speedups
Fix bugs and implement DDP
2 parents 987a47e + c4a07ac commit 1b5505f

File tree

16 files changed

+361
-142
lines changed

16 files changed

+361
-142
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import jax
2+
import torch
3+
import torch.distributed as dist
4+
from torch.utils.data import Sampler
5+
6+
7+
def shard_numpy_ds(xs):
8+
"""Prepare tf data for JAX
9+
10+
Convert an input batch from tf Tensors to numpy arrays and reshape it to be
11+
sharded across devices.
12+
"""
13+
local_device_count = jax.local_device_count()
14+
15+
def _prepare(x):
16+
# Use _numpy() for zero-copy conversion between TF and NumPy.
17+
x = x._numpy() # pylint: disable=protected-access
18+
19+
# reshape (host_batch_size, height, width, 3) to
20+
# (local_devices, device_batch_size, height, width, 3)
21+
return x.reshape((local_device_count, -1) + x.shape[1:])
22+
23+
return jax.tree_map(_prepare, xs)
24+
25+
26+
# github.com/pytorch/pytorch/issues/23900#issuecomment-518858050
27+
def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False):
28+
iterator = iter(iterable)
29+
epoch = 0
30+
while True:
31+
try:
32+
batch = next(iterator)
33+
assert len(keys) == len(batch)
34+
yield dict(zip(keys, batch))
35+
except StopIteration:
36+
if custom_sampler:
37+
epoch += 1
38+
iterable.sampler.set_epoch(epoch)
39+
iterator = iter(iterable)
40+
41+
42+
# github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py
43+
class DistributedEvalSampler(Sampler):
44+
r"""
45+
DistributedEvalSampler is different from DistributedSampler.
46+
It does NOT add extra samples to make it evenly divisible.
47+
DistributedEvalSampler should NOT be used for training. The distributed
48+
processes could hang forever.
49+
See this issue for details: https://github.com/pytorch/pytorch/issues/22584
50+
shuffle is disabled by default
51+
DistributedEvalSampler is for evaluation purpose where synchronization does
52+
not happen every epoch.
53+
Synchronization should be done outside the dataloader loop.
54+
Sampler that restricts data loading to a subset of the dataset.
55+
It is especially useful in conjunction with
56+
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
57+
process can pass a :class`~torch.utils.data.DistributedSampler` instance as
58+
a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
59+
original dataset that is exclusive to it.
60+
.. note::
61+
Dataset is assumed to be of constant size.
62+
Arguments:
63+
dataset: Dataset used for sampling.
64+
num_replicas (int, optional): Number of processes participating in
65+
distributed training. By default, :attr:`rank` is retrieved from the
66+
current distributed group.
67+
rank (int, optional): Rank of the current process within
68+
:attr:`num_replicas`. By default, :attr:`rank` is retrieved from the
69+
current distributed group.
70+
shuffle (bool, optional): If ``True``, sampler will shuffle the
71+
indices. Default: ``False``
72+
seed (int, optional): random seed used to shuffle the sampler if
73+
:attr:`shuffle=True`. This number should be identical across all
74+
processes in the distributed group. Default: ``0``.
75+
.. warning::
76+
In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>`
77+
method at the beginning of each epoch **before** creating the
78+
:class:`DataLoader` iterator is necessary to make shuffling work
79+
properly across multiple epochs. Otherwise, the same ordering will be
80+
always used.
81+
Example::
82+
>>> sampler = DistributedSampler(dataset) if is_distributed else None
83+
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
84+
... sampler=sampler)
85+
>>> for epoch in range(start_epoch, n_epochs):
86+
... if is_distributed:
87+
... sampler.set_epoch(epoch)
88+
... train(loader)
89+
"""
90+
91+
def __init__(self,
92+
dataset,
93+
num_replicas=None,
94+
rank=None,
95+
shuffle=False,
96+
seed=0):
97+
if num_replicas is None:
98+
if not dist.is_available():
99+
raise RuntimeError('Requires distributed package to be available.')
100+
num_replicas = dist.get_world_size()
101+
if rank is None:
102+
if not dist.is_available():
103+
raise RuntimeError('Requires distributed package to be available.')
104+
rank = dist.get_rank()
105+
self.dataset = dataset
106+
self.num_replicas = num_replicas
107+
self.rank = rank
108+
self.epoch = 0
109+
# true value without extra samples
110+
self.total_size = len(self.dataset)
111+
indices = list(range(self.total_size))
112+
indices = indices[self.rank:self.total_size:self.num_replicas]
113+
# true value without extra samples
114+
self.num_samples = len(indices)
115+
116+
self.shuffle = shuffle
117+
self.seed = seed
118+
119+
def __iter__(self):
120+
if self.shuffle:
121+
# deterministically shuffle based on epoch and seed
122+
g = torch.Generator()
123+
g.manual_seed(self.seed + self.epoch)
124+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
125+
else:
126+
indices = list(range(len(self.dataset)))
127+
128+
# subsample
129+
indices = indices[self.rank:self.total_size:self.num_replicas]
130+
assert len(indices) == self.num_samples
131+
132+
return iter(indices)
133+
134+
def __len__(self):
135+
return self.num_samples
136+
137+
def set_epoch(self, epoch):
138+
r"""
139+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this
140+
ensures all replicas use a different random ordering for each epoch.
141+
Otherwise, the next iteration of this sampler will yield the same
142+
ordering.
143+
Arguments:
144+
epoch (int): _epoch number.
145+
"""
146+
self.epoch = epoch

algorithmic_efficiency/spec.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ def build_input_queue(self,
106106
data_rng: RandomState,
107107
split: str,
108108
data_dir: str,
109-
global_batch_size: int) -> Dict[str, Any]:
109+
global_batch_size: int,
110+
cache: Optional[bool] = None,
111+
repeat_final_dataset: Optional[bool] = None,
112+
num_batches: Optional[int] = None) -> Dict[str, Any]:
110113
"""Build the input queue for the workload data.
111114
112115
This is the only function that is NOT allowed to be called by submitters.

algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import tensorflow as tf
1010
import tensorflow_datasets as tfds
1111

12+
from algorithmic_efficiency import data_utils
13+
1214
IMAGE_SIZE = 224
1315
RESIZE_SIZE = 256
1416
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
@@ -156,16 +158,18 @@ def preprocess_for_train(image_bytes,
156158
Returns:
157159
A preprocessed image `Tensor`.
158160
"""
159-
crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2)
161+
# Note (runame): Cannot be done in graph mode, i.e. during ds.map().
162+
# Alternative?
163+
# crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2)
160164

161165
image = _decode_and_random_crop(image_bytes,
162-
crop_rng,
166+
rng,
163167
image_size,
164168
aspect_ratio_range,
165169
area_range,
166170
resize_size)
167171
image = tf.reshape(image, [image_size, image_size, 3])
168-
image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng)
172+
image = tf.image.stateless_random_flip_left_right(image, seed=rng)
169173
image = normalize_image(image, mean_rgb, stddev_rgb)
170174
image = tf.image.convert_image_dtype(image, dtype=dtype)
171175
return image
@@ -209,22 +213,19 @@ def create_split(split,
209213
aspect_ratio_range=(0.75, 4.0 / 3.0),
210214
area_range=(0.08, 1.0)):
211215
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
216+
del num_batches
212217
if split == 'eval_train':
213-
split = 'train'
218+
split = 'train[:50000]'
214219

215220
shuffle_rng, preprocess_rng = jax.random.split(rng, 2)
216221

217-
def decode_example(example):
222+
def decode_example(example_index, example):
218223
dtype = tf.float32
219224
if train:
220-
# We call ds.enumerate() to get a globally unique per-example, per-step
221-
# index that we can fold into the RNG seed.
222-
(example_index, example) = example
223225
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
224226
tf.cast(preprocess_rng, tf.int64), example_index)
225227
image = preprocess_for_train(example['image'],
226228
per_step_preprocess_rng,
227-
example_index,
228229
mean_rgb,
229230
stddev_rgb,
230231
aspect_ratio_range,
@@ -246,7 +247,7 @@ def decode_example(example):
246247
'image': tfds.decode.SkipDecoding(),
247248
})
248249
options = tf.data.Options()
249-
options.experimental_threading.private_threadpool_size = 48
250+
options.threading.private_threadpool_size = 48
250251
ds = ds.with_options(options)
251252

252253
if cache:
@@ -256,11 +257,11 @@ def decode_example(example):
256257
ds = ds.repeat()
257258
ds = ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0])
258259

260+
# We call ds.enumerate() to get a globally unique per-example, per-step
261+
# index that we can fold into the RNG seed.
262+
ds = ds.enumerate()
259263
ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
260-
ds = ds.batch(global_batch_size, drop_remainder=True)
261-
262-
if num_batches is not None:
263-
ds = ds.take(num_batches)
264+
ds = ds.batch(global_batch_size, drop_remainder=train)
264265

265266
if repeat_final_dataset:
266267
ds = ds.repeat()
@@ -270,25 +271,6 @@ def decode_example(example):
270271
return ds
271272

272273

273-
def shard_numpy_ds(xs):
274-
"""Prepare tf data for JAX
275-
276-
Convert an input batch from tf Tensors to numpy arrays and reshape it to be
277-
sharded across devices.
278-
"""
279-
local_device_count = jax.local_device_count()
280-
281-
def _prepare(x):
282-
# Use _numpy() for zero-copy conversion between TF and NumPy.
283-
x = x._numpy() # pylint: disable=protected-access
284-
285-
# reshape (host_batch_size, height, width, 3) to
286-
# (local_devices, device_batch_size, height, width, 3)
287-
return x.reshape((local_device_count, -1) + x.shape[1:])
288-
289-
return jax.tree_map(_prepare, xs)
290-
291-
292274
def create_input_iter(split,
293275
dataset_builder,
294276
rng,
@@ -309,7 +291,6 @@ def create_input_iter(split,
309291
rng,
310292
global_batch_size,
311293
train=train,
312-
dtype=tf.float32,
313294
image_size=image_size,
314295
resize_size=resize_size,
315296
mean_rgb=mean_rgb,
@@ -319,9 +300,9 @@ def create_input_iter(split,
319300
num_batches=num_batches,
320301
aspect_ratio_range=aspect_ratio_range,
321302
area_range=area_range)
322-
it = map(shard_numpy_ds, ds)
303+
it = map(data_utils.shard_numpy_ds, ds)
323304

324305
# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
325306
it = jax_utils.prefetch_to_device(it, 2)
326307

327-
return it
308+
return iter(it)

algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,17 @@ def build_input_queue(self,
3232
data_rng: spec.RandomState,
3333
split: str,
3434
data_dir: str,
35-
global_batch_size: int):
36-
return iter(
37-
self._build_dataset(data_rng, split, data_dir, global_batch_size))
35+
global_batch_size: int,
36+
cache: Optional[bool] = None,
37+
repeat_final_dataset: Optional[bool] = None,
38+
num_batches: Optional[int] = None):
39+
return self._build_dataset(data_rng,
40+
split,
41+
data_dir,
42+
global_batch_size,
43+
cache,
44+
repeat_final_dataset,
45+
num_batches)
3846

3947
def _build_dataset(self,
4048
data_rng: spec.RandomState,
@@ -144,6 +152,8 @@ def model_fn(
144152
mode: spec.ForwardPassMode,
145153
rng: spec.RandomState,
146154
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
155+
del mode
156+
del rng
147157
variables = {'params': params, **model_state}
148158
if update_batch_norm:
149159
logits, new_model_state = self._model.apply(
@@ -171,13 +181,14 @@ def loss_fn(self, label_batch: spec.Tensor,
171181
return xentropy
172182

173183
def _compute_metrics(self, logits, labels):
174-
loss = jnp.mean(self.loss_fn(labels, logits))
175-
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
184+
loss = jnp.sum(self.loss_fn(labels, logits))
185+
# not accuracy, but nr. of correct predictions
186+
accuracy = jnp.sum(jnp.argmax(logits, -1) == labels)
176187
metrics = {
177188
'loss': loss,
178189
'accuracy': accuracy,
179190
}
180-
metrics = lax.pmean(metrics, axis_name='batch')
191+
metrics = lax.psum(metrics, axis_name='batch')
181192
return metrics
182193

183194
def _eval_model_on_split(self,
@@ -213,5 +224,6 @@ def _eval_model_on_split(self,
213224
eval_metrics[metric_name] = 0.0
214225
eval_metrics[metric_name] += metric_value
215226

216-
eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics)
227+
eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
228+
eval_metrics)
217229
return eval_metrics

0 commit comments

Comments
 (0)