Description
🐛 Describe the bug
When using the arm version cpu package of torch (2.2.1/2.2.2/2.3.0), the time taken to data loading increase (like "noise"). After looking back at historical updates, it seems that build product torch-2.3.0.dev20240207 introduced some changes which influenced the data loading, as torch-2.3.0.dev20240206 shows a normal behavior.
But commits to pytorch during that time did not seem to affect data loading, maybe this builder behavior modification about arm version package (pytorch/builder#1696) caused this issue? Have no ideas about the underlying dependency and need help.
dataset : imagenet2012
batch_size : 512
num_workers : 64
Pillow : 10.3.0
Test Result (only change the torch)
For dev20240206, time taken to 30 steps
Step 0 : 4.522608757019043
Step 1 : 0.0013849735260009766
Step 2 : 0.24074482917785645
Step 3 : 0.00135040283203125
Step 4 : 0.0012285709381103516
Step 5 : 0.0013418197631835938
... ...
Step 25 : 0.0010256767272949219
Step 26 : 0.001039743423461914
Step 27 : 0.001051187515258789
Step 28 : 0.0011143684387207031
Step 29 : 0.06728792190551758
Step Avg : 0.16862444082895914
For 2.3.0, time taken to 30 steps
Step 0 : 186.62519240379333
Step 1 : 5.451301097869873
Step 2 : 0.0008325576782226562
Step 3 : 0.0007758140563964844
Step 4 : 0.20022034645080566
Step 5 : 0.0008080005645751953
Step 6 : 4.120985746383667
Step 7 : 0.0008554458618164062
Step 8 : 0.0008018016815185547
Step 9 : 0.0008161067962646484
Step 10 : 0.000774383544921875
Step 11 : 0.0008170604705810547
Step 12 : 0.0007812976837158203
Step 13 : 0.0008800029754638672
Step 14 : 8.32148790359497
Step 15 : 0.0008358955383300781
Step 16 : 0.0007970333099365234
Step 17 : 0.0008027553558349609
Step 18 : 0.0007946491241455078
Step 19 : 8.685199737548828
Step 20 : 0.0008993148803710938
Step 21 : 0.0008280277252197266
Step 22 : 0.0007817745208740234
Step 23 : 0.00079345703125
Step 24 : 0.0008122920989990234
Step 25 : 0.0007855892181396484
Step 26 : 0.0007987022399902344
Step 27 : 0.0008292198181152344
Step 28 : 0.0007941722869873047
Step 29 : 0.0008187294006347656
Step Avg : 7.114162381490072
For dev20240207, time taken to 30 steps
Step 0 : 191.67563939094543
Step 1 : 0.0010831356048583984
Step 2 : 17.258564472198486
Step 3 : 0.0008544921875
Step 4 : 0.0008475780487060547
Step 5 : 0.000812530517578125
... ...
Step 11 : 0.0007905960083007812
Step 12 : 0.0008392333984375
Step 13 : 19.682594776153564
Step 14 : 0.00084686279296875
Step 15 : 0.0007865428924560547
... ...
Step 25 : 0.0008213520050048828
Step 26 : 0.0008218288421630859
Step 27 : 0.0008070468902587891
Step 28 : 0.00078582763671875
Step 29 : 0.0007913112640380859
Step Avg : 7.621339146296183
Test Demo
import os
import time
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for _ in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
def fast_collate(batch):
images = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = images[0].size[0]
h = images[0].size[1]
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
for i, img in enumerate(images):
np_array = np.asarray(img, dtype=np.uint8)
if np_array.ndim < 3:
np_array = np.expand_dims(np_array, axis=-1)
np_array = np.rollaxis(np_array, 2)
tensor[i] += torch.from_numpy(np_array.copy())
return tensor, targets
def load():
train_dir = os.path.join("/data/imagenet2012", "train")
train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]))
train_sampler = None
batch_size = 512
workers = 64
train_loader = MultiEpochsDataLoader(
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, sampler=train_sampler,
collate_fn=fast_collate, drop_last=True)
start_time = time.time()
previous_timestamp = start_time
for i, (images, target) in enumerate(train_loader):
print(f"Step {i} : {time.time() - previous_timestamp}")
previous_timestamp = time.time()
if i == 29:
print(f"Step Avg : {(previous_timestamp - start_time) / 30}")
break
if __name__ == "__main__":
load()
Versions
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: CentOS Linux 7 (AltArch) (aarch64)
GCC version: (GCC) 10.2.1 20210130 (Red Hat 10.2.1-11)
Clang version: Could not collect
CMake version: version 3.29.0
Libc version: glibc-2.17
Python version: 3.8.19 (default, Apr 2 2024, 06:27:46) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)] (64-bit runtime)
Python platform: Linux-4.18.0-80.7.2.el7.aarch64-aarch64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: aarch64
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 1
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Model: 0
CPU max MHz: 2400.0000
CPU min MHz: 2400.0000
BogoMIPS: 200.00
L1d cache: 64K
L1i cache: 64K
L2 cache: 512K
L3 cache: 32768K
NUMA node0 CPU(s): 0-31
NUMA node1 CPU(s): 32-63
Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma dcpop asimddp asimdfhm
Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.3.0
[pip3] torchvision==0.18.0
[conda] Could not collect