Skip to content

Commit c68e267

Browse files
committed
Add convert2file.py which uses eager execution and multiprocess to
speedup
1 parent 4ffd852 commit c68e267

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed

convert2file.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import os
2+
import tensorflow as tf
3+
import torch
4+
import gzip
5+
from functools import partial
6+
from collections import namedtuple
7+
import argparse as ap
8+
import multiprocessing as mp
9+
10+
11+
"""
12+
This file converts tfrecords in deepmind gqn dataset to gzip files. Each tfrecord will be converted
13+
to a single gzip file (561-of-900.tfrecord -> 561-of-900.pt.gz).
14+
15+
Each gzip file contains a list of tuples, where each tuple is of (images, poses)
16+
For example, when converting the shepard_metzler_5_parts dataset with batch_size of 32, the gzip
17+
file contains a list of length 32, each tuple contains images (15,64,64,3) and poses (15,5), where
18+
15 is the sequence length.
19+
20+
In the original implementation, each sequence is converted to a gzip file, this results in more than
21+
800K small files on the disk. Here we choose to pack multiple sequences into one gzip file, thus
22+
avoiding having too many small files. Note that the gqn implementation from wohlert
23+
(https://github.com/wohlert/generative-query-network-pytorch) works with the original version. In
24+
order for it to work with the new format, one can simply change (in wohlert gqn) batch_size to 1
25+
and do a squeeze after the loader.
26+
27+
It is also recommended to remove the first 500 records of both shepard metzler dataset as they
28+
only contain 20 sequences, compared to the last 400 records which contain 2000 sequences.
29+
30+
Example:
31+
convert all records with all sequences in sm5 train (400 records, 2000 seq each)
32+
python convert2file.py ~/gqn_dataset shepard_metzler_5_parts
33+
34+
Convert first 20 records with batch size of 128 in sm5 test
35+
python convert2file.py ~/gqn_dataset shepard_metzler_5_parts -n 20 -b 128 -m test
36+
"""
37+
38+
tf.logging.set_verbosity(tf.logging.ERROR) # disable annoying logging
39+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
40+
41+
DatasetInfo = namedtuple('DatasetInfo', ['image_size', 'seq_length'])
42+
43+
all_datasets = dict(
44+
jaco=DatasetInfo(image_size=64, seq_length=11),
45+
mazes=DatasetInfo(image_size=84, seq_length=300),
46+
rooms_free_camera_with_object_rotations=DatasetInfo(image_size=128, seq_length=10),
47+
rooms_ring_camera=DatasetInfo(image_size=64, seq_length=10),
48+
rooms_free_camera_no_object_rotations=DatasetInfo(image_size=64, seq_length=10),
49+
shepard_metzler_5_parts=DatasetInfo(image_size=64, seq_length=15),
50+
shepard_metzler_7_parts=DatasetInfo(image_size=64, seq_length=15)
51+
)
52+
53+
_pose_dim = 5
54+
55+
56+
def collect_files(path, ext=None, key=None):
57+
if key is None:
58+
files = sorted(os.listdir(path))
59+
else:
60+
files = sorted(os.listdir(path), key=key)
61+
62+
if ext is not None:
63+
files = [f for f in files if os.path.splitext(f)[-1] == ext]
64+
65+
return [os.path.join(path, fname) for fname in files]
66+
67+
68+
def convert_record(record, info, batch_size=None):
69+
print(record)
70+
71+
path, filename = os.path.split(record)
72+
basename = os.path.splitext(filename)[0]
73+
scenes = process_record(record, info, batch_size)
74+
# scenes is a list of tuples (image_seq, pose_seq)
75+
out = os.path.join(path, f'{basename}.pt.gz')
76+
save_to_disk(scenes, out)
77+
78+
79+
def save_to_disk(scenes, path):
80+
with gzip.open(path, 'wb') as f:
81+
torch.save(scenes, f)
82+
83+
84+
def process_record(record, info, batch_size=None):
85+
engine = tf.python_io.tf_record_iterator(record)
86+
87+
scenes = []
88+
for i, data in enumerate(engine):
89+
if i == batch_size:
90+
break
91+
scene = convert_to_numpy(data, info)
92+
scenes.append(scene)
93+
94+
return scenes
95+
96+
97+
def process_images(example, seq_length, image_size):
98+
"""Instantiates the ops used to preprocess the frames data."""
99+
images = tf.concat(example['frames'], axis=0)
100+
images = tf.map_fn(tf.image.decode_jpeg, tf.reshape(images, [-1]),
101+
dtype=tf.uint8, back_prop=False)
102+
shape = (image_size, image_size, 3)
103+
images = tf.reshape(images, (-1, seq_length) + shape)
104+
return images
105+
106+
107+
def process_poses(example, seq_length):
108+
"""Instantiates the ops used to preprocess the cameras data."""
109+
poses = example['cameras']
110+
poses = tf.reshape(poses, (-1, seq_length, _pose_dim))
111+
return poses
112+
113+
114+
def convert_to_numpy(raw_data, info):
115+
seq_length = info.seq_length
116+
image_size = info.image_size
117+
118+
feature = {'frames': tf.FixedLenFeature(shape=seq_length, dtype=tf.string),
119+
'cameras': tf.FixedLenFeature(shape=seq_length * _pose_dim, dtype=tf.float32)}
120+
example = tf.parse_single_example(raw_data, feature)
121+
122+
images = process_images(example, seq_length, image_size)
123+
poses = process_poses(example, seq_length)
124+
125+
return images.numpy().squeeze(), poses.numpy().squeeze()
126+
127+
128+
if __name__ == '__main__':
129+
tf.enable_eager_execution()
130+
parser = ap.ArgumentParser(description='Convert gqn tfrecords to gzip files.')
131+
parser.add_argument('base_dir', nargs=1,
132+
help='base directory of gqn dataset')
133+
parser.add_argument('dataset', nargs=1,
134+
help='datasets to convert, eg. shepard_metzler_5_parts')
135+
parser.add_argument('-b', '--batch-size', type=int, default=None,
136+
help='number of sequences in each output file')
137+
parser.add_argument('-n', '--first-n', type=int, default=None,
138+
help='convert only the first n tfrecords if given')
139+
parser.add_argument('-m', '--mode', type=str, default='train',
140+
help='whether to convert train or test')
141+
args = parser.parse_args()
142+
143+
base_dir = os.path.expanduser(args.base_dir[0])
144+
dataset = args.dataset[0]
145+
146+
print(f'base_dir: {base_dir}')
147+
print(f'dataset: {dataset}')
148+
149+
info = all_datasets[dataset]
150+
data_dir = os.path.join(base_dir, dataset)
151+
records = collect_files(os.path.join(data_dir, args.mode), '.tfrecord')
152+
153+
if args.first_n is not None:
154+
records = records[:args.first_n]
155+
156+
num_proc = mp.cpu_count()
157+
print(f'converting {len(records)} records in {dataset}/{args.mode}, with {num_proc} processes')
158+
159+
with mp.Pool(processes=num_proc) as pool:
160+
f = partial(convert_record, info=info, batch_size=args.batch_size)
161+
pool.map(f, records)
162+
163+
print('Done')

0 commit comments

Comments
 (0)