Skip to content

Commit c8d3f17

Browse files
committed
RecordReader for TFRecords
1 parent 9ba59ed commit c8d3f17

File tree

2 files changed

+211
-21
lines changed

2 files changed

+211
-21
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ target
55
travis-ci/travis_rsa
66
**/*.iml
77
.idea
8-
test_resources/io/actual.tfrecord
8+
test_resources/io/actual.tfrecord
9+
test_resources/io/roundtrip.tfrecord

src/io.rs

Lines changed: 209 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
//! A module for writing TFRecords, Tensorflow's preferred on-disk data format.
22
//!
33
//! See the [tensorflow docs](https://www.tensorflow.org/api_guides/python/python_io#tfrecords-format-details) for details of this format.
4+
use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
45

5-
use self::byteorder::WriteBytesExt;
6-
use byteorder;
76
use crc::crc32;
8-
use std::io;
9-
use std::io::Write;
7+
use std::{
8+
fmt, io,
9+
io::{Read, Write},
10+
};
11+
12+
fn mask(crc: u32) -> u32 {
13+
((crc >> 15) | (crc << 17)).wrapping_add(0xa282ead8u32)
14+
}
1015

1116
/// A type for writing bytes in the TFRecords format.
1217
#[derive(Debug)]
@@ -32,30 +37,181 @@ where
3237
uint32 masked_crc32_of_length
3338
byte data[length]
3439
uint32 masked_crc32_of_data
35-
and the records are concatenated together to produce the file. CRCs are described here [1], and the mask of a CRC is
36-
[1] https://en.wikipedia.org/wiki/Cyclic_redundancy_check
40+
and the records are concatenated together to produce the file. CRCs are described here [1],
41+
and the mask of a CRC is :
3742
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
43+
44+
[1] https://en.wikipedia.org/wiki/Cyclic_redundancy_check
3845
*/
3946
let mut len_bytes = [0u8; 8];
40-
(&mut len_bytes[..]).write_u64::<byteorder::LittleEndian>(bytes.len() as u64)?;
47+
(&mut len_bytes[..]).write_u64::<LittleEndian>(bytes.len() as u64)?;
4148

42-
let masked_len_crc32c = Self::mask(crc32::checksum_castagnoli(&len_bytes));
49+
let masked_len_crc32c = mask(crc32::checksum_castagnoli(&len_bytes));
4350
let mut len_crc32_bytes = [0u8; 4];
44-
(&mut len_crc32_bytes[..]).write_u32::<byteorder::LittleEndian>(masked_len_crc32c)?;
51+
(&mut len_crc32_bytes[..]).write_u32::<LittleEndian>(masked_len_crc32c)?;
4552

46-
let masked_bytes_crc32c = Self::mask(crc32::checksum_castagnoli(&bytes));
53+
let masked_bytes_crc32c = mask(crc32::checksum_castagnoli(&bytes));
4754
let mut bytes_crc32_bytes = [0u8; 4];
48-
(&mut bytes_crc32_bytes[..]).write_u32::<byteorder::LittleEndian>(masked_bytes_crc32c)?;
55+
(&mut bytes_crc32_bytes[..]).write_u32::<LittleEndian>(masked_bytes_crc32c)?;
4956

5057
self.writer.write(&len_bytes)?;
5158
self.writer.write(&len_crc32_bytes)?;
5259
self.writer.write(bytes)?;
5360
self.writer.write(&bytes_crc32_bytes)?;
5461
Ok(())
5562
}
63+
}
64+
65+
#[derive(Debug)]
66+
/// The possible errors from a record read attempt
67+
pub enum RecordReadError {
68+
/// Either the length of the content checksum didn't match - indicates data corruption
69+
InvalidChecksum,
70+
/// There was an underlying io error
71+
IoError { source: io::Error },
72+
/// The supplied buffer was too short to contain the next record
73+
BufferTooShort { needed: u64, had: u64 },
74+
}
75+
impl From<io::Error> for RecordReadError {
76+
fn from(from: io::Error) -> RecordReadError {
77+
RecordReadError::IoError { source: from }
78+
}
79+
}
80+
impl std::error::Error for RecordReadError {
81+
fn description(&self) -> &str {
82+
"There was an error reading the next record."
83+
}
84+
}
85+
86+
impl fmt::Display for RecordReadError {
87+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88+
write!(f, "{:?}", self)
89+
}
90+
}
91+
92+
/// A type for deserializing TFRecord formats
93+
#[derive(Debug)]
94+
pub struct RecordReader<R: Read> {
95+
reader: R,
96+
}
97+
98+
impl<R> RecordReader<R>
99+
where
100+
R: Read,
101+
{
102+
/// Construct a new RecordReader from an underlying Read.
103+
pub fn new(reader: R) -> Self {
104+
RecordReader { reader }
105+
}
106+
107+
fn read_next_len(&mut self) -> Result<Option<u64>, RecordReadError> {
108+
let len = match self.reader.read_u64::<LittleEndian>() {
109+
Err(e) => {
110+
if e.kind() == io::ErrorKind::UnexpectedEof {
111+
return Ok(None);
112+
}
113+
return Err(e.into());
114+
}
115+
Ok(val) => val,
116+
};
117+
118+
let mut len_bytes = [0u8; 8];
119+
LittleEndian::write_u64(&mut len_bytes, len);
120+
121+
let expected_len_crc32 = self.reader.read_u32::<LittleEndian>()?;
122+
let actual_len_crc32 = mask(crc32::checksum_castagnoli(&len_bytes));
123+
if expected_len_crc32 != actual_len_crc32 {
124+
return Err(RecordReadError::InvalidChecksum);
125+
}
126+
Ok(Some(len))
127+
}
128+
129+
fn checksum_bytes(&mut self, bytes: &[u8]) -> Result<(), RecordReadError> {
130+
let actual_bytes_crc32 = mask(crc32::checksum_castagnoli(&bytes));
131+
let expected_bytes_crc32 = self.reader.read_u32::<LittleEndian>()?;
132+
if actual_bytes_crc32 != expected_bytes_crc32 {
133+
return Err(RecordReadError::InvalidChecksum);
134+
}
135+
Ok(())
136+
}
137+
fn read_bytes_exact(&mut self, len: u64, buf: &mut [u8]) -> Result<(), RecordReadError> {
138+
if (buf.len() as u64) < len {
139+
return Err(RecordReadError::BufferTooShort {
140+
needed: len,
141+
had: buf.len() as u64,
142+
});
143+
}
144+
self.reader.read_exact(buf)?;
145+
self.checksum_bytes(buf)?;
146+
Ok(())
147+
}
148+
/// Read the next record into a byte slice.
149+
/// Returns the number of bytes read, if successful.
150+
/// Returns None, if it could read exactly 0 bytes (indicating EOF)
151+
/// let file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
152+
/// // Optional, but we probably want to buffer our reads.
153+
/// let file = std::io::BufReader::new(file);
154+
/// let records = RecordReader::new(file);
155+
/// // If we know in advance the maximum length of the records, we can avoid heap allocs
156+
/// let mut buf = [ 0u8; 12000 ];
157+
/// loop {
158+
/// let next = records.read_next(&mut buf);
159+
/// match next {
160+
/// Ok(res) => match res {
161+
/// Some(len) => train(&buf[0..len]),
162+
/// None => break,
163+
/// }
164+
/// Err(e) => { warn!("{:?}", e); break }
165+
/// }
166+
/// }
167+
pub fn read_next(&mut self, buf: &mut [u8]) -> Result<Option<u64>, RecordReadError> {
168+
let len = match self.read_next_len()? {
169+
Some(len) => len,
170+
None => return Ok(None),
171+
};
172+
173+
let slice = &mut buf[0..len as usize];
174+
self.read_bytes_exact(len, slice)?;
175+
176+
Ok(Some(len))
177+
}
178+
/// Allocate a Vec<u8> on the heap and read the next record into it.
179+
/// Returns the filled Vec, if successful.
180+
/// Returns None, if it could read exactly 0 bytes (indicating EOF)
181+
pub fn read_next_owned(&mut self) -> Result<Option<Vec<u8>>, RecordReadError> {
182+
let len = match self.read_next_len()? {
183+
Some(len) => len,
184+
None => return Ok(None),
185+
};
186+
let mut vec = vec![0u8; len as usize];
187+
self.read_bytes_exact(len, &mut vec)?;
188+
189+
Ok(Some(vec))
190+
}
191+
/// Convert the Reader into an Iterator<Item = Result<Vec<u8>, RecordReadError>, which iterates
192+
/// the whole Read.
193+
/// let file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
194+
/// // Optional, but we probably want to buffer our reads from a system file.
195+
/// let file = std::io::BufReader::new(file);
196+
/// for tfrecord in RecordReader::new(file).into_iter_owned() {
197+
/// match tfrecord {
198+
/// Ok(record) => train(record),
199+
/// Err(e) => warn!("Corrupted data?")
200+
/// }
201+
/// }
202+
pub fn into_iter_owned(self) -> impl Iterator<Item = Result<Vec<u8>, RecordReadError>> {
203+
RecordOwnedIterator { records: self }
204+
}
205+
}
206+
207+
struct RecordOwnedIterator<R: Read> {
208+
records: RecordReader<R>,
209+
}
56210

57-
fn mask(crc: u32) -> u32 {
58-
((crc >> 15) | (crc << 17)).wrapping_add(0xa282ead8u32)
211+
impl<R: Read> Iterator for RecordOwnedIterator<R> {
212+
type Item = Result<Vec<u8>, RecordReadError>;
213+
fn next(&mut self) -> Option<Self::Item> {
214+
self.records.read_next_owned().transpose()
59215
}
60216
}
61217

@@ -84,14 +240,47 @@ mod tests {
84240
.write_record("The Quick Brown Fox".as_bytes())
85241
.unwrap();
86242
}
243+
{
244+
let mut af = File::open(actual_filename).unwrap();
245+
let mut ef = File::open(expected_filename).unwrap();
246+
let mut actual = vec![0; 0];
247+
let mut expected = vec![0; 0];
248+
af.read_to_end(&mut actual).unwrap();
249+
ef.read_to_end(&mut expected).unwrap();
87250

88-
let mut af = File::open(actual_filename).unwrap();
89-
let mut ef = File::open(expected_filename).unwrap();
90-
let mut actual = vec![0; 0];
91-
let mut expected = vec![0; 0];
92-
af.read_to_end(&mut actual).unwrap();
93-
ef.read_to_end(&mut expected).unwrap();
251+
assert_eq!(actual, expected);
252+
}
253+
254+
let _ = std::fs::remove_file(actual_filename);
255+
}
94256

95-
assert_eq!(actual, expected);
257+
#[test]
258+
fn read_and_write_roundtrip() {
259+
let records = vec!["Foo bar baz", "boom bing bang", "sum soup shennaninganner"];
260+
let path = "test_resources/io/roundtrip.tfrecord";
261+
let out = ::std::fs::OpenOptions::new()
262+
.write(true)
263+
.create(true)
264+
.open(path)
265+
.unwrap();
266+
{
267+
let mut writer = RecordWriter::new(out);
268+
for rec in records.iter() {
269+
writer.write_record(rec.as_bytes()).unwrap();
270+
}
271+
}
272+
{
273+
let actual = ::std::fs::OpenOptions::new().read(true).open(path).unwrap();
274+
let reader = RecordReader::new(actual);
275+
for (actual, expected) in reader.into_iter_owned().zip(records) {
276+
assert_eq!(actual.unwrap(), expected.as_bytes());
277+
}
278+
}
279+
{
280+
let actual = ::std::fs::OpenOptions::new().read(true).open(path).unwrap();
281+
let reader = RecordReader::new(actual);
282+
assert_eq!(reader.into_iter_owned().count(), 3);
283+
}
284+
let _ = std::fs::remove_file(path);
96285
}
97286
}

0 commit comments

Comments
 (0)