Skip to content

Commit a4fd812

Browse files
committed
add gqn_dataset for visualization
1 parent c68e267 commit a4fd812

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

gqn_dataset.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import gzip
3+
import torch
4+
5+
def collect_files(path, ext=None, key=None):
6+
if key is None:
7+
files = sorted(os.listdir(path))
8+
else:
9+
files = sorted(os.listdir(path), key=key)
10+
11+
if ext is not None:
12+
files = [f for f in files if os.path.splitext(f)[-1] == ext]
13+
14+
return [os.path.join(path, fname) for fname in files]
15+
16+
_base_dir = os.path.expanduser('~/Workspace/dataset/gqn_dataset')
17+
18+
19+
class GQNDataset:
20+
def __init__(self, base_dir=_base_dir, scene='shepard_metzler_5_parts',
21+
mode='train', transform=None):
22+
self.base_dir = os.path.expanduser(base_dir)
23+
self.data_dir = os.path.join(self.base_dir, scene, mode)
24+
self.filenames = collect_files(self.data_dir, ext='.gz')
25+
self.transform = transform
26+
27+
def __len__(self):
28+
return len(self.filenames)
29+
30+
def __getitem__(self, i):
31+
filename = self.filenames[i]
32+
33+
with gzip.open(filename, 'rb') as f:
34+
data = torch.load(f)
35+
36+
images_list, poses_list = list(zip(*data))
37+
images_seqs = np.array(images_list)
38+
poses_seqs = np.array(poses_list)
39+
40+
return images_seqs
41+
42+
43+
if __name__ == '__main__':
44+
import matplotlib.pyplot as plt
45+
import numpy as np
46+
47+
ds = GQNDataset(mode='train')
48+
images_list = ds[0]
49+
50+
n = 6
51+
f = plt.figure(figsize=(12, 8))
52+
axes = f.subplots(nrows=n, ncols=1, sharex=True, sharey=True)
53+
for i in range(n):
54+
images = images_list[i]
55+
grid = np.hstack(images[:10])
56+
axes[i].imshow(grid)
57+
plt.show()

0 commit comments

Comments
 (0)