Skip to content

Commit 527a920

Browse files
committed
New imagenet data loader that uses the pytorch dataset/dataloader abstractions
1 parent da371e0 commit 527a920

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

imagenet/data2.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import math
2+
import random
3+
4+
import torch
5+
6+
import torch.utils.data as data
7+
8+
import os.path
9+
import torchfile
10+
import numpy
11+
from PIL import Image
12+
13+
###########################################################
14+
# This is the main imagenet loading logic
15+
###########################################################
16+
class ImagenetDataset(data.Dataset):
17+
18+
def __init__(self, path, jitter):
19+
self.path = path
20+
self.data = torchfile.load(path)
21+
self.jitter = jitter
22+
self.res = 256
23+
24+
def __len__(self):
25+
return 1000
26+
#return len(self.data.imagePath)
27+
28+
def __getitem__(self, i):
29+
imagePath = self.data.imagePath[i].tobytes()
30+
try:
31+
# remove the null-terminators
32+
imagePath = imagePath[:imagePath.index('\0')]
33+
except:
34+
pass
35+
pic = Image.open(imagePath)
36+
pic = pic.convert('RGB')
37+
if pic.size[0] > pic.size[1]:
38+
pic.resize((self.res * pic.size[0]/pic.size[1], self.res), Image.BILINEAR)
39+
else:
40+
pic.resize((self.res, self.res * pic.size[1]/pic.size[0]), Image.BILINEAR)
41+
42+
h1 = None
43+
w1 = None
44+
if self.jitter:
45+
# random crop
46+
h1 = math.ceil(random.uniform(1e-2, pic.size[0] - self.res))
47+
w1 = math.ceil(random.uniform(1e-2, pic.size[1] - self.res))
48+
else:
49+
# center crop
50+
w1 = math.ceil(pic.size[0] - self.res)/2
51+
h1 = math.ceil(img.size[1] - self.res)/2
52+
53+
pic = pic.crop((w1, h1, w1 + self.res, h1 + self.res))
54+
55+
if self.jitter and random.uniform(0, 1) > 0.5:
56+
pic = pic.transpose(Image.FLIP_LEFT_RIGHT)
57+
58+
img = torch.ByteTensor(numpy.asarray(pic))
59+
img = img.view(pic.size[0], pic.size[1], 3)
60+
img = img.transpose(0,2).transpose(1,2).contiguous() # put it in CHW format
61+
62+
# lets wait until we have Python bindings for torch.image to do scale/crop
63+
return img, torch.IntTensor(1).fill_(self.data.imageClass[i])
64+
65+
class Foo():
66+
def __del__(self):
67+
print("deleted foo")
68+
69+
# demo
70+
if __name__ == "__main__":
71+
import time
72+
num_workers = 8
73+
dataset = ImagenetDataset('/mnt/vol/gfsai-east/ai-group/datasets/imagenet/trainCache.t7', True)
74+
loader = data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=num_workers)
75+
76+
start = time.time()
77+
i = 0
78+
a = Foo()
79+
for images, labels in loader:
80+
print("{}/{}, time= {:.04f} s".format(i, len(dataset), time.time() - start))
81+
i += images.size(0)
82+
start = time.time()
83+
84+
print("done")

0 commit comments

Comments
 (0)