Skip to content

Commit 2c04a2e

Browse files
committed
Add macros for building TLV (de)serializers.
There's quite a bit of machinery included here, but it neatly avoids any dynamic allocation during TLV deserialization, and the calling side looks nice and simple. There's a few new state-tracking read/write streams, but they should be pretty cheap (just a few increments/decrements per read/write. The macro-generated code is pretty nice, though has some redundant if statements (I haven't checked if they get optimized out yet, but I can't imagine they don't).
1 parent 5c5b074 commit 2c04a2e

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

lightning/src/util/ser.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::io::{Read, Write};
66
use std::collections::HashMap;
77
use std::hash::Hash;
88
use std::sync::Mutex;
9+
use std::cmp;
910

1011
use secp256k1::Signature;
1112
use secp256k1::key::{PublicKey, SecretKey};
@@ -67,6 +68,46 @@ impl Writer for VecWriter {
6768
}
6869
}
6970

71+
pub(crate) struct LengthCalculatingWriter(pub usize);
72+
impl Writer for LengthCalculatingWriter {
73+
#[inline]
74+
fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
75+
self.0 += buf.len();
76+
Ok(())
77+
}
78+
#[inline]
79+
fn size_hint(&mut self, _size: usize) {}
80+
}
81+
82+
/// Essentially std::io::Take but a bit simpler and exposing the amount read at the end, cause we
83+
/// may need to skip ahead that much at the end.
84+
pub(crate) struct FixedLengthReader<R: Read> {
85+
pub read: R,
86+
pub read_len: u64,
87+
pub max_len: u64,
88+
}
89+
impl<R: Read> FixedLengthReader<R> {
90+
pub fn eat_remaining(&mut self) -> Result<(), ::std::io::Error> {
91+
while self.read_len != self.max_len {
92+
debug_assert!(self.read_len < self.max_len);
93+
let mut buf = [0; 1024];
94+
let readsz = cmp::min(1024, self.max_len - self.read_len) as usize;
95+
self.read_exact(&mut buf[0..readsz])?;
96+
}
97+
Ok(())
98+
}
99+
}
100+
impl<R: Read> Read for FixedLengthReader<R> {
101+
fn read(&mut self, dest: &mut [u8]) -> Result<usize, ::std::io::Error> {
102+
if dest.len() as u64 > self.max_len - self.read_len {
103+
Ok(0)
104+
} else {
105+
self.read_len += dest.len() as u64;
106+
self.read.read(dest)
107+
}
108+
}
109+
}
110+
70111
/// A trait that various rust-lightning types implement allowing them to be written out to a Writer
71112
pub trait Writeable {
72113
/// Writes self out to the given Writer

lightning/src/util/ser_macros.rs

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,91 @@
1+
macro_rules! encode_tlv {
2+
($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
3+
use bitcoin::consensus::Encodable;
4+
use bitcoin::consensus::encode::{Error, VarInt};
5+
use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
6+
$(
7+
VarInt($type).consensus_encode(WriterWriteAdaptor($stream))
8+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
9+
let mut len_calc = LengthCalculatingWriter(0);
10+
$field.write(&mut len_calc)?;
11+
VarInt(len_calc.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
12+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
13+
$field.write($stream)?;
14+
)*
15+
} }
16+
}
17+
18+
macro_rules! encode_varint_length_prefixed_tlv {
19+
($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
20+
use bitcoin::consensus::Encodable;
21+
use bitcoin::consensus::encode::{Error, VarInt};
22+
use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
23+
let mut len = LengthCalculatingWriter(0);
24+
encode_tlv!(&mut len, {
25+
$(($type, $field)),*
26+
});
27+
VarInt(len.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
28+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
29+
encode_tlv!($stream, {
30+
$(($type, $field)),*
31+
});
32+
} }
33+
}
34+
35+
macro_rules! decode_tlv {
36+
($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
37+
use ln::msgs::DecodeError;
38+
let mut max_type: u64 = 0;
39+
'tlv_read: loop {
40+
use bitcoin::consensus::encode;
41+
use util::ser;
42+
use std;
43+
44+
let typ: encode::VarInt = match encode::Decodable::consensus_decode($stream) {
45+
Err(encode::Error::Io(ref ioe)) if ioe.kind() == std::io::ErrorKind::UnexpectedEof
46+
=> break 'tlv_read,
47+
Err(encode::Error::Io(ioe)) => Err(DecodeError::from(ioe))?,
48+
Err(_) => Err(DecodeError::InvalidValue)?,
49+
Ok(t) => t,
50+
};
51+
if typ.0 == std::u64::MAX || typ.0 + 1 <= max_type {
52+
Err(DecodeError::InvalidValue)?
53+
}
54+
$(if max_type < $reqtype + 1 && typ.0 > $reqtype {
55+
Err(DecodeError::InvalidValue)?
56+
})*
57+
max_type = typ.0 + 1;
58+
59+
let length: encode::VarInt = encode::Decodable::consensus_decode($stream)
60+
.map_err(|e| match e {
61+
encode::Error::Io(ioe) => DecodeError::from(ioe),
62+
_ => DecodeError::InvalidValue
63+
})?;
64+
let mut s = ser::FixedLengthReader {
65+
read: $stream,
66+
read_len: 0,
67+
max_len: length.0,
68+
};
69+
match typ.0 {
70+
$($reqtype => {
71+
$reqfield = ser::Readable::read(&mut s)?;
72+
},)*
73+
$($type => {
74+
$field = Some(ser::Readable::read(&mut s)?);
75+
},)*
76+
x if x % 2 == 0 => {
77+
Err(DecodeError::UnknownRequiredFeature)?
78+
},
79+
_ => {},
80+
}
81+
s.eat_remaining().map_err(|_| DecodeError::ShortRead)?;
82+
}
83+
$(if max_type < $reqtype + 1 {
84+
Err(DecodeError::InvalidValue)?
85+
})*
86+
} }
87+
}
88+
189
macro_rules! impl_writeable {
290
($st:ident, $len: expr, {$($field:ident),*}) => {
391
impl ::util::ser::Writeable for $st {
@@ -40,3 +128,73 @@ macro_rules! impl_writeable_len_match {
40128
}
41129
}
42130
}
131+
132+
#[cfg(test)]
133+
mod tests {
134+
use std::io::Cursor;
135+
use ln::msgs::DecodeError;
136+
137+
fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
138+
let mut s = Cursor::new(s);
139+
let mut a: u64 = 0;
140+
let mut b: u32 = 0;
141+
let mut c: Option<u32> = None;
142+
decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
143+
Ok((a, b, c))
144+
}
145+
#[test]
146+
fn test_tlv() {
147+
// Value for 3 is longer than we expect, but that's ok...
148+
assert_eq!(tlv_reader(&::hex::decode(
149+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d")
150+
).unwrap()[..]).unwrap(),
151+
(0xdeadbeef1badbeef, 0xdeadbeef, None));
152+
// ...even if there's something afterwards
153+
assert_eq!(tlv_reader(&::hex::decode(
154+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d", "0404ffffffff")
155+
).unwrap()[..]).unwrap(),
156+
(0xdeadbeef1badbeef, 0xdeadbeef, Some(0xffffffff)));
157+
// ...but not if that extra length is missing
158+
if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
159+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
160+
).unwrap()[..]) {
161+
} else { panic!(); }
162+
163+
// If they're out of order that's also bad
164+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
165+
concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
166+
).unwrap()[..]) {
167+
} else { panic!(); }
168+
// ...even if its some field we don't understand
169+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
170+
concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
171+
).unwrap()[..]) {
172+
} else { panic!(); }
173+
174+
// It's also bad if they included even fields we don't understand
175+
if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
176+
concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
177+
).unwrap()[..]) {
178+
} else { panic!(); }
179+
// ... or if they're missing fields we need
180+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
181+
concat!("0100", "0208deadbeef1badbeef")
182+
).unwrap()[..]) {
183+
} else { panic!(); }
184+
// ... even if that field is even
185+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
186+
concat!("0304deadbeef", "0500")
187+
).unwrap()[..]) {
188+
} else { panic!(); }
189+
190+
// But usually things are pretty much what we expect:
191+
assert_eq!(tlv_reader(&::hex::decode(
192+
concat!("0208deadbeef1badbeef", "03041bad1dea")
193+
).unwrap()[..]).unwrap(),
194+
(0xdeadbeef1badbeef, 0x1bad1dea, None));
195+
assert_eq!(tlv_reader(&::hex::decode(
196+
concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
197+
).unwrap()[..]).unwrap(),
198+
(0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
199+
}
200+
}

0 commit comments

Comments
 (0)