Skip to content

Commit 44da562

Browse files
authored
Merge pull request #4 from pytorch/numpy
fix ToTensor to handle numpy
2 parents 63dabca + e659e27 commit 44da562

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

test/cifar.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
11
import torch
22
import torchvision.datasets as dset
3+
import torchvision.transforms as transforms
34

4-
print('\n\nCifar 10')
5-
a = dset.CIFAR10(root="abc/def/ghi", download=True)
5+
# print('\n\nCifar 10')
6+
# a = dset.CIFAR10(root="abc/def/ghi", download=True)
67

7-
print(a[3])
8+
# print(a[3])
89

9-
print('\n\nCifar 100')
10-
a = dset.CIFAR100(root="abc/def/ghi", download=True)
10+
# print('\n\nCifar 100')
11+
# a = dset.CIFAR100(root="abc/def/ghi", download=True)
1112

12-
print(a[3])
13+
# print(a[3])
14+
15+
16+
dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor())
17+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32,
18+
shuffle=True, num_workers=2)
19+
20+
21+
# miter = dataloader.__iter__()
22+
# def getBatch():
23+
# global miter
24+
# try:
25+
# return miter.next()
26+
# except StopIteration:
27+
# miter = dataloader.__iter__()
28+
# return miter.next()
29+
30+
# i=0
31+
# while True:
32+
# print(i)
33+
# img, target = getBatch()
34+
# i+=1
35+

torchvision/transforms.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import random
44
from PIL import Image
5+
import numpy as np
56

67

78
class Compose(object):
@@ -16,11 +17,16 @@ def __call__(self, img):
1617

1718
class ToTensor(object):
1819
def __call__(self, pic):
19-
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
20-
img = img.view(pic.size[0], pic.size[1], 3)
21-
# put it in CHW format
22-
# yikes, this transpose takes 80% of the loading time/CPU
23-
img = img.transpose(0, 2).transpose(1, 2).contiguous()
20+
if isinstance(pic, np.ndarray):
21+
# handle numpy array
22+
img = torch.from_numpy(pic)
23+
else:
24+
# handle PIL Image
25+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
26+
img = img.view(pic.size[0], pic.size[1], 3)
27+
# put it in CHW format
28+
# yikes, this transpose takes 80% of the loading time/CPU
29+
img = img.transpose(0, 2).transpose(1, 2).contiguous()
2430
return img.float()
2531

2632
class Normalize(object):

0 commit comments

Comments
 (0)