Skip to content

Commit df90668

Browse files
committed
Add macros for building TLV (de)serializers.
There's quite a bit of machinery included here, but it neatly avoids any dynamic allocation during TLV deserialization, and the calling side looks nice and simple. There's a few new state-tracking read/write streams, but they should be pretty cheap (just a few increments/decrements per read/write. The macro-generated code is pretty nice, though has some redundant if statements (I haven't checked if they get optimized out yet, but I can't imagine they don't).
1 parent 3c9538f commit df90668

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed

lightning/src/util/ser.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::io::{Read, Write};
66
use std::collections::HashMap;
77
use std::hash::Hash;
88
use std::sync::Mutex;
9+
use std::cmp;
910

1011
use secp256k1::Signature;
1112
use secp256k1::key::{PublicKey, SecretKey};
@@ -67,6 +68,46 @@ impl Writer for VecWriter {
6768
}
6869
}
6970

71+
pub(crate) struct LengthCalculatingWriter(pub usize);
72+
impl Writer for LengthCalculatingWriter {
73+
#[inline]
74+
fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
75+
self.0 += buf.len();
76+
Ok(())
77+
}
78+
#[inline]
79+
fn size_hint(&mut self, _size: usize) {}
80+
}
81+
82+
/// Essentially std::io::Take but a bit simpler and exposing the amount read at the end, cause we
83+
/// may need to skip ahead that much at the end.
84+
pub(crate) struct FixedLengthReader<R: Read> {
85+
pub read: R,
86+
pub read_len: u64,
87+
pub max_len: u64,
88+
}
89+
impl<R: Read> FixedLengthReader<R> {
90+
pub fn eat_remaining(&mut self) -> Result<(), ::std::io::Error> {
91+
while self.read_len != self.max_len {
92+
debug_assert!(self.read_len < self.max_len);
93+
let mut buf = [0; 1024];
94+
let readsz = cmp::min(1024, self.max_len - self.read_len) as usize;
95+
self.read_exact(&mut buf[0..readsz])?;
96+
}
97+
Ok(())
98+
}
99+
}
100+
impl<R: Read> Read for FixedLengthReader<R> {
101+
fn read(&mut self, dest: &mut [u8]) -> Result<usize, ::std::io::Error> {
102+
if dest.len() as u64 > self.max_len - self.read_len {
103+
Ok(0)
104+
} else {
105+
self.read_len += dest.len() as u64;
106+
self.read.read(dest)
107+
}
108+
}
109+
}
110+
70111
/// A trait that various rust-lightning types implement allowing them to be written out to a Writer
71112
pub trait Writeable {
72113
/// Writes self out to the given Writer

lightning/src/util/ser_macros.rs

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,104 @@
1+
macro_rules! encode_tlv {
2+
($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
3+
use bitcoin::consensus::Encodable;
4+
use bitcoin::consensus::encode::{Error, VarInt};
5+
use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
6+
$(
7+
VarInt($type).consensus_encode(WriterWriteAdaptor($stream))
8+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
9+
let mut len_calc = LengthCalculatingWriter(0);
10+
$field.write(&mut len_calc)?;
11+
VarInt(len_calc.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
12+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
13+
$field.write($stream)?;
14+
)*
15+
} }
16+
}
17+
18+
macro_rules! encode_varint_length_prefixed_tlv {
19+
($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
20+
use bitcoin::consensus::Encodable;
21+
use bitcoin::consensus::encode::{Error, VarInt};
22+
use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
23+
let mut len = LengthCalculatingWriter(0);
24+
{
25+
$(
26+
VarInt($type).consensus_encode(WriterWriteAdaptor(&mut len))
27+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
28+
let mut field_len = LengthCalculatingWriter(0);
29+
$field.write(&mut field_len)?;
30+
VarInt(field_len.0 as u64).consensus_encode(WriterWriteAdaptor(&mut len))
31+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
32+
len.0 += field_len.0;
33+
)*
34+
}
35+
36+
VarInt(len.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
37+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
38+
encode_tlv!($stream, {
39+
$(($type, $field)),*
40+
});
41+
} }
42+
}
43+
44+
macro_rules! decode_tlv {
45+
($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
46+
use ln::msgs::DecodeError;
47+
let mut last_seen_type: Option<u64> = None;
48+
'tlv_read: loop {
49+
use bitcoin::consensus::encode;
50+
use util::ser;
51+
use std;
52+
53+
let typ: encode::VarInt = match encode::Decodable::consensus_decode($stream) {
54+
Err(encode::Error::Io(ref ioe)) if ioe.kind() == std::io::ErrorKind::UnexpectedEof
55+
=> break 'tlv_read,
56+
Err(encode::Error::Io(ioe)) => Err(DecodeError::from(ioe))?,
57+
Err(_) => Err(DecodeError::InvalidValue)?,
58+
Ok(t) => t,
59+
};
60+
61+
match last_seen_type {
62+
Some(t) if typ.0 <= t => {
63+
Err(DecodeError::InvalidValue)?
64+
},
65+
_ => {},
66+
}
67+
$(if (last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype) && typ.0 > $reqtype {
68+
Err(DecodeError::InvalidValue)?
69+
})*
70+
last_seen_type = Some(typ.0);
71+
72+
let length: encode::VarInt = encode::Decodable::consensus_decode($stream)
73+
.map_err(|e| match e {
74+
encode::Error::Io(ioe) => DecodeError::from(ioe),
75+
_ => DecodeError::InvalidValue
76+
})?;
77+
let mut s = ser::FixedLengthReader {
78+
read: $stream,
79+
read_len: 0,
80+
max_len: length.0,
81+
};
82+
match typ.0 {
83+
$($reqtype => {
84+
$reqfield = ser::Readable::read(&mut s)?;
85+
},)*
86+
$($type => {
87+
$field = Some(ser::Readable::read(&mut s)?);
88+
},)*
89+
x if x % 2 == 0 => {
90+
Err(DecodeError::UnknownRequiredFeature)?
91+
},
92+
_ => {},
93+
}
94+
s.eat_remaining().map_err(|_| DecodeError::ShortRead)?;
95+
}
96+
$(if last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype {
97+
Err(DecodeError::InvalidValue)?
98+
})*
99+
} }
100+
}
101+
1102
macro_rules! impl_writeable {
2103
($st:ident, $len: expr, {$($field:ident),*}) => {
3104
impl ::util::ser::Writeable for $st {
@@ -40,3 +141,73 @@ macro_rules! impl_writeable_len_match {
40141
}
41142
}
42143
}
144+
145+
#[cfg(test)]
146+
mod tests {
147+
use std::io::Cursor;
148+
use ln::msgs::DecodeError;
149+
150+
fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
151+
let mut s = Cursor::new(s);
152+
let mut a: u64 = 0;
153+
let mut b: u32 = 0;
154+
let mut c: Option<u32> = None;
155+
decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
156+
Ok((a, b, c))
157+
}
158+
#[test]
159+
fn test_tlv() {
160+
// Value for 3 is longer than we expect, but that's ok...
161+
assert_eq!(tlv_reader(&::hex::decode(
162+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d")
163+
).unwrap()[..]).unwrap(),
164+
(0xdeadbeef1badbeef, 0xdeadbeef, None));
165+
// ...even if there's something afterwards
166+
assert_eq!(tlv_reader(&::hex::decode(
167+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d", "0404ffffffff")
168+
).unwrap()[..]).unwrap(),
169+
(0xdeadbeef1badbeef, 0xdeadbeef, Some(0xffffffff)));
170+
// ...but not if that extra length is missing
171+
if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
172+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
173+
).unwrap()[..]) {
174+
} else { panic!(); }
175+
176+
// If they're out of order that's also bad
177+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
178+
concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
179+
).unwrap()[..]) {
180+
} else { panic!(); }
181+
// ...even if its some field we don't understand
182+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
183+
concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
184+
).unwrap()[..]) {
185+
} else { panic!(); }
186+
187+
// It's also bad if they included even fields we don't understand
188+
if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
189+
concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
190+
).unwrap()[..]) {
191+
} else { panic!(); }
192+
// ... or if they're missing fields we need
193+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
194+
concat!("0100", "0208deadbeef1badbeef")
195+
).unwrap()[..]) {
196+
} else { panic!(); }
197+
// ... even if that field is even
198+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
199+
concat!("0304deadbeef", "0500")
200+
).unwrap()[..]) {
201+
} else { panic!(); }
202+
203+
// But usually things are pretty much what we expect:
204+
assert_eq!(tlv_reader(&::hex::decode(
205+
concat!("0208deadbeef1badbeef", "03041bad1dea")
206+
).unwrap()[..]).unwrap(),
207+
(0xdeadbeef1badbeef, 0x1bad1dea, None));
208+
assert_eq!(tlv_reader(&::hex::decode(
209+
concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
210+
).unwrap()[..]).unwrap(),
211+
(0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
212+
}
213+
}

0 commit comments

Comments
 (0)