1
+ import os
2
+ import tensorflow as tf
3
+ import torch
4
+ import gzip
5
+ from functools import partial
6
+ from collections import namedtuple
7
+ import argparse as ap
8
+ import multiprocessing as mp
9
+
10
+
11
+ """
12
+ This file converts tfrecords in deepmind gqn dataset to gzip files. Each tfrecord will be converted
13
+ to a single gzip file (561-of-900.tfrecord -> 561-of-900.pt.gz).
14
+
15
+ Each gzip file contains a list of tuples, where each tuple is of (images, poses)
16
+ For example, when converting the shepard_metzler_5_parts dataset with batch_size of 32, the gzip
17
+ file contains a list of length 32, each tuple contains images (15,64,64,3) and poses (15,5), where
18
+ 15 is the sequence length.
19
+
20
+ In the original implementation, each sequence is converted to a gzip file, this results in more than
21
+ 800K small files on the disk. Here we choose to pack multiple sequences into one gzip file, thus
22
+ avoiding having too many small files. Note that the gqn implementation from wohlert
23
+ (https://github.com/wohlert/generative-query-network-pytorch) works with the original version. In
24
+ order for it to work with the new format, one can simply change (in wohlert gqn) batch_size to 1
25
+ and do a squeeze after the loader.
26
+
27
+ It is also recommended to remove the first 500 records of both shepard metzler dataset as they
28
+ only contain 20 sequences, compared to the last 400 records which contain 2000 sequences.
29
+
30
+ Example:
31
+ convert all records with all sequences in sm5 train (400 records, 2000 seq each)
32
+ python convert2file.py ~/gqn_dataset shepard_metzler_5_parts
33
+
34
+ Convert first 20 records with batch size of 128 in sm5 test
35
+ python convert2file.py ~/gqn_dataset shepard_metzler_5_parts -n 20 -b 128 -m test
36
+ """
37
+
38
+ tf .logging .set_verbosity (tf .logging .ERROR ) # disable annoying logging
39
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = "-1" # disable gpu
40
+
41
+ DatasetInfo = namedtuple ('DatasetInfo' , ['image_size' , 'seq_length' ])
42
+
43
+ all_datasets = dict (
44
+ jaco = DatasetInfo (image_size = 64 , seq_length = 11 ),
45
+ mazes = DatasetInfo (image_size = 84 , seq_length = 300 ),
46
+ rooms_free_camera_with_object_rotations = DatasetInfo (image_size = 128 , seq_length = 10 ),
47
+ rooms_ring_camera = DatasetInfo (image_size = 64 , seq_length = 10 ),
48
+ rooms_free_camera_no_object_rotations = DatasetInfo (image_size = 64 , seq_length = 10 ),
49
+ shepard_metzler_5_parts = DatasetInfo (image_size = 64 , seq_length = 15 ),
50
+ shepard_metzler_7_parts = DatasetInfo (image_size = 64 , seq_length = 15 )
51
+ )
52
+
53
+ _pose_dim = 5
54
+
55
+
56
+ def collect_files (path , ext = None , key = None ):
57
+ if key is None :
58
+ files = sorted (os .listdir (path ))
59
+ else :
60
+ files = sorted (os .listdir (path ), key = key )
61
+
62
+ if ext is not None :
63
+ files = [f for f in files if os .path .splitext (f )[- 1 ] == ext ]
64
+
65
+ return [os .path .join (path , fname ) for fname in files ]
66
+
67
+
68
+ def convert_record (record , info , batch_size = None ):
69
+ print (record )
70
+
71
+ path , filename = os .path .split (record )
72
+ basename = os .path .splitext (filename )[0 ]
73
+ scenes = process_record (record , info , batch_size )
74
+ # scenes is a list of tuples (image_seq, pose_seq)
75
+ out = os .path .join (path , f'{ basename } .pt.gz' )
76
+ save_to_disk (scenes , out )
77
+
78
+
79
+ def save_to_disk (scenes , path ):
80
+ with gzip .open (path , 'wb' ) as f :
81
+ torch .save (scenes , f )
82
+
83
+
84
+ def process_record (record , info , batch_size = None ):
85
+ engine = tf .python_io .tf_record_iterator (record )
86
+
87
+ scenes = []
88
+ for i , data in enumerate (engine ):
89
+ if i == batch_size :
90
+ break
91
+ scene = convert_to_numpy (data , info )
92
+ scenes .append (scene )
93
+
94
+ return scenes
95
+
96
+
97
+ def process_images (example , seq_length , image_size ):
98
+ """Instantiates the ops used to preprocess the frames data."""
99
+ images = tf .concat (example ['frames' ], axis = 0 )
100
+ images = tf .map_fn (tf .image .decode_jpeg , tf .reshape (images , [- 1 ]),
101
+ dtype = tf .uint8 , back_prop = False )
102
+ shape = (image_size , image_size , 3 )
103
+ images = tf .reshape (images , (- 1 , seq_length ) + shape )
104
+ return images
105
+
106
+
107
+ def process_poses (example , seq_length ):
108
+ """Instantiates the ops used to preprocess the cameras data."""
109
+ poses = example ['cameras' ]
110
+ poses = tf .reshape (poses , (- 1 , seq_length , _pose_dim ))
111
+ return poses
112
+
113
+
114
+ def convert_to_numpy (raw_data , info ):
115
+ seq_length = info .seq_length
116
+ image_size = info .image_size
117
+
118
+ feature = {'frames' : tf .FixedLenFeature (shape = seq_length , dtype = tf .string ),
119
+ 'cameras' : tf .FixedLenFeature (shape = seq_length * _pose_dim , dtype = tf .float32 )}
120
+ example = tf .parse_single_example (raw_data , feature )
121
+
122
+ images = process_images (example , seq_length , image_size )
123
+ poses = process_poses (example , seq_length )
124
+
125
+ return images .numpy ().squeeze (), poses .numpy ().squeeze ()
126
+
127
+
128
+ if __name__ == '__main__' :
129
+ tf .enable_eager_execution ()
130
+ parser = ap .ArgumentParser (description = 'Convert gqn tfrecords to gzip files.' )
131
+ parser .add_argument ('base_dir' , nargs = 1 ,
132
+ help = 'base directory of gqn dataset' )
133
+ parser .add_argument ('dataset' , nargs = 1 ,
134
+ help = 'datasets to convert, eg. shepard_metzler_5_parts' )
135
+ parser .add_argument ('-b' , '--batch-size' , type = int , default = None ,
136
+ help = 'number of sequences in each output file' )
137
+ parser .add_argument ('-n' , '--first-n' , type = int , default = None ,
138
+ help = 'convert only the first n tfrecords if given' )
139
+ parser .add_argument ('-m' , '--mode' , type = str , default = 'train' ,
140
+ help = 'whether to convert train or test' )
141
+ args = parser .parse_args ()
142
+
143
+ base_dir = os .path .expanduser (args .base_dir [0 ])
144
+ dataset = args .dataset [0 ]
145
+
146
+ print (f'base_dir: { base_dir } ' )
147
+ print (f'dataset: { dataset } ' )
148
+
149
+ info = all_datasets [dataset ]
150
+ data_dir = os .path .join (base_dir , dataset )
151
+ records = collect_files (os .path .join (data_dir , args .mode ), '.tfrecord' )
152
+
153
+ if args .first_n is not None :
154
+ records = records [:args .first_n ]
155
+
156
+ num_proc = mp .cpu_count ()
157
+ print (f'converting { len (records )} records in { dataset } /{ args .mode } , with { num_proc } processes' )
158
+
159
+ with mp .Pool (processes = num_proc ) as pool :
160
+ f = partial (convert_record , info = info , batch_size = args .batch_size )
161
+ pool .map (f , records )
162
+
163
+ print ('Done' )
0 commit comments