Skip to content

Commit da371e0

Browse files
committed
imagenet data loader
1 parent 77a6ec7 commit da371e0

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

imagenet/data.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import math
2+
import random
3+
4+
import torch
5+
import torch.multiprocessing as multiprocessing
6+
7+
import os.path
8+
import torchfile
9+
import numpy
10+
from PIL import Image
11+
12+
###########################################################
13+
# These widgets go in some dataset library
14+
###########################################################
15+
class Dataset(object):
16+
17+
def size(self):
18+
raise NotImplementedError()
19+
20+
def get(self, i):
21+
raise NotImplementedError()
22+
23+
class PermutedDataset(Dataset):
24+
25+
def __init__(self, dataset, perm=None):
26+
self.dataset = dataset
27+
self.perm = perm or torch.randperm(dataset.size())
28+
29+
def size(self):
30+
return self.dataset.size()
31+
32+
def get(self, i):
33+
return self.dataset.get(int(self.perm[i]))
34+
35+
class PartitionedDataset(Dataset):
36+
37+
def __init__(self, dataset, part, nPart):
38+
self.dataset = dataset
39+
self.start = dataset.size() * part / nPart
40+
self.end = dataset.size() * (part+1) / nPart
41+
42+
def size(self):
43+
return self.end - self.start
44+
45+
def get(self, i):
46+
return self.dataset.get(self.start + i)
47+
48+
###########################################################
49+
# This is the main imagenet loading logic
50+
###########################################################
51+
class ImagenetDataset(Dataset):
52+
53+
def __init__(self, path, jitter):
54+
self.path = path
55+
self.data = torchfile.load(path)
56+
self.jitter = jitter
57+
self.res = 256
58+
59+
def size(self):
60+
# return 1000
61+
return len(self.data.imagePath)
62+
63+
def get(self, i):
64+
imagePath = self.data.imagePath[i].tobytes()
65+
try:
66+
# remove the null-terminators
67+
imagePath = imagePath[:imagePath.index('\0')]
68+
except:
69+
pass
70+
pic = Image.open(imagePath)
71+
pic = pic.convert('RGB')
72+
if pic.size[0] > pic.size[1]:
73+
pic.resize((self.res * pic.size[0]/pic.size[1], self.res), Image.BILINEAR)
74+
else:
75+
pic.resize((self.res, self.res * pic.size[1]/pic.size[0]), Image.BILINEAR)
76+
77+
h1 = None
78+
w1 = None
79+
if self.jitter:
80+
# random crop
81+
h1 = math.ceil(random.uniform(1e-2, pic.size[0] - self.res))
82+
w1 = math.ceil(random.uniform(1e-2, pic.size[1] - self.res))
83+
else:
84+
# center crop
85+
w1 = math.ceil(pic.size[0] - self.res)/2
86+
h1 = math.ceil(img.size[1] - self.res)/2
87+
88+
pic = pic.crop((w1, h1, w1 + self.res, h1 + self.res))
89+
90+
if self.jitter and random.uniform(0, 1) > 0.5:
91+
pic = pic.transpose(Image.FLIP_LEFT_RIGHT)
92+
93+
img = torch.ByteTensor(numpy.asarray(pic))
94+
img = img.view(pic.size[0], pic.size[1], 3)
95+
img = img.transpose(0,2).transpose(1,2).contiguous() # put it in CHW format
96+
97+
# lets wait until we have Python bindings for torch.image to do scale/crop
98+
return img, self.data.imageClass[i]
99+
100+
101+
###########################################################
102+
# Where does this widget go?
103+
###########################################################
104+
class MultiQueueIterator(object):
105+
106+
def __init__(self, queue, N, sentinel=None):
107+
self.queue = queue
108+
self.N = N
109+
self.i = 0
110+
self.sentinel = sentinel
111+
112+
def __iter__(self):
113+
return self
114+
115+
def next(self):
116+
while self.i < self.N:
117+
e = self.queue.get()
118+
if e == self.sentinel:
119+
self.i += 1
120+
else:
121+
return e
122+
raise StopIteration()
123+
124+
125+
###########################################################
126+
# Shim that runs in each process
127+
###########################################################
128+
def _dataLoader(queue, dataset):
129+
batchSize = 64
130+
for i in range(0, dataset.size(), batchSize):
131+
batch = [dataset.get(x) for x in range(i, i + batchSize) if x < dataset.size()]
132+
queue.put(zip(*batch))
133+
queue.put(None)
134+
135+
136+
###########################################################
137+
# This is what's called externally
138+
###########################################################
139+
def makeDataIterator(datasetPath, isTest, nProc):
140+
dataset = PermutedDataset(ImagenetDataset(datasetPath, not isTest))
141+
queue = multiprocessing.Queue()
142+
processes = [multiprocessing.Process(target=_dataLoader,
143+
args=(queue, PartitionedDataset(dataset, i, nProc))).start() for i in range(nProc)]
144+
return dataset, MultiQueueIterator(queue, nProc)
145+
146+
# demo
147+
if __name__ == "__main__":
148+
import time
149+
nDonkeys = 8
150+
dataset, dataIterator = makeDataIterator(
151+
'/mnt/vol/gfsai-east/ai-group/datasets/imagenet/trainCache.t7',
152+
False, nDonkeys)
153+
154+
start = time.time()
155+
i = 0
156+
for images, labels in dataIterator:
157+
print("{}/{}, time= {:.04f} s".format(i, dataset.size(), time.time() - start))
158+
i += len(images)
159+
start = time.time()

0 commit comments

Comments
 (0)