Skip to content

Commit ffdfc08

Browse files
authored
Merge pull request #85 from runame/pytorch-speedups
Add DDP to WMT + faster ImageNet data loading
2 parents 465abed + 1317227 commit ffdfc08

File tree

20 files changed

+270
-195
lines changed

20 files changed

+270
-195
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ env/
77
venv/
88
workdir/
99
makefile
10+
*.out
11+
*.sh

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Docker is the easiest way to enable PyTorch/JAX GPU support on Linux since only
136136
```bash
137137
python3 submission_runner.py \
138138
--framework=jax \
139-
--workload=mnist |\
139+
--workload=mnist \
140140
--submission_path=baselines/mnist/mnist_jax/submission.py \
141141
--tuning_search_space=baselines/mnist/tuning_search_space.json
142142
```
@@ -151,6 +151,14 @@ python3 submission_runner.py \
151151
--tuning_search_space=baselines/mnist/tuning_search_space.json
152152
```
153153

154+
When using multiple GPUs on a single node it is recommended to use PyTorch's
155+
[distributed data parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
156+
To do so, simply replace `python3` by
157+
```bash
158+
torchrun --standalone --nnodes=1 --nproc_per_node=N_GPUS
159+
```
160+
where `N_GPUS` is the number of available GPUs on the node.
161+
154162
## Rules
155163

156164
The rules for the MLCommons Algorithmic Efficency benchmark can be found in the seperate [rules document](RULES.md). Suggestions, clarifications and questions can be raised via pull requests.

algorithmic_efficiency/data_utils.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import jax
2+
import numpy as np
23
import torch
34
import torch.distributed as dist
5+
from torch.utils.data import DataLoader
6+
from torch.utils.data import DistributedSampler
47
from torch.utils.data import Sampler
58

69

710
def shard_numpy_ds(xs):
8-
"""Prepare tf data for JAX
11+
"""Prepare tf data for JAX or PyTorch DDP.
912
1013
Convert an input batch from tf Tensors to numpy arrays and reshape it to be
1114
sharded across devices.
@@ -16,8 +19,9 @@ def _prepare(x):
1619
# Use _numpy() for zero-copy conversion between TF and NumPy.
1720
x = x._numpy() # pylint: disable=protected-access
1821

19-
# reshape (host_batch_size, height, width, 3) to
20-
# (local_devices, device_batch_size, height, width, 3)
22+
# Reshape (global_batch_size, ...) to
23+
# (local_device_count, per_device_batch_size, ...).
24+
# Assumes that `global_batch_size % local_device_count == 0`.
2125
return x.reshape((local_device_count, -1) + x.shape[1:])
2226

2327
return jax.tree_map(_prepare, xs)
@@ -33,7 +37,7 @@ def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False):
3337
assert len(keys) == len(batch)
3438
yield dict(zip(keys, batch))
3539
except StopIteration:
36-
if custom_sampler:
40+
if custom_sampler and isinstance(iterable, DataLoader):
3741
epoch += 1
3842
iterable.sampler.set_epoch(epoch)
3943
iterator = iter(iterable)
@@ -54,7 +58,7 @@ class DistributedEvalSampler(Sampler):
5458
Sampler that restricts data loading to a subset of the dataset.
5559
It is especially useful in conjunction with
5660
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
57-
process can pass a :class`~torch.utils.data.DistributedSampler` instance as
61+
process can pass a :class`~DistributedEvalSampler` instance as
5862
a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
5963
original dataset that is exclusive to it.
6064
.. note::
@@ -144,3 +148,97 @@ def set_epoch(self, epoch):
144148
epoch (int): _epoch number.
145149
"""
146150
self.epoch = epoch
151+
152+
153+
# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
154+
# ConvNets/image_classification/dataloaders.py
155+
def fast_collate(batch, memory_format=torch.contiguous_format):
156+
imgs = [img[0] for img in batch]
157+
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
158+
w = imgs[0].size[0]
159+
h = imgs[0].size[1]
160+
tensor = torch.zeros(
161+
(len(imgs), 3, h, w),
162+
dtype=torch.uint8).contiguous(memory_format=memory_format)
163+
for i, img in enumerate(imgs):
164+
nump_array = np.asarray(img, dtype=np.uint8)
165+
if nump_array.ndim < 3:
166+
nump_array = np.expand_dims(nump_array, axis=-1)
167+
nump_array = np.rollaxis(nump_array, 2)
168+
tensor[i] += torch.from_numpy(nump_array.copy())
169+
return tensor, targets
170+
171+
172+
# Inspired by
173+
# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
174+
# ConvNets/image_classification/dataloaders.py
175+
class PrefetchedWrapper:
176+
177+
def __init__(self, dataloader, device, mean, std, start_epoch=0):
178+
self.dataloader = dataloader
179+
self.epoch = start_epoch
180+
self.device = device
181+
self.data_mean = torch.tensor([i / 255 for i in mean],
182+
device=device).view(1, 3, 1, 1)
183+
self.data_std = torch.tensor([i / 255 for i in std],
184+
device=device).view(1, 3, 1, 1)
185+
186+
def __len__(self):
187+
return len(self.dataloader)
188+
189+
def __iter__(self):
190+
if isinstance(self.dataloader.sampler,
191+
(DistributedSampler, DistributedEvalSampler)):
192+
self.dataloader.sampler.set_epoch(self.epoch)
193+
self.epoch += 1
194+
return self.prefetched_loader()
195+
196+
def prefetched_loader(self):
197+
stream = torch.cuda.Stream()
198+
first = True
199+
200+
for next_inputs, next_targets in self.dataloader:
201+
with torch.cuda.stream(stream):
202+
next_inputs = next_inputs.to(
203+
self.device, dtype=torch.float,
204+
non_blocking=True).sub(self.data_mean).div(self.data_std)
205+
next_targets = next_targets.to(self.device, non_blocking=True)
206+
207+
if not first:
208+
yield inputs, targets
209+
else:
210+
first = False
211+
212+
torch.cuda.current_stream().wait_stream(stream)
213+
inputs = next_inputs
214+
targets = next_targets
215+
216+
yield inputs, targets
217+
218+
219+
# Inspired by github.com/PetrochukM/PyTorch-NLP/blob/master/torchnlp/samplers/
220+
# distributed_sampler.py
221+
class TFDistributedSampler:
222+
223+
def __init__(self, iterator, device='cuda:0', rank=None):
224+
self.iterator = iterator
225+
self.device = device
226+
self.rank = rank
227+
if rank is None:
228+
if not torch.distributed.is_initialized():
229+
raise RuntimeError('Requires `torch.distributed` to be initialized.')
230+
self.rank = torch.distributed.get_rank()
231+
232+
def __iter__(self):
233+
return self
234+
235+
def __next__(self):
236+
batch = next(self.iterator)
237+
batch = {
238+
# Assumes that len(value) > self.rank, i.e. there needs to be data for
239+
# each rank/GPU.
240+
key: torch.as_tensor(
241+
value[self.rank], device=self.device, dtype=torch.int64) for key,
242+
value in batch.items()
243+
}
244+
return batch

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def create_split(split,
214214
area_range=(0.08, 1.0)):
215215
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
216216
del num_batches
217-
if split == 'eval_train':
218-
split = 'train[:50000]'
219217

220218
shuffle_rng, preprocess_rng = jax.random.split(rng, 2)
221219

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def _build_dataset(self,
5858
ds_builder = tfds.builder('imagenet2012:5.*.*', data_dir=data_dir)
5959
ds_builder.download_and_prepare()
6060
train = split == 'train'
61+
if split == 'eval_train':
62+
split = f'train[:{self.num_eval_train_examples}]'
6163
ds = input_pipeline.create_input_iter(
6264
split,
6365
ds_builder,

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from algorithmic_efficiency.workloads.imagenet_resnet.workload import \
2323
BaseImagenetResNetWorkload
2424

25-
PYTORCH_DDP = 'LOCAL_RANK' in os.environ
26-
RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0
25+
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
26+
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
2727
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
2828
N_GPUS = torch.cuda.device_count()
2929

@@ -54,12 +54,7 @@ def build_input_queue(self,
5454
split: str,
5555
data_dir: str,
5656
global_batch_size: int):
57-
it = self._build_dataset(data_rng, split, data_dir, global_batch_size)
58-
for batch in it:
59-
yield {
60-
'inputs': batch['inputs'].float().to(DEVICE, non_blocking=True),
61-
'targets': batch['targets'].to(DEVICE, non_blocking=True),
62-
}
57+
return self._build_dataset(data_rng, split, data_dir, global_batch_size)
6358

6459
def _build_dataset(self,
6560
data_rng: spec.RandomState,
@@ -69,16 +64,9 @@ def _build_dataset(self,
6964
del data_rng
7065
is_train = split == 'train'
7166

72-
normalize = transforms.Compose([
73-
transforms.ToTensor(),
74-
transforms.Normalize(
75-
mean=[i / 255 for i in self.train_mean],
76-
std=[i / 255 for i in self.train_stddev])
77-
])
7867
eval_transform_config = transforms.Compose([
7968
transforms.Resize(self.resize_size),
8069
transforms.CenterCrop(self.center_crop_size),
81-
normalize
8270
])
8371
transform_config = {
8472
'train':
@@ -88,7 +76,6 @@ def _build_dataset(self,
8876
scale=self.scale_ratio_range,
8977
ratio=self.aspect_ratio_range),
9078
transforms.RandomHorizontalFlip(),
91-
normalize
9279
]),
9380
'eval_train':
9481
eval_transform_config,
@@ -108,7 +95,7 @@ def _build_dataset(self,
10895
range(self.num_eval_train_examples))
10996

11097
sampler = None
111-
if PYTORCH_DDP:
98+
if USE_PYTORCH_DDP:
11299
if is_train:
113100
sampler = torch.utils.data.distributed.DistributedSampler(
114101
dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
@@ -119,13 +106,17 @@ def _build_dataset(self,
119106
dataloader = torch.utils.data.DataLoader(
120107
dataset,
121108
batch_size=batch_size,
122-
shuffle=not PYTORCH_DDP and is_train,
109+
shuffle=not USE_PYTORCH_DDP and is_train,
123110
sampler=sampler,
124-
num_workers=0,
111+
num_workers=4,
125112
pin_memory=True,
113+
collate_fn=data_utils.fast_collate,
126114
drop_last=is_train)
127-
128-
dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP)
115+
dataloader = data_utils.PrefetchedWrapper(dataloader,
116+
DEVICE,
117+
self.train_mean,
118+
self.train_stddev)
119+
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
129120

130121
return dataloader
131122

@@ -137,7 +128,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
137128
}
138129
model.to(DEVICE)
139130
if N_GPUS > 1:
140-
if PYTORCH_DDP:
131+
if USE_PYTORCH_DDP:
141132
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
142133
model = DDP(model, device_ids=[RANK], output_device=RANK)
143134
else:
@@ -245,7 +236,7 @@ def _eval_model_on_split(self,
245236
total_metrics = {
246237
k: v + batch_metrics[k] for k, v in total_metrics.items()
247238
}
248-
if PYTORCH_DDP:
239+
if USE_PYTORCH_DDP:
249240
for metric in total_metrics.values():
250241
dist.all_reduce(metric)
251242
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}

algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from algorithmic_efficiency.workloads.imagenet_vit.workload import \
1818
decode_variant
1919

20-
PYTORCH_DDP = 'LOCAL_RANK' in os.environ
21-
RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0
20+
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
21+
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
2222
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
2323
N_GPUS = torch.cuda.device_count()
2424

@@ -35,7 +35,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
3535
}
3636
model.to(DEVICE)
3737
if N_GPUS > 1:
38-
if PYTORCH_DDP:
38+
if USE_PYTORCH_DDP:
3939
model = DDP(model, device_ids=[RANK], output_device=RANK)
4040
else:
4141
model = torch.nn.DataParallel(model)

algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""MNIST workload implemented in Jax."""
22
import functools
3+
import itertools
34
from typing import Any, Dict, Tuple
45

56
from flax import jax_utils
@@ -102,7 +103,11 @@ def build_input_queue(self,
102103
split: str,
103104
data_dir: str,
104105
global_batch_size: int) -> Dict[str, Any]:
105-
return self._build_dataset(data_rng, split, data_dir, global_batch_size)
106+
ds = self._build_dataset(data_rng, split, data_dir, global_batch_size)
107+
if split != 'train':
108+
# Note that this stores the entire eval dataset in memory.
109+
ds = itertools.cycle(ds)
110+
return ds
106111

107112
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
108113
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)

algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from algorithmic_efficiency import spec
1818
from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload
1919

20-
PYTORCH_DDP = 'LOCAL_RANK' in os.environ
21-
RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0
20+
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
21+
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
2222
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
2323
N_GPUS = torch.cuda.device_count()
2424

@@ -78,7 +78,7 @@ def _build_dataset(self,
7878
is_train = split == 'train'
7979

8080
sampler = None
81-
if PYTORCH_DDP:
81+
if USE_PYTORCH_DDP:
8282
if is_train:
8383
sampler = torch.utils.data.distributed.DistributedSampler(
8484
dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
@@ -89,13 +89,12 @@ def _build_dataset(self,
8989
dataloader = torch.utils.data.DataLoader(
9090
dataset,
9191
batch_size=batch_size,
92-
shuffle=not PYTORCH_DDP and is_train,
92+
shuffle=not USE_PYTORCH_DDP and is_train,
9393
sampler=sampler,
9494
num_workers=0,
9595
pin_memory=True,
9696
drop_last=is_train)
97-
if is_train:
98-
dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP)
97+
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
9998

10099
return dataloader
101100

@@ -118,14 +117,9 @@ def build_input_queue(self,
118117
global_batch_size: int) -> Dict[str, Any]:
119118
it = self._build_dataset(data_rng, split, data_dir, global_batch_size)
120119
for batch in it:
121-
if isinstance(batch, dict):
122-
inputs = batch['inputs']
123-
targets = batch['targets']
124-
else:
125-
inputs, targets = batch
126120
yield {
127-
'inputs': inputs.to(DEVICE, non_blocking=True),
128-
'targets': targets.to(DEVICE, non_blocking=True),
121+
'inputs': batch['inputs'].to(DEVICE, non_blocking=True),
122+
'targets': batch['targets'].to(DEVICE, non_blocking=True),
129123
}
130124

131125
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
@@ -136,7 +130,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
136130
}
137131
model.to(DEVICE)
138132
if N_GPUS > 1:
139-
if PYTORCH_DDP:
133+
if USE_PYTORCH_DDP:
140134
model = DDP(model, device_ids=[RANK], output_device=RANK)
141135
else:
142136
model = torch.nn.DataParallel(model)

0 commit comments

Comments
 (0)