1
1
//! A module for writing TFRecords, Tensorflow's preferred on-disk data format.
2
2
//!
3
3
//! See the [tensorflow docs](https://www.tensorflow.org/api_guides/python/python_io#tfrecords-format-details) for details of this format.
4
+ use byteorder:: { ByteOrder , LittleEndian , ReadBytesExt , WriteBytesExt } ;
4
5
5
- use self :: byteorder:: WriteBytesExt ;
6
- use byteorder;
7
6
use crc:: crc32;
8
- use std:: io;
9
- use std:: io:: Write ;
7
+ use std:: {
8
+ fmt, io,
9
+ io:: { Read , Write } ,
10
+ } ;
11
+
12
+ fn mask ( crc : u32 ) -> u32 {
13
+ ( ( crc >> 15 ) | ( crc << 17 ) ) . wrapping_add ( 0xa282ead8u32 )
14
+ }
10
15
11
16
/// A type for writing bytes in the TFRecords format.
12
17
#[ derive( Debug ) ]
@@ -32,30 +37,181 @@ where
32
37
uint32 masked_crc32_of_length
33
38
byte data[length]
34
39
uint32 masked_crc32_of_data
35
- and the records are concatenated together to produce the file. CRCs are described here [1], and the mask of a CRC is
36
- [1] https://en.wikipedia.org/wiki/Cyclic_redundancy_check
40
+ and the records are concatenated together to produce the file. CRCs are described here [1],
41
+ and the mask of a CRC is :
37
42
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
43
+
44
+ [1] https://en.wikipedia.org/wiki/Cyclic_redundancy_check
38
45
*/
39
46
let mut len_bytes = [ 0u8 ; 8 ] ;
40
- ( & mut len_bytes[ ..] ) . write_u64 :: < byteorder :: LittleEndian > ( bytes. len ( ) as u64 ) ?;
47
+ ( & mut len_bytes[ ..] ) . write_u64 :: < LittleEndian > ( bytes. len ( ) as u64 ) ?;
41
48
42
- let masked_len_crc32c = Self :: mask ( crc32:: checksum_castagnoli ( & len_bytes) ) ;
49
+ let masked_len_crc32c = mask ( crc32:: checksum_castagnoli ( & len_bytes) ) ;
43
50
let mut len_crc32_bytes = [ 0u8 ; 4 ] ;
44
- ( & mut len_crc32_bytes[ ..] ) . write_u32 :: < byteorder :: LittleEndian > ( masked_len_crc32c) ?;
51
+ ( & mut len_crc32_bytes[ ..] ) . write_u32 :: < LittleEndian > ( masked_len_crc32c) ?;
45
52
46
- let masked_bytes_crc32c = Self :: mask ( crc32:: checksum_castagnoli ( & bytes) ) ;
53
+ let masked_bytes_crc32c = mask ( crc32:: checksum_castagnoli ( & bytes) ) ;
47
54
let mut bytes_crc32_bytes = [ 0u8 ; 4 ] ;
48
- ( & mut bytes_crc32_bytes[ ..] ) . write_u32 :: < byteorder :: LittleEndian > ( masked_bytes_crc32c) ?;
55
+ ( & mut bytes_crc32_bytes[ ..] ) . write_u32 :: < LittleEndian > ( masked_bytes_crc32c) ?;
49
56
50
57
self . writer . write ( & len_bytes) ?;
51
58
self . writer . write ( & len_crc32_bytes) ?;
52
59
self . writer . write ( bytes) ?;
53
60
self . writer . write ( & bytes_crc32_bytes) ?;
54
61
Ok ( ( ) )
55
62
}
63
+ }
64
+
65
+ #[ derive( Debug ) ]
66
+ /// The possible errors from a record read attempt
67
+ pub enum RecordReadError {
68
+ /// Either the length of the content checksum didn't match - indicates data corruption
69
+ InvalidChecksum ,
70
+ /// There was an underlying io error
71
+ IoError { source : io:: Error } ,
72
+ /// The supplied buffer was too short to contain the next record
73
+ BufferTooShort { needed : u64 , had : u64 } ,
74
+ }
75
+ impl From < io:: Error > for RecordReadError {
76
+ fn from ( from : io:: Error ) -> RecordReadError {
77
+ RecordReadError :: IoError { source : from }
78
+ }
79
+ }
80
+ impl std:: error:: Error for RecordReadError {
81
+ fn description ( & self ) -> & str {
82
+ "There was an error reading the next record."
83
+ }
84
+ }
85
+
86
+ impl fmt:: Display for RecordReadError {
87
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
88
+ write ! ( f, "{:?}" , self )
89
+ }
90
+ }
91
+
92
+ /// A type for deserializing TFRecord formats
93
+ #[ derive( Debug ) ]
94
+ pub struct RecordReader < R : Read > {
95
+ reader : R ,
96
+ }
97
+
98
+ impl < R > RecordReader < R >
99
+ where
100
+ R : Read ,
101
+ {
102
+ /// Construct a new RecordReader from an underlying Read.
103
+ pub fn new ( reader : R ) -> Self {
104
+ RecordReader { reader }
105
+ }
106
+
107
+ fn read_next_len ( & mut self ) -> Result < Option < u64 > , RecordReadError > {
108
+ let len = match self . reader . read_u64 :: < LittleEndian > ( ) {
109
+ Err ( e) => {
110
+ if e. kind ( ) == io:: ErrorKind :: UnexpectedEof {
111
+ return Ok ( None ) ;
112
+ }
113
+ return Err ( e. into ( ) ) ;
114
+ }
115
+ Ok ( val) => val,
116
+ } ;
117
+
118
+ let mut len_bytes = [ 0u8 ; 8 ] ;
119
+ LittleEndian :: write_u64 ( & mut len_bytes, len) ;
120
+
121
+ let expected_len_crc32 = self . reader . read_u32 :: < LittleEndian > ( ) ?;
122
+ let actual_len_crc32 = mask ( crc32:: checksum_castagnoli ( & len_bytes) ) ;
123
+ if expected_len_crc32 != actual_len_crc32 {
124
+ return Err ( RecordReadError :: InvalidChecksum ) ;
125
+ }
126
+ Ok ( Some ( len) )
127
+ }
128
+
129
+ fn checksum_bytes ( & mut self , bytes : & [ u8 ] ) -> Result < ( ) , RecordReadError > {
130
+ let actual_bytes_crc32 = mask ( crc32:: checksum_castagnoli ( & bytes) ) ;
131
+ let expected_bytes_crc32 = self . reader . read_u32 :: < LittleEndian > ( ) ?;
132
+ if actual_bytes_crc32 != expected_bytes_crc32 {
133
+ return Err ( RecordReadError :: InvalidChecksum ) ;
134
+ }
135
+ Ok ( ( ) )
136
+ }
137
+ fn read_bytes_exact ( & mut self , len : u64 , buf : & mut [ u8 ] ) -> Result < ( ) , RecordReadError > {
138
+ if ( buf. len ( ) as u64 ) < len {
139
+ return Err ( RecordReadError :: BufferTooShort {
140
+ needed : len,
141
+ had : buf. len ( ) as u64 ,
142
+ } ) ;
143
+ }
144
+ self . reader . read_exact ( buf) ?;
145
+ self . checksum_bytes ( buf) ?;
146
+ Ok ( ( ) )
147
+ }
148
+ /// Read the next record into a byte slice.
149
+ /// Returns the number of bytes read, if successful.
150
+ /// Returns None, if it could read exactly 0 bytes (indicating EOF)
151
+ /// let file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
152
+ /// // Optional, but we probably want to buffer our reads.
153
+ /// let file = std::io::BufReader::new(file);
154
+ /// let records = RecordReader::new(file);
155
+ /// // If we know in advance the maximum length of the records, we can avoid heap allocs
156
+ /// let mut buf = [ 0u8; 12000 ];
157
+ /// loop {
158
+ /// let next = records.read_next(&mut buf);
159
+ /// match next {
160
+ /// Ok(res) => match res {
161
+ /// Some(len) => train(&buf[0..len]),
162
+ /// None => break,
163
+ /// }
164
+ /// Err(e) => { warn!("{:?}", e); break }
165
+ /// }
166
+ /// }
167
+ pub fn read_next ( & mut self , buf : & mut [ u8 ] ) -> Result < Option < u64 > , RecordReadError > {
168
+ let len = match self . read_next_len ( ) ? {
169
+ Some ( len) => len,
170
+ None => return Ok ( None ) ,
171
+ } ;
172
+
173
+ let slice = & mut buf[ 0 ..len as usize ] ;
174
+ self . read_bytes_exact ( len, slice) ?;
175
+
176
+ Ok ( Some ( len) )
177
+ }
178
+ /// Allocate a Vec<u8> on the heap and read the next record into it.
179
+ /// Returns the filled Vec, if successful.
180
+ /// Returns None, if it could read exactly 0 bytes (indicating EOF)
181
+ pub fn read_next_owned ( & mut self ) -> Result < Option < Vec < u8 > > , RecordReadError > {
182
+ let len = match self . read_next_len ( ) ? {
183
+ Some ( len) => len,
184
+ None => return Ok ( None ) ,
185
+ } ;
186
+ let mut vec = vec ! [ 0u8 ; len as usize ] ;
187
+ self . read_bytes_exact ( len, & mut vec) ?;
188
+
189
+ Ok ( Some ( vec) )
190
+ }
191
+ /// Convert the Reader into an Iterator<Item = Result<Vec<u8>, RecordReadError>, which iterates
192
+ /// the whole Read.
193
+ /// let file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
194
+ /// // Optional, but we probably want to buffer our reads from a system file.
195
+ /// let file = std::io::BufReader::new(file);
196
+ /// for tfrecord in RecordReader::new(file).into_iter_owned() {
197
+ /// match tfrecord {
198
+ /// Ok(record) => train(record),
199
+ /// Err(e) => warn!("Corrupted data?")
200
+ /// }
201
+ /// }
202
+ pub fn into_iter_owned ( self ) -> impl Iterator < Item = Result < Vec < u8 > , RecordReadError > > {
203
+ RecordOwnedIterator { records : self }
204
+ }
205
+ }
206
+
207
+ struct RecordOwnedIterator < R : Read > {
208
+ records : RecordReader < R > ,
209
+ }
56
210
57
- fn mask ( crc : u32 ) -> u32 {
58
- ( ( crc >> 15 ) | ( crc << 17 ) ) . wrapping_add ( 0xa282ead8u32 )
211
+ impl < R : Read > Iterator for RecordOwnedIterator < R > {
212
+ type Item = Result < Vec < u8 > , RecordReadError > ;
213
+ fn next ( & mut self ) -> Option < Self :: Item > {
214
+ self . records . read_next_owned ( ) . transpose ( )
59
215
}
60
216
}
61
217
@@ -84,14 +240,47 @@ mod tests {
84
240
. write_record ( "The Quick Brown Fox" . as_bytes ( ) )
85
241
. unwrap ( ) ;
86
242
}
243
+ {
244
+ let mut af = File :: open ( actual_filename) . unwrap ( ) ;
245
+ let mut ef = File :: open ( expected_filename) . unwrap ( ) ;
246
+ let mut actual = vec ! [ 0 ; 0 ] ;
247
+ let mut expected = vec ! [ 0 ; 0 ] ;
248
+ af. read_to_end ( & mut actual) . unwrap ( ) ;
249
+ ef. read_to_end ( & mut expected) . unwrap ( ) ;
87
250
88
- let mut af = File :: open ( actual_filename) . unwrap ( ) ;
89
- let mut ef = File :: open ( expected_filename) . unwrap ( ) ;
90
- let mut actual = vec ! [ 0 ; 0 ] ;
91
- let mut expected = vec ! [ 0 ; 0 ] ;
92
- af. read_to_end ( & mut actual) . unwrap ( ) ;
93
- ef. read_to_end ( & mut expected) . unwrap ( ) ;
251
+ assert_eq ! ( actual, expected) ;
252
+ }
253
+
254
+ let _ = std:: fs:: remove_file ( actual_filename) ;
255
+ }
94
256
95
- assert_eq ! ( actual, expected) ;
257
+ #[ test]
258
+ fn read_and_write_roundtrip ( ) {
259
+ let records = vec ! [ "Foo bar baz" , "boom bing bang" , "sum soup shennaninganner" ] ;
260
+ let path = "test_resources/io/roundtrip.tfrecord" ;
261
+ let out = :: std:: fs:: OpenOptions :: new ( )
262
+ . write ( true )
263
+ . create ( true )
264
+ . open ( path)
265
+ . unwrap ( ) ;
266
+ {
267
+ let mut writer = RecordWriter :: new ( out) ;
268
+ for rec in records. iter ( ) {
269
+ writer. write_record ( rec. as_bytes ( ) ) . unwrap ( ) ;
270
+ }
271
+ }
272
+ {
273
+ let actual = :: std:: fs:: OpenOptions :: new ( ) . read ( true ) . open ( path) . unwrap ( ) ;
274
+ let reader = RecordReader :: new ( actual) ;
275
+ for ( actual, expected) in reader. into_iter_owned ( ) . zip ( records) {
276
+ assert_eq ! ( actual. unwrap( ) , expected. as_bytes( ) ) ;
277
+ }
278
+ }
279
+ {
280
+ let actual = :: std:: fs:: OpenOptions :: new ( ) . read ( true ) . open ( path) . unwrap ( ) ;
281
+ let reader = RecordReader :: new ( actual) ;
282
+ assert_eq ! ( reader. into_iter_owned( ) . count( ) , 3 ) ;
283
+ }
284
+ let _ = std:: fs:: remove_file ( path) ;
96
285
}
97
286
}
0 commit comments