From e9767f2b60e67cd0bcb686c6434f431d5f452814 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Mon, 24 Feb 2025 19:44:55 -0800 Subject: [PATCH 1/6] len serializer --- proptest-regressions/raw/test/mod.txt | 8 + src/raw/test/append.rs | 19 +- src/ser/mod.rs | 6 +- src/ser/raw/document_serializer.rs | 22 +- src/ser/raw/len_serializer.rs | 1430 +++++++++++++++++++++++++ src/ser/raw/mod.rs | 40 +- src/ser/raw/value_serializer.rs | 15 +- 7 files changed, 1467 insertions(+), 73 deletions(-) create mode 100644 proptest-regressions/raw/test/mod.txt create mode 100644 src/ser/raw/len_serializer.rs diff --git a/proptest-regressions/raw/test/mod.txt b/proptest-regressions/raw/test/mod.txt new file mode 100644 index 00000000..f8ee2aa2 --- /dev/null +++ b/proptest-regressions/raw/test/mod.txt @@ -0,0 +1,8 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7d265dcc185e765bd763a321052fec8f67887f8f16dca9781d0161bbb0f8fdb0 # shrinks to bson = Document({"": String("")}) +cc b931d167d43e92047967875bd28287e133c1464fcdae96025b5345e959f097fb # shrinks to bson = JavaScriptCodeWithScope { code: "", scope: Document({"": Document({"": Binary { subtype: BinaryOld, bytes: [0, 0, 0, 0] }})}) } diff --git a/src/raw/test/append.rs b/src/raw/test/append.rs index 147fa152..7894f164 100644 --- a/src/raw/test/append.rs +++ b/src/raw/test/append.rs @@ -1,22 +1,9 @@ use std::iter::FromIterator; use crate::{ - oid::ObjectId, - raw::RawJavaScriptCodeWithScope, - spec::BinarySubtype, - tests::LOCK, - Binary, - Bson, - DateTime, - DbPointer, - Decimal128, - Document, - JavaScriptCodeWithScope, - RawArrayBuf, - RawBson, - RawDocumentBuf, - Regex, - Timestamp, + oid::ObjectId, raw::RawJavaScriptCodeWithScope, spec::BinarySubtype, tests::LOCK, Binary, Bson, + DateTime, DbPointer, Decimal128, Document, JavaScriptCodeWithScope, RawArrayBuf, RawBson, + RawDocumentBuf, Regex, Timestamp, }; use pretty_assertions::assert_eq; diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 10fb8037..67ca4f16 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -201,7 +201,11 @@ pub fn to_vec(value: &T) -> Result> where T: Serialize, { - let mut serializer = raw::Serializer::new(); + let mut len_serializer = raw::len_serializer::Serializer::new(); + value.serialize(&mut len_serializer)?; + let lens = len_serializer.into_lens(); + println!("lens={:?}", &lens); + let mut serializer = raw::Serializer::new(lens.into_iter()); #[cfg(feature = "serde_path_to_error")] { serde_path_to_error::serialize(value, &mut serializer).map_err(Error::with_path)?; diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs index 114a11e9..cc3caad4 100644 --- a/src/ser/raw/document_serializer.rs +++ b/src/ser/raw/document_serializer.rs @@ -1,32 +1,24 @@ use serde::{ser::Impossible, Serialize}; use crate::{ - ser::{write_cstring, write_i32, Error, Result}, - to_bson, - Bson, + ser::{write_cstring, Error, Result}, + to_bson, Bson, }; use super::Serializer; -pub(crate) struct DocumentSerializationResult<'a> { - pub(crate) root_serializer: &'a mut Serializer, -} - /// Serializer used to serialize document or array bodies. pub(crate) struct DocumentSerializer<'a> { root_serializer: &'a mut Serializer, num_keys_serialized: usize, - start: usize, } impl<'a> DocumentSerializer<'a> { pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { - let start = rs.bytes.len(); - write_i32(&mut rs.bytes, 0)?; + rs.write_next_len()?; Ok(Self { root_serializer: rs, num_keys_serialized: 0, - start, }) } @@ -56,13 +48,9 @@ impl<'a> DocumentSerializer<'a> { Ok(()) } - pub(crate) fn end_doc(self) -> crate::ser::Result> { + pub(crate) fn end_doc(self) -> crate::ser::Result<()> { self.root_serializer.bytes.push(0); - let length = (self.root_serializer.bytes.len() - self.start) as i32; - self.root_serializer.replace_i32(self.start, length); - Ok(DocumentSerializationResult { - root_serializer: self.root_serializer, - }) + Ok(()) } } diff --git a/src/ser/raw/len_serializer.rs b/src/ser/raw/len_serializer.rs new file mode 100644 index 00000000..cd9bfece --- /dev/null +++ b/src/ser/raw/len_serializer.rs @@ -0,0 +1,1430 @@ +use serde::{ + ser::{Error as SerdeError, SerializeMap, SerializeStruct}, + Serialize, +}; + +use crate::{ + raw::{RAW_ARRAY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + ser::{Error, Result}, + serde_helpers::HUMAN_READABLE_NEWTYPE, + spec::{BinarySubtype, ElementType}, + uuid::UUID_NEWTYPE_NAME, +}; + +/// Serializer used to convert a type `T` into raw BSON bytes. +pub(crate) struct Serializer { + /// Length of all documents visited by the serializer in the order in which they are serialized. + /// The length of the root document will always appear at index zero. + lens: Vec, + + /// Index of each document and sub-document we are computing the length of. + /// For well-formed serialization requests this will always contain at least one element. + lens_stack: Vec, + + /// Hint provided by the type being serialized. + hint: SerializerHint, + + human_readable: bool, +} + +/// Various bits of information that the serialized type can provide to the serializer to +/// inform the purpose of the next serialization step. +#[derive(Debug, Clone, Copy)] +enum SerializerHint { + None, + + /// The next call to `serialize_bytes` is for the purposes of serializing a UUID. + Uuid, + + /// The next call to `serialize_bytes` is for the purposes of serializing a raw document. + RawDocument, + + /// The next call to `serialize_bytes` is for the purposes of serializing a raw array. + RawArray, +} + +impl SerializerHint { + fn take(&mut self) -> SerializerHint { + std::mem::replace(self, SerializerHint::None) + } +} + +impl Serializer { + pub(crate) fn new() -> Self { + Self { + lens: vec![], + lens_stack: vec![], + hint: SerializerHint::None, + human_readable: false, + } + } + + pub(crate) fn into_lens(self) -> Vec { + assert!(self.lens_stack.is_empty()); + self.lens + } + + #[inline] + fn enter_doc(&mut self) { + let index = self.lens.len(); + self.lens.push(0); + self.lens_stack.push(index); + } + + #[inline] + fn exit_doc(&mut self) { + let index = self + .lens_stack + .pop() + .expect("document enter and exit are paired"); + self.lens[index] += 4 + 1; // i32 doc len + null terminator. + let len = self.lens[index]; + if let Some(parent_index) = self.lens_stack.last() { + // propagate length back up to parent, if present. + self.lens[*parent_index] += len; + } + } + + #[inline] + fn add_bytes(&mut self, bytes: i32) -> Result<()> { + if let Some(index) = self.lens_stack.last() { + self.lens[*index] += bytes; + Ok(()) + } else { + Err(Error::custom(format!( + "attempted to encode a non-document type at the top level", + ))) + } + } + + #[inline] + fn add_element_name_and_type(&mut self, len: usize) -> Result<()> { + // type + length + null terminator. + self.add_bytes(1 + len as i32 + 1) + } + + #[inline] + fn add_cstr_bytes(&mut self, len: usize) -> Result<()> { + self.add_bytes(len as i32 + 1) + } + + #[inline] + fn add_bin_bytes(&mut self, len: usize, subtype: BinarySubtype) -> Result<()> { + let total_len = if subtype == BinarySubtype::BinaryOld { + 4 + 1 + 4 + len as i32 + } else { + 4 + 1 + len as i32 + }; + self.add_bytes(total_len) + } + + #[inline] + fn add_str_bytes(&mut self, len: usize) -> Result<()> { + self.add_bytes(4 + len as i32 + 1) + } +} + +impl<'a> serde::Serializer for &'a mut Serializer { + type Ok = (); + type Error = Error; + + type SerializeSeq = DocumentSerializer<'a>; + type SerializeTuple = DocumentSerializer<'a>; + type SerializeTupleStruct = DocumentSerializer<'a>; + type SerializeTupleVariant = VariantSerializer<'a>; + type SerializeMap = DocumentSerializer<'a>; + type SerializeStruct = StructSerializer<'a>; + type SerializeStructVariant = VariantSerializer<'a>; + + fn is_human_readable(&self) -> bool { + self.human_readable + } + + #[inline] + fn serialize_bool(self, _v: bool) -> Result { + self.add_bytes(1) + } + + #[inline] + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i32(v.into()) + } + + #[inline] + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i32(v.into()) + } + + #[inline] + fn serialize_i32(self, _v: i32) -> Result { + self.add_bytes(4) + } + + #[inline] + fn serialize_i64(self, _v: i64) -> Result { + self.add_bytes(8) + } + + #[inline] + fn serialize_u8(self, v: u8) -> Result { + self.serialize_i32(v.into()) + } + + #[inline] + fn serialize_u16(self, v: u16) -> Result { + self.serialize_i32(v.into()) + } + + #[inline] + fn serialize_u32(self, v: u32) -> Result { + self.serialize_i64(v.into()) + } + + #[inline] + fn serialize_u64(self, _v: u64) -> Result { + self.add_bytes(8) + } + + #[inline] + fn serialize_f32(self, v: f32) -> Result { + self.serialize_f64(v.into()) + } + + #[inline] + fn serialize_f64(self, _v: f64) -> Result { + self.add_bytes(8) + } + + #[inline] + fn serialize_char(self, v: char) -> Result { + let mut s = String::new(); + s.push(v); + self.serialize_str(&s) + } + + #[inline] + fn serialize_str(self, v: &str) -> Result { + self.add_str_bytes(v.len()) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + match self.hint.take() { + SerializerHint::RawDocument | SerializerHint::RawArray => { + if self.lens_stack.is_empty() { + // The root document is raw in this case. + self.enter_doc(); + let result = self.add_bytes(v.len() as i32); + self.exit_doc(); + result + } else { + // We don't record these as docs as the lengths aren't computed from multiple inputs. + self.add_bytes(v.len() as i32) + } + } + // NB: in this path we would never emit BinaryOld. + _ => self.add_bin_bytes(v.len(), BinarySubtype::Generic), + } + } + + #[inline] + fn serialize_none(self) -> Result { + // this writes an ElementType::Null, which records 0 following bytes for the value. + Ok(()) + } + + #[inline] + fn serialize_some(self, value: &T) -> Result + where + T: serde::Serialize + ?Sized, + { + value.serialize(self) + } + + #[inline] + fn serialize_unit(self) -> Result { + self.serialize_none() + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result + where + T: serde::Serialize + ?Sized, + { + match name { + UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, + RAW_DOCUMENT_NEWTYPE => self.hint = SerializerHint::RawDocument, + RAW_ARRAY_NEWTYPE => self.hint = SerializerHint::RawArray, + HUMAN_READABLE_NEWTYPE => { + let old = self.human_readable; + self.human_readable = true; + let result = value.serialize(&mut *self); + self.human_readable = old; + return result; + } + _ => {} + } + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: serde::Serialize + ?Sized, + { + let mut d = DocumentSerializer::start(&mut *self)?; + d.serialize_entry(variant, value)?; + d.end_doc()?; + Ok(()) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + DocumentSerializer::start(&mut *self) + } + + #[inline] + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + VariantSerializer::start(&mut *self, variant) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + DocumentSerializer::start(&mut *self) + } + + #[inline] + fn serialize_struct(self, name: &'static str, _len: usize) -> Result { + let value_type = match name { + "$oid" => Some(ValueType::ObjectId), + "$date" => Some(ValueType::DateTime), + "$binary" => Some(ValueType::Binary), + "$timestamp" => Some(ValueType::Timestamp), + "$minKey" => Some(ValueType::MinKey), + "$maxKey" => Some(ValueType::MaxKey), + "$code" => Some(ValueType::JavaScriptCode), + "$codeWithScope" => Some(ValueType::JavaScriptCodeWithScope), + "$symbol" => Some(ValueType::Symbol), + "$undefined" => Some(ValueType::Undefined), + "$regularExpression" => Some(ValueType::RegularExpression), + "$dbPointer" => Some(ValueType::DbPointer), + "$numberDecimal" => Some(ValueType::Decimal128), + _ => None, + }; + + match value_type { + Some(vt) => Ok(StructSerializer::Value(ValueSerializer::new(self, vt))), + None => Ok(StructSerializer::Document(DocumentSerializer::start(self)?)), + } + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + VariantSerializer::start(&mut *self, variant) + } +} + +pub(crate) enum StructSerializer<'a> { + /// Serialize a BSON value currently represented in serde as a struct (e.g. ObjectId) + Value(ValueSerializer<'a>), + + /// Serialize the struct as a document. + Document(DocumentSerializer<'a>), +} + +impl SerializeStruct for StructSerializer<'_> { + type Ok = (); + type Error = Error; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + match self { + StructSerializer::Value(ref mut v) => (&mut *v).serialize_field(key, value), + StructSerializer::Document(d) => d.serialize_field(key, value), + } + } + + #[inline] + fn end(self) -> Result { + match self { + StructSerializer::Document(d) => SerializeStruct::end(d), + StructSerializer::Value(mut v) => v.end(), + } + } +} + +/// Serializer used for enum variants, including both tuple (e.g. Foo::Bar(1, 2, 3)) and +/// struct (e.g. Foo::Bar { a: 1 }). +pub(crate) struct VariantSerializer<'a> { + root_serializer: &'a mut Serializer, + + /// How many elements have been serialized in the inner document / array so far. + num_elements_serialized: usize, +} + +impl<'a> VariantSerializer<'a> { + fn start(rs: &'a mut Serializer, variant: &'static str) -> Result { + rs.enter_doc(); // outer doc for variant + rs.add_element_name_and_type(variant.len())?; + + rs.enter_doc(); // inner doc/array containing variant doc/tuple. + Ok(Self { + root_serializer: rs, + num_elements_serialized: 0, + }) + } + + #[inline] + fn serialize_element(&mut self, k: &str, v: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + self.root_serializer.add_element_name_and_type(k.len())?; + v.serialize(&mut *self.root_serializer)?; + self.num_elements_serialized += 1; + Ok(()) + } + + #[inline] + fn end_both(self) -> Result<()> { + self.root_serializer.exit_doc(); // close variant doc/array + self.root_serializer.exit_doc(); // close variant wrapper. + Ok(()) + } +} + +impl serde::ser::SerializeTupleVariant for VariantSerializer<'_> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + self.serialize_element(format!("{}", self.num_elements_serialized).as_str(), value) + } + + #[inline] + fn end(self) -> Result { + self.end_both() + } +} + +impl serde::ser::SerializeStructVariant for VariantSerializer<'_> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + self.serialize_element(key, value) + } + + #[inline] + fn end(self) -> Result { + self.end_both() + } +} + +use serde::ser::Impossible; + +use crate::{to_bson, Bson}; + +/// Serializer used to serialize document or array bodies. +pub(crate) struct DocumentSerializer<'a> { + root_serializer: &'a mut Serializer, + num_keys_serialized: usize, +} + +impl<'a> DocumentSerializer<'a> { + pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { + rs.enter_doc(); + Ok(Self { + root_serializer: rs, + num_keys_serialized: 0, + }) + } + + /// Serialize a document key using the provided closure. + fn serialize_doc_key_custom Result<()>>( + &mut self, + f: F, + ) -> Result<()> { + f(self.root_serializer)?; + self.num_keys_serialized += 1; + Ok(()) + } + + /// Serialize a document key to string using [`KeySerializer`]. + fn serialize_doc_key(&mut self, key: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + self.serialize_doc_key_custom(|rs| { + key.serialize(KeySerializer { + root_serializer: rs, + })?; + Ok(()) + })?; + Ok(()) + } + + pub(crate) fn end_doc(self) -> crate::ser::Result<&'a mut Serializer> { + self.root_serializer.exit_doc(); + Ok(self.root_serializer) + } +} + +impl serde::ser::SerializeSeq for DocumentSerializer<'_> { + type Ok = (); + type Error = Error; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + let index = self.num_keys_serialized; + self.serialize_doc_key_custom(|rs| rs.add_element_name_and_type(index.to_string().len()))?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl serde::ser::SerializeMap for DocumentSerializer<'_> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + self.serialize_doc_key(key) + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + value.serialize(&mut *self.root_serializer) + } + + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl serde::ser::SerializeStruct for DocumentSerializer<'_> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + self.serialize_doc_key(key)?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl serde::ser::SerializeTuple for DocumentSerializer<'_> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + self.serialize_doc_key(&self.num_keys_serialized.to_string())?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + self.serialize_doc_key(&self.num_keys_serialized.to_string())?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +/// Serializer used specifically for serializing document keys. +/// Only keys that serialize to strings will be accepted. +struct KeySerializer<'a> { + root_serializer: &'a mut Serializer, +} + +impl KeySerializer<'_> { + fn invalid_key(v: T) -> Error { + Error::InvalidDocumentKey(to_bson(&v).unwrap_or(Bson::Null)) + } +} + +impl serde::Serializer for KeySerializer<'_> { + type Ok = (); + + type Error = Error; + + type SerializeSeq = Impossible<(), Error>; + type SerializeTuple = Impossible<(), Error>; + type SerializeTupleStruct = Impossible<(), Error>; + type SerializeTupleVariant = Impossible<(), Error>; + type SerializeMap = Impossible<(), Error>; + type SerializeStruct = Impossible<(), Error>; + type SerializeStructVariant = Impossible<(), Error>; + + #[inline] + fn serialize_bool(self, v: bool) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i8(self, v: i8) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i16(self, v: i16) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i32(self, v: i32) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i64(self, v: i64) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u8(self, v: u8) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u16(self, v: u16) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u32(self, v: u32) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u64(self, v: u64) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_f32(self, v: f32) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_f64(self, v: f64) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_char(self, v: char) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_str(self, v: &str) -> Result { + self.root_serializer.add_element_name_and_type(v.len()) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_none(self) -> Result { + Err(Self::invalid_key(Bson::Null)) + } + + #[inline] + fn serialize_some(self, value: &T) -> Result + where + T: Serialize + ?Sized, + { + value.serialize(self) + } + + #[inline] + fn serialize_unit(self) -> Result { + Err(Self::invalid_key(Bson::Null)) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(Self::invalid_key(Bson::Null)) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + where + T: Serialize + ?Sized, + { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: Serialize + ?Sized, + { + Err(Self::invalid_key(value)) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + Err(Self::invalid_key(Bson::Array(vec![]))) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result { + Err(Self::invalid_key(Bson::Array(vec![]))) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Self::invalid_key(Bson::Array(vec![]))) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } +} + +use crate::{base64, RawDocument, RawJavaScriptCodeWithScopeRef}; + +/// A serializer used specifically for serializing the serde-data-model form of a BSON type (e.g. +/// [`Binary`]) to raw bytes. +pub(crate) struct ValueSerializer<'a> { + root_serializer: &'a mut Serializer, + state: SerializationStep, +} + +/// State machine used to track which step in the serialization of a given type the serializer is +/// currently on. +#[derive(Debug)] +enum SerializationStep { + Oid, + + DateTime, + DateTimeNumberLong, + + Binary, + /// This step can either transition to the raw or base64 steps depending + /// on whether a string or bytes are serialized. + BinaryBytes, + BinarySubType { + base64: String, + }, + RawBinarySubType { + bytes: Vec, + }, + + Symbol, + + RegEx, + RegExPattern, + RegExOptions, + + Timestamp, + TimestampTime, + TimestampIncrement, + + DbPointer, + DbPointerRef, + DbPointerId, + + Code, + + CodeWithScopeCode, + CodeWithScopeScope { + code: String, + raw: bool, + }, + + MinKey, + + MaxKey, + + Undefined, + + Decimal128, + Decimal128Value, + + Done, +} + +/// Enum of BSON "value" types that this serializer can serialize. +#[derive(Debug, Clone, Copy)] +pub(super) enum ValueType { + DateTime, + Binary, + ObjectId, + Symbol, + RegularExpression, + Timestamp, + DbPointer, + JavaScriptCode, + JavaScriptCodeWithScope, + MinKey, + MaxKey, + Decimal128, + Undefined, +} + +impl From for ElementType { + fn from(vt: ValueType) -> Self { + match vt { + ValueType::Binary => ElementType::Binary, + ValueType::DateTime => ElementType::DateTime, + ValueType::DbPointer => ElementType::DbPointer, + ValueType::Decimal128 => ElementType::Decimal128, + ValueType::Symbol => ElementType::Symbol, + ValueType::RegularExpression => ElementType::RegularExpression, + ValueType::Timestamp => ElementType::Timestamp, + ValueType::JavaScriptCode => ElementType::JavaScriptCode, + ValueType::JavaScriptCodeWithScope => ElementType::JavaScriptCodeWithScope, + ValueType::MaxKey => ElementType::MaxKey, + ValueType::MinKey => ElementType::MinKey, + ValueType::Undefined => ElementType::Undefined, + ValueType::ObjectId => ElementType::ObjectId, + } + } +} + +impl<'a> ValueSerializer<'a> { + pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { + let state = match value_type { + ValueType::DateTime => SerializationStep::DateTime, + ValueType::Binary => SerializationStep::Binary, + ValueType::ObjectId => SerializationStep::Oid, + ValueType::Symbol => SerializationStep::Symbol, + ValueType::RegularExpression => SerializationStep::RegEx, + ValueType::Timestamp => SerializationStep::Timestamp, + ValueType::DbPointer => SerializationStep::DbPointer, + ValueType::JavaScriptCode => SerializationStep::Code, + ValueType::JavaScriptCodeWithScope => SerializationStep::CodeWithScopeCode, + ValueType::MinKey => SerializationStep::MinKey, + ValueType::MaxKey => SerializationStep::MaxKey, + ValueType::Decimal128 => SerializationStep::Decimal128, + ValueType::Undefined => SerializationStep::Undefined, + }; + Self { + root_serializer: rs, + state, + } + } + + fn invalid_step(&self, primitive_type: &'static str) -> Error { + Error::custom(format!( + "cannot serialize {} at step {:?}", + primitive_type, self.state + )) + } +} + +impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { + type Ok = (); + type Error = Error; + + type SerializeSeq = Impossible<(), Error>; + type SerializeTuple = Impossible<(), Error>; + type SerializeTupleStruct = Impossible<(), Error>; + type SerializeTupleVariant = Impossible<(), Error>; + type SerializeMap = CodeWithScopeSerializer<'b>; + type SerializeStruct = Self; + type SerializeStructVariant = Impossible<(), Error>; + + #[inline] + fn serialize_bool(self, _v: bool) -> Result { + Err(self.invalid_step("bool")) + } + + #[inline] + fn serialize_i8(self, _v: i8) -> Result { + Err(self.invalid_step("i8")) + } + + #[inline] + fn serialize_i16(self, _v: i16) -> Result { + Err(self.invalid_step("i16")) + } + + #[inline] + fn serialize_i32(self, _v: i32) -> Result { + Err(self.invalid_step("i32")) + } + + #[inline] + fn serialize_i64(self, _v: i64) -> Result { + match self.state { + SerializationStep::TimestampTime => { + self.state = SerializationStep::TimestampIncrement; + Ok(()) + } + SerializationStep::TimestampIncrement => self.root_serializer.add_bytes(8), + _ => Err(self.invalid_step("i64")), + } + } + + #[inline] + fn serialize_u8(self, v: u8) -> Result { + match self.state { + SerializationStep::RawBinarySubType { ref bytes } => { + self.root_serializer.add_bin_bytes(bytes.len(), v.into())?; + self.state = SerializationStep::Done; + Ok(()) + } + _ => Err(self.invalid_step("u8")), + } + } + + #[inline] + fn serialize_u16(self, _v: u16) -> Result { + Err(self.invalid_step("u16")) + } + + #[inline] + fn serialize_u32(self, _v: u32) -> Result { + Err(self.invalid_step("u32")) + } + + #[inline] + fn serialize_u64(self, _v: u64) -> Result { + Err(self.invalid_step("u64")) + } + + #[inline] + fn serialize_f32(self, _v: f32) -> Result { + Err(self.invalid_step("f32")) + } + + #[inline] + fn serialize_f64(self, _v: f64) -> Result { + Err(self.invalid_step("f64")) + } + + #[inline] + fn serialize_char(self, _v: char) -> Result { + Err(self.invalid_step("char")) + } + + fn serialize_str(self, v: &str) -> Result { + match &self.state { + SerializationStep::DateTimeNumberLong => { + self.root_serializer.add_bytes(8)?; + } + SerializationStep::Oid => { + self.root_serializer.add_bytes(12)?; + } + SerializationStep::BinaryBytes => { + self.state = SerializationStep::BinarySubType { + base64: v.to_string(), + }; + } + SerializationStep::BinarySubType { base64 } => { + let subtype_byte = hex::decode(v).map_err(Error::custom)?; + let subtype: BinarySubtype = subtype_byte[0].into(); + let bytes = base64::decode(base64.as_str()).map_err(Error::custom)?; + self.root_serializer.add_bin_bytes(bytes.len(), subtype)?; + } + SerializationStep::Symbol | SerializationStep::DbPointerRef => { + self.root_serializer.add_str_bytes(v.len())?; + } + SerializationStep::RegExPattern => { + self.root_serializer.add_cstr_bytes(v.len())?; + } + SerializationStep::RegExOptions => { + self.root_serializer.add_cstr_bytes(v.len())?; + } + SerializationStep::Code => { + self.root_serializer.add_str_bytes(v.len())?; + } + SerializationStep::CodeWithScopeCode => { + self.state = SerializationStep::CodeWithScopeScope { + code: v.to_string(), + raw: false, + }; + } + s => { + return Err(Error::custom(format!( + "can't serialize string for step {:?}", + s + ))) + } + } + Ok(()) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + match self.state { + SerializationStep::Decimal128Value => self.root_serializer.add_bytes(16), + SerializationStep::BinaryBytes => { + self.state = SerializationStep::RawBinarySubType { bytes: v.to_vec() }; + Ok(()) + } + SerializationStep::CodeWithScopeScope { ref code, raw } if raw => { + let raw = RawJavaScriptCodeWithScopeRef { + code, + scope: RawDocument::from_bytes(v).map_err(Error::custom)?, + }; + self.root_serializer.add_bytes(4)?; + self.root_serializer.add_str_bytes(code.len())?; + self.root_serializer.add_bytes(raw.len())?; + self.state = SerializationStep::Done; + Ok(()) + } + _ => Err(self.invalid_step("&[u8]")), + } + } + + #[inline] + fn serialize_none(self) -> Result { + Err(self.invalid_step("none")) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize + ?Sized, + { + Err(self.invalid_step("some")) + } + + #[inline] + fn serialize_unit(self) -> Result { + Err(self.invalid_step("unit")) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(self.invalid_step("unit_struct")) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(self.invalid_step("unit_variant")) + } + + #[inline] + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result + where + T: Serialize + ?Sized, + { + match (&mut self.state, name) { + ( + SerializationStep::CodeWithScopeScope { + code: _, + ref mut raw, + }, + RAW_DOCUMENT_NEWTYPE, + ) => { + *raw = true; + value.serialize(self) + } + _ => Err(self.invalid_step("newtype_struct")), + } + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize + ?Sized, + { + Err(self.invalid_step("newtype_variant")) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + Err(self.invalid_step("seq")) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result { + Err(self.invalid_step("newtype_tuple")) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(self.invalid_step("tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(self.invalid_step("tuple_variant")) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + match self.state { + SerializationStep::CodeWithScopeScope { ref code, raw } if !raw => { + CodeWithScopeSerializer::start(code.as_str(), self.root_serializer) + } + _ => Err(self.invalid_step("map")), + } + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Ok(self) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(self.invalid_step("struct_variant")) + } + + fn is_human_readable(&self) -> bool { + false + } +} + +impl SerializeStruct for &mut ValueSerializer<'_> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + match (&self.state, key) { + (SerializationStep::DateTime, "$date") => { + self.state = SerializationStep::DateTimeNumberLong; + value.serialize(&mut **self)?; + } + (SerializationStep::DateTimeNumberLong, "$numberLong") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Oid, "$oid") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Binary, "$binary") => { + self.state = SerializationStep::BinaryBytes; + value.serialize(&mut **self)?; + } + (SerializationStep::BinaryBytes, key) if key == "bytes" || key == "base64" => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::RawBinarySubType { .. }, "subType") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::BinarySubType { .. }, "subType") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Symbol, "$symbol") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::RegEx, "$regularExpression") => { + self.state = SerializationStep::RegExPattern; + value.serialize(&mut **self)?; + } + (SerializationStep::RegExPattern, "pattern") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::RegExOptions; + } + (SerializationStep::RegExOptions, "options") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Timestamp, "$timestamp") => { + self.state = SerializationStep::TimestampTime; + value.serialize(&mut **self)?; + } + (SerializationStep::TimestampTime, "t") => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::TimestampIncrement { .. }, "i") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::DbPointer, "$dbPointer") => { + self.state = SerializationStep::DbPointerRef; + value.serialize(&mut **self)?; + } + (SerializationStep::DbPointerRef, "$ref") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::DbPointerId; + } + (SerializationStep::DbPointerId, "$id") => { + self.state = SerializationStep::Oid; + value.serialize(&mut **self)?; + } + (SerializationStep::Code, "$code") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::CodeWithScopeCode, "$code") => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::CodeWithScopeScope { .. }, "$scope") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::MinKey, "$minKey") => { + self.state = SerializationStep::Done; + } + (SerializationStep::MaxKey, "$maxKey") => { + self.state = SerializationStep::Done; + } + (SerializationStep::Undefined, "$undefined") => { + self.state = SerializationStep::Done; + } + (SerializationStep::Decimal128, "$numberDecimal") + | (SerializationStep::Decimal128, "$numberDecimalBytes") => { + self.state = SerializationStep::Decimal128Value; + value.serialize(&mut **self)?; + } + (SerializationStep::Decimal128Value, "$numberDecimal") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Done, k) => { + return Err(Error::custom(format!( + "expected to end serialization of type, got extra key \"{}\"", + k + ))); + } + (state, k) => { + return Err(Error::custom(format!( + "mismatched serialization step and next key: {:?} + \"{}\"", + state, k + ))); + } + } + + Ok(()) + } + + #[inline] + fn end(self) -> Result { + Ok(()) + } +} + +pub(crate) struct CodeWithScopeSerializer<'a> { + doc: DocumentSerializer<'a>, +} + +impl<'a> CodeWithScopeSerializer<'a> { + #[inline] + fn start(code: &str, rs: &'a mut Serializer) -> Result { + rs.enter_doc(); + rs.add_str_bytes(code.len())?; + + let doc = DocumentSerializer::start(rs)?; + Ok(Self { doc }) + } +} + +impl SerializeMap for CodeWithScopeSerializer<'_> { + type Ok = (); + type Error = Error; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + self.doc.serialize_key(key) + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + self.doc.serialize_value(value) + } + + #[inline] + fn end(self) -> Result { + let rs = self.doc.end_doc()?; + // code with scope does not have an additional null terminator. + rs.add_bytes(-1)?; + rs.exit_doc(); + Ok(()) + } +} diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 69b7320e..9be0bcbd 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -1,4 +1,5 @@ mod document_serializer; +pub(super) mod len_serializer; mod value_serializer; use std::io::Write; @@ -22,6 +23,8 @@ use document_serializer::DocumentSerializer; /// Serializer used to convert a type `T` into raw BSON bytes. pub(crate) struct Serializer { + lens: std::vec::IntoIter, + bytes: Vec, /// The index into `bytes` where the current element type will need to be stored. @@ -58,8 +61,9 @@ impl SerializerHint { } impl Serializer { - pub(crate) fn new() -> Self { + pub(crate) fn new(lens: std::vec::IntoIter) -> Self { Self { + lens, bytes: Vec::new(), type_index: 0, hint: SerializerHint::None, @@ -72,6 +76,11 @@ impl Serializer { self.bytes } + #[inline] + fn write_next_len(&mut self) -> Result<()> { + write_i32(&mut self.bytes, self.lens.next().expect("pre-recorded len")) + } + /// Reserve a spot for the element type to be set retroactively via `update_element_type`. #[inline] fn reserve_element_type(&mut self) { @@ -97,13 +106,6 @@ impl Serializer { self.bytes[self.type_index] = t as u8; Ok(()) } - - /// Replace an i32 value at the given index with the given value. - #[inline] - fn replace_i32(&mut self, at: usize, with: i32) { - let portion = &mut self.bytes[at..at + 4]; - portion.copy_from_slice(&with.to_le_bytes()); - } } impl<'a> serde::Serializer for &'a mut Serializer { @@ -425,13 +427,6 @@ enum VariantInnerType { pub(crate) struct VariantSerializer<'a> { root_serializer: &'a mut Serializer, - /// Variants are serialized as documents of the form `{ : }`, - /// and `doc_start` indicates the index at which the outer document begins. - doc_start: usize, - - /// `inner_start` indicates the index at which the inner document or array begins. - inner_start: usize, - /// How many elements have been serialized in the inner document / array so far. num_elements_serialized: usize, } @@ -442,9 +437,7 @@ impl<'a> VariantSerializer<'a> { variant: &'static str, inner_type: VariantInnerType, ) -> Result { - let doc_start = rs.bytes.len(); - // write placeholder length for document, will be updated at end - write_i32(&mut rs.bytes, 0)?; + rs.write_next_len()?; let inner = match inner_type { VariantInnerType::Struct => ElementType::EmbeddedDocument, @@ -452,15 +445,12 @@ impl<'a> VariantSerializer<'a> { }; rs.bytes.push(inner as u8); write_cstring(&mut rs.bytes, variant)?; - let inner_start = rs.bytes.len(); // write placeholder length for inner, will be updated at end - write_i32(&mut rs.bytes, 0)?; + rs.write_next_len()?; Ok(Self { root_serializer: rs, num_elements_serialized: 0, - doc_start, - inner_start, }) } @@ -481,14 +471,8 @@ impl<'a> VariantSerializer<'a> { fn end_both(self) -> Result<()> { // null byte for the inner self.root_serializer.bytes.push(0); - let arr_length = (self.root_serializer.bytes.len() - self.inner_start) as i32; - self.root_serializer - .replace_i32(self.inner_start, arr_length); - // null byte for document self.root_serializer.bytes.push(0); - let doc_length = (self.root_serializer.bytes.len() - self.doc_start) as i32; - self.root_serializer.replace_i32(self.doc_start, doc_length); Ok(()) } } diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 8c0b2215..ddb74536 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -11,8 +11,7 @@ use crate::{ raw::RAW_DOCUMENT_NEWTYPE, ser::{write_binary, write_cstring, write_i32, write_i64, write_string, Error, Result}, spec::{BinarySubtype, ElementType}, - RawDocument, - RawJavaScriptCodeWithScopeRef, + RawDocument, RawJavaScriptCodeWithScopeRef, }; use super::{document_serializer::DocumentSerializer, Serializer}; @@ -582,19 +581,17 @@ impl SerializeStruct for &mut ValueSerializer<'_> { } pub(crate) struct CodeWithScopeSerializer<'a> { - start: usize, doc: DocumentSerializer<'a>, } impl<'a> CodeWithScopeSerializer<'a> { #[inline] fn start(code: &str, rs: &'a mut Serializer) -> Result { - let start = rs.bytes.len(); - write_i32(&mut rs.bytes, 0)?; // placeholder length + rs.write_next_len()?; write_string(&mut rs.bytes, code); let doc = DocumentSerializer::start(rs)?; - Ok(Self { start, doc }) + Ok(Self { doc }) } } @@ -620,10 +617,6 @@ impl SerializeMap for CodeWithScopeSerializer<'_> { #[inline] fn end(self) -> Result { - let result = self.doc.end_doc()?; - - let total_len = (result.root_serializer.bytes.len() - self.start) as i32; - result.root_serializer.replace_i32(self.start, total_len); - Ok(()) + self.doc.end_doc().map(|_| ()) } } From dbc2f2fa36a720955186fa96f76fb41f9097a421 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 25 Feb 2025 10:30:21 -0800 Subject: [PATCH 2/6] set keys in addition to dealing with writes --- src/ser/raw/document_serializer.rs | 13 ++++--- src/ser/raw/mod.rs | 55 +++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs index cc3caad4..f60d2751 100644 --- a/src/ser/raw/document_serializer.rs +++ b/src/ser/raw/document_serializer.rs @@ -5,7 +5,7 @@ use crate::{ to_bson, Bson, }; -use super::Serializer; +use super::{Key, Serializer}; /// Serializer used to serialize document or array bodies. pub(crate) struct DocumentSerializer<'a> { @@ -64,6 +64,7 @@ impl serde::ser::SerializeSeq for DocumentSerializer<'_> { T: serde::Serialize + ?Sized, { let index = self.num_keys_serialized; + self.root_serializer.set_next_key(Key::Index(index)); // XXX must increment num_keys_serialized. self.serialize_doc_key_custom(|rs| { use std::io::Write; write!(&mut rs.bytes, "{}", index)?; @@ -89,7 +90,7 @@ impl serde::ser::SerializeMap for DocumentSerializer<'_> { where T: serde::Serialize + ?Sized, { - self.serialize_doc_key(key) + self.serialize_doc_key(key) // XXX this may result in a new copy. } #[inline] @@ -115,7 +116,8 @@ impl serde::ser::SerializeStruct for DocumentSerializer<'_> { where T: serde::Serialize + ?Sized, { - self.serialize_doc_key(key)?; + self.root_serializer.set_next_key(Key::Static(key)); + self.serialize_doc_key(key)?; // XXX remove, this does not need to go through KeySerializer value.serialize(&mut *self.root_serializer) } @@ -135,7 +137,9 @@ impl serde::ser::SerializeTuple for DocumentSerializer<'_> { where T: serde::Serialize + ?Sized, { - self.serialize_doc_key(&self.num_keys_serialized.to_string())?; + self.root_serializer + .set_next_key(Key::Index(self.num_keys_serialized)); + self.serialize_doc_key(&self.num_keys_serialized.to_string())?; // XXX increment num_keys_serialized instead value.serialize(&mut *self.root_serializer) } @@ -252,6 +256,7 @@ impl serde::Serializer for KeySerializer<'_> { #[inline] fn serialize_str(self, v: &str) -> Result { + self.root_serializer.set_next_key(Key::Owned(v.to_owned())); write_cstring(&mut self.root_serializer.bytes, v) } diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 9be0bcbd..a15b1e78 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -27,6 +27,8 @@ pub(crate) struct Serializer { bytes: Vec, + next_key: Option, + /// The index into `bytes` where the current element type will need to be stored. /// This needs to be set retroactively because in BSON, the element type comes before the key, /// but in serde, the serializer learns of the type after serializing the key. @@ -54,6 +56,12 @@ enum SerializerHint { RawArray, } +enum Key { + Static(&'static str), + Owned(String), + Index(usize), +} + impl SerializerHint { fn take(&mut self) -> SerializerHint { std::mem::replace(self, SerializerHint::None) @@ -65,6 +73,7 @@ impl Serializer { Self { lens, bytes: Vec::new(), + next_key: None, type_index: 0, hint: SerializerHint::None, human_readable: false, @@ -81,6 +90,43 @@ impl Serializer { write_i32(&mut self.bytes, self.lens.next().expect("pre-recorded len")) } + #[inline] + fn set_next_key(&mut self, key: Key) { + self.next_key = Some(key); + } + + #[inline] + fn write_key(&mut self, t: ElementType) -> Result<()> { + if let Some(key) = self.next_key.take() { + self.bytes.push(t as u8); + self.write_key_string(&key) + } else { + if self.bytes.is_empty() && t == ElementType::EmbeddedDocument { + // don't need to set the element type for the top level document + Ok(()) + } else { + Err(Error::custom(format!( + "attempted to encode a non-document type at the top level: {:?}", + t + ))) + } + } + } + + #[inline] + fn write_key_string(&mut self, key: &Key) -> Result<()> { + match key { + Key::Static(k) => write_cstring(&mut self.bytes, k), + Key::Owned(k) => write_cstring(&mut self.bytes, &k), + Key::Index(i) => { + use std::io::Write; + write!(&mut self.bytes, "{}", i)?; + self.bytes.push(0); + Ok(()) + } + } + } + /// Reserve a spot for the element type to be set retroactively via `update_element_type`. #[inline] fn reserve_element_type(&mut self) { @@ -455,12 +501,13 @@ impl<'a> VariantSerializer<'a> { } #[inline] - fn serialize_element(&mut self, k: &str, v: &T) -> Result<()> + fn serialize_element(&mut self, k: Key, v: &T) -> Result<()> where T: Serialize + ?Sized, { self.root_serializer.reserve_element_type(); - write_cstring(&mut self.root_serializer.bytes, k)?; + self.root_serializer.write_key_string(&k)?; + self.root_serializer.set_next_key(k); v.serialize(&mut *self.root_serializer)?; self.num_elements_serialized += 1; @@ -487,7 +534,7 @@ impl serde::ser::SerializeTupleVariant for VariantSerializer<'_> { where T: Serialize + ?Sized, { - self.serialize_element(format!("{}", self.num_elements_serialized).as_str(), value) + self.serialize_element(Key::Index(self.num_elements_serialized), value) } #[inline] @@ -506,7 +553,7 @@ impl serde::ser::SerializeStructVariant for VariantSerializer<'_> { where T: Serialize + ?Sized, { - self.serialize_element(key, value) + self.serialize_element(Key::Static(key), value) } #[inline] From a8902408016b07d551de992b03a98d5c9c3e1abd Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 25 Feb 2025 11:06:42 -0800 Subject: [PATCH 3/6] key serialization moved to new method --- src/ser/mod.rs | 1 - src/ser/raw/document_serializer.rs | 38 +++--------- src/ser/raw/mod.rs | 93 +++++++++--------------------- 3 files changed, 37 insertions(+), 95 deletions(-) diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 67ca4f16..61ab9cc9 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -204,7 +204,6 @@ where let mut len_serializer = raw::len_serializer::Serializer::new(); value.serialize(&mut len_serializer)?; let lens = len_serializer.into_lens(); - println!("lens={:?}", &lens); let mut serializer = raw::Serializer::new(lens.into_iter()); #[cfg(feature = "serde_path_to_error")] { diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs index f60d2751..f3e9719b 100644 --- a/src/ser/raw/document_serializer.rs +++ b/src/ser/raw/document_serializer.rs @@ -1,7 +1,7 @@ use serde::{ser::Impossible, Serialize}; use crate::{ - ser::{write_cstring, Error, Result}, + ser::{Error, Result}, to_bson, Bson, }; @@ -22,29 +22,15 @@ impl<'a> DocumentSerializer<'a> { }) } - /// Serialize a document key using the provided closure. - fn serialize_doc_key_custom Result<()>>( - &mut self, - f: F, - ) -> Result<()> { - // push a dummy element type for now, will update this once we serialize the value - self.root_serializer.reserve_element_type(); - f(self.root_serializer)?; - self.num_keys_serialized += 1; - Ok(()) - } - /// Serialize a document key to string using [`KeySerializer`]. fn serialize_doc_key(&mut self, key: &T) -> Result<()> where T: serde::Serialize + ?Sized, { - self.serialize_doc_key_custom(|rs| { - key.serialize(KeySerializer { - root_serializer: rs, - })?; - Ok(()) + key.serialize(KeySerializer { + root_serializer: &mut self.root_serializer, })?; + self.num_keys_serialized += 1; Ok(()) } @@ -63,14 +49,9 @@ impl serde::ser::SerializeSeq for DocumentSerializer<'_> { where T: serde::Serialize + ?Sized, { - let index = self.num_keys_serialized; - self.root_serializer.set_next_key(Key::Index(index)); // XXX must increment num_keys_serialized. - self.serialize_doc_key_custom(|rs| { - use std::io::Write; - write!(&mut rs.bytes, "{}", index)?; - rs.bytes.push(0); - Ok(()) - })?; + self.root_serializer + .set_next_key(Key::Index(self.num_keys_serialized)); + self.num_keys_serialized += 1; value.serialize(&mut *self.root_serializer) } @@ -117,7 +98,6 @@ impl serde::ser::SerializeStruct for DocumentSerializer<'_> { T: serde::Serialize + ?Sized, { self.root_serializer.set_next_key(Key::Static(key)); - self.serialize_doc_key(key)?; // XXX remove, this does not need to go through KeySerializer value.serialize(&mut *self.root_serializer) } @@ -139,7 +119,7 @@ impl serde::ser::SerializeTuple for DocumentSerializer<'_> { { self.root_serializer .set_next_key(Key::Index(self.num_keys_serialized)); - self.serialize_doc_key(&self.num_keys_serialized.to_string())?; // XXX increment num_keys_serialized instead + self.num_keys_serialized += 1; value.serialize(&mut *self.root_serializer) } @@ -257,7 +237,7 @@ impl serde::Serializer for KeySerializer<'_> { #[inline] fn serialize_str(self, v: &str) -> Result { self.root_serializer.set_next_key(Key::Owned(v.to_owned())); - write_cstring(&mut self.root_serializer.bytes, v) + Ok(()) } #[inline] diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index a15b1e78..af35ecdd 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -29,11 +29,6 @@ pub(crate) struct Serializer { next_key: Option, - /// The index into `bytes` where the current element type will need to be stored. - /// This needs to be set retroactively because in BSON, the element type comes before the key, - /// but in serde, the serializer learns of the type after serializing the key. - type_index: usize, - /// Hint provided by the type being serialized. hint: SerializerHint, @@ -56,6 +51,7 @@ enum SerializerHint { RawArray, } +#[derive(Debug, Clone)] enum Key { Static(&'static str), Owned(String), @@ -74,7 +70,6 @@ impl Serializer { lens, bytes: Vec::new(), next_key: None, - type_index: 0, hint: SerializerHint::None, human_readable: false, } @@ -98,11 +93,21 @@ impl Serializer { #[inline] fn write_key(&mut self, t: ElementType) -> Result<()> { if let Some(key) = self.next_key.take() { + println!("write_key({:?}) key={:?}", t, key); self.bytes.push(t as u8); - self.write_key_string(&key) + match key { + Key::Static(k) => write_cstring(&mut self.bytes, k), + Key::Owned(k) => write_cstring(&mut self.bytes, &k), + Key::Index(i) => { + use std::io::Write; + write!(&mut self.bytes, "{}", i)?; + self.bytes.push(0); + Ok(()) + } + } } else { if self.bytes.is_empty() && t == ElementType::EmbeddedDocument { - // don't need to set the element type for the top level document + // don't need to write element type and key for top-level document. Ok(()) } else { Err(Error::custom(format!( @@ -112,46 +117,6 @@ impl Serializer { } } } - - #[inline] - fn write_key_string(&mut self, key: &Key) -> Result<()> { - match key { - Key::Static(k) => write_cstring(&mut self.bytes, k), - Key::Owned(k) => write_cstring(&mut self.bytes, &k), - Key::Index(i) => { - use std::io::Write; - write!(&mut self.bytes, "{}", i)?; - self.bytes.push(0); - Ok(()) - } - } - } - - /// Reserve a spot for the element type to be set retroactively via `update_element_type`. - #[inline] - fn reserve_element_type(&mut self) { - self.type_index = self.bytes.len(); // record index - self.bytes.push(0); // push temporary placeholder - } - - /// Retroactively set the element type of the most recently serialized element. - #[inline] - fn update_element_type(&mut self, t: ElementType) -> Result<()> { - if self.type_index == 0 { - if matches!(t, ElementType::EmbeddedDocument) { - // don't need to set the element type for the top level document - return Ok(()); - } else { - return Err(Error::custom(format!( - "attempted to encode a non-document type at the top level: {:?}", - t - ))); - } - } - - self.bytes[self.type_index] = t as u8; - Ok(()) - } } impl<'a> serde::Serializer for &'a mut Serializer { @@ -172,7 +137,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_bool(self, v: bool) -> Result { - self.update_element_type(ElementType::Boolean)?; + self.write_key(ElementType::Boolean)?; self.bytes.push(v as u8); Ok(()) } @@ -189,14 +154,14 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_i32(self, v: i32) -> Result { - self.update_element_type(ElementType::Int32)?; + self.write_key(ElementType::Int32)?; write_i32(&mut self.bytes, v)?; Ok(()) } #[inline] fn serialize_i64(self, v: i64) -> Result { - self.update_element_type(ElementType::Int64)?; + self.write_key(ElementType::Int64)?; write_i64(&mut self.bytes, v)?; Ok(()) } @@ -233,7 +198,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_f64(self, v: f64) -> Result { - self.update_element_type(ElementType::Double)?; + self.write_key(ElementType::Double)?; write_f64(&mut self.bytes, v) } @@ -246,7 +211,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_str(self, v: &str) -> Result { - self.update_element_type(ElementType::String)?; + self.write_key(ElementType::String)?; write_string(&mut self.bytes, v); Ok(()) } @@ -255,15 +220,15 @@ impl<'a> serde::Serializer for &'a mut Serializer { fn serialize_bytes(self, v: &[u8]) -> Result { match self.hint.take() { SerializerHint::RawDocument => { - self.update_element_type(ElementType::EmbeddedDocument)?; + self.write_key(ElementType::EmbeddedDocument)?; self.bytes.write_all(v)?; } SerializerHint::RawArray => { - self.update_element_type(ElementType::Array)?; + self.write_key(ElementType::Array)?; self.bytes.write_all(v)?; } hint => { - self.update_element_type(ElementType::Binary)?; + self.write_key(ElementType::Binary)?; let subtype = if matches!(hint, SerializerHint::Uuid) { BinarySubtype::Uuid @@ -279,7 +244,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_none(self) -> Result { - self.update_element_type(ElementType::Null)?; + self.write_key(ElementType::Null)?; Ok(()) } @@ -343,7 +308,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { where T: serde::Serialize + ?Sized, { - self.update_element_type(ElementType::EmbeddedDocument)?; + self.write_key(ElementType::EmbeddedDocument)?; let mut d = DocumentSerializer::start(&mut *self)?; d.serialize_entry(variant, value)?; d.end_doc()?; @@ -352,7 +317,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_seq(self, _len: Option) -> Result { - self.update_element_type(ElementType::Array)?; + self.write_key(ElementType::Array)?; DocumentSerializer::start(&mut *self) } @@ -378,13 +343,13 @@ impl<'a> serde::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { - self.update_element_type(ElementType::EmbeddedDocument)?; + self.write_key(ElementType::EmbeddedDocument)?; VariantSerializer::start(&mut *self, variant, VariantInnerType::Tuple) } #[inline] fn serialize_map(self, _len: Option) -> Result { - self.update_element_type(ElementType::EmbeddedDocument)?; + self.write_key(ElementType::EmbeddedDocument)?; DocumentSerializer::start(&mut *self) } @@ -407,7 +372,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { _ => None, }; - self.update_element_type( + self.write_key( value_type .map(Into::into) .unwrap_or(ElementType::EmbeddedDocument), @@ -426,7 +391,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { - self.update_element_type(ElementType::EmbeddedDocument)?; + self.write_key(ElementType::EmbeddedDocument)?; VariantSerializer::start(&mut *self, variant, VariantInnerType::Struct) } } @@ -505,8 +470,6 @@ impl<'a> VariantSerializer<'a> { where T: Serialize + ?Sized, { - self.root_serializer.reserve_element_type(); - self.root_serializer.write_key_string(&k)?; self.root_serializer.set_next_key(k); v.serialize(&mut *self.root_serializer)?; From 084d6bc1e67e87bcd8149d69d3fa6f3c3221bcd5 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Tue, 25 Feb 2025 12:57:55 -0800 Subject: [PATCH 4/6] to_buf_mut --- Cargo.toml | 1 + src/ser/mod.rs | 71 +++++--------- src/ser/raw/document_serializer.rs | 29 +++--- src/ser/raw/mod.rs | 150 ++++++++++++++++++----------- src/ser/raw/value_serializer.rs | 61 ++++++------ 5 files changed, 169 insertions(+), 143 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bec0c3db..a4b58c56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ serde_with-3 = { package = "serde_with", version = "3.1.0", optional = true } time = { version = "0.3.9", features = ["formatting", "parsing", "macros", "large-dates"] } bitvec = "1.0.1" serde_path_to_error = { version = "0.1.16", optional = true } +bytes = "1.10.0" [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies] js-sys = "0.3" diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 61ab9cc9..b1bbf5fa 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -34,18 +34,19 @@ use std::io::Write; use crate::{ bson::{Bson, Document}, - de::MAX_BSON_SIZE, - spec::BinarySubtype, RawDocumentBuf, }; use ::serde::{ser::Error as SerdeError, Serialize}; +use bytes::BufMut; +// XXX remove pub(crate) fn write_string(buf: &mut Vec, s: &str) { buf.extend(&(s.len() as i32 + 1).to_le_bytes()); buf.extend(s.as_bytes()); buf.push(0); } +// XXX remove pub(crate) fn write_cstring(buf: &mut Vec, s: &str) -> Result<()> { if s.contains('\0') { return Err(Error::InvalidCString(s.into())); @@ -55,6 +56,7 @@ pub(crate) fn write_cstring(buf: &mut Vec, s: &str) -> Result<()> { Ok(()) } +// XXX remove #[inline] pub(crate) fn write_i32(writer: &mut W, val: i32) -> Result<()> { writer @@ -63,47 +65,6 @@ pub(crate) fn write_i32(writer: &mut W, val: i32) -> Result<( .map_err(From::from) } -#[inline] -fn write_i64(writer: &mut W, val: i64) -> Result<()> { - writer - .write_all(&val.to_le_bytes()) - .map(|_| ()) - .map_err(From::from) -} - -#[inline] -fn write_f64(writer: &mut W, val: f64) -> Result<()> { - writer - .write_all(&val.to_le_bytes()) - .map(|_| ()) - .map_err(From::from) -} - -#[inline] -fn write_binary(mut writer: W, bytes: &[u8], subtype: BinarySubtype) -> Result<()> { - let len = if let BinarySubtype::BinaryOld = subtype { - bytes.len() + 4 - } else { - bytes.len() - }; - - if len > MAX_BSON_SIZE as usize { - return Err(Error::custom(format!( - "binary length {} exceeded maximum size", - bytes.len() - ))); - } - - write_i32(&mut writer, len as i32)?; - writer.write_all(&[subtype.into()])?; - - if let BinarySubtype::BinaryOld = subtype { - write_i32(&mut writer, len as i32 - 4)?; - }; - - writer.write_all(bytes).map_err(From::from) -} - /// Encode a `T` Serializable into a [`Bson`] value. /// /// The [`Serializer`] used by this function presents itself as human readable, whereas the @@ -200,11 +161,29 @@ where pub fn to_vec(value: &T) -> Result> where T: Serialize, +{ + to_buf_mut(value, |len| Vec::with_capacity(len)) +} + +#[inline] +pub fn to_buf_mut(value: &T, create: F) -> Result +where + T: Serialize, + F: Fn(usize) -> B, + B: BufMut, { let mut len_serializer = raw::len_serializer::Serializer::new(); - value.serialize(&mut len_serializer)?; + #[cfg(feature = "serde_path_to_error")] + { + serde_path_to_error::serialize(value, &mut len_serializer).map_err(Error::with_path)?; + } + #[cfg(not(feature = "serde_path_to_error"))] + { + value.serialize(&mut len_serializer)?; + } let lens = len_serializer.into_lens(); - let mut serializer = raw::Serializer::new(lens.into_iter()); + let buf = create(*lens.first().expect("root document must have length") as usize); + let mut serializer = raw::Serializer::new(buf, lens.into_iter()); #[cfg(feature = "serde_path_to_error")] { serde_path_to_error::serialize(value, &mut serializer).map_err(Error::with_path)?; @@ -213,7 +192,7 @@ where { value.serialize(&mut serializer)?; } - Ok(serializer.into_vec()) + Ok(serializer.into_buf()) } /// Serialize the given `T` as a [`RawDocumentBuf`]. diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs index f3e9719b..25914174 100644 --- a/src/ser/raw/document_serializer.rs +++ b/src/ser/raw/document_serializer.rs @@ -1,3 +1,4 @@ +use bytes::BufMut; use serde::{ser::Impossible, Serialize}; use crate::{ @@ -8,13 +9,13 @@ use crate::{ use super::{Key, Serializer}; /// Serializer used to serialize document or array bodies. -pub(crate) struct DocumentSerializer<'a> { - root_serializer: &'a mut Serializer, +pub(crate) struct DocumentSerializer<'a, B> { + root_serializer: &'a mut Serializer, num_keys_serialized: usize, } -impl<'a> DocumentSerializer<'a> { - pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { +impl<'a, B: BufMut> DocumentSerializer<'a, B> { + pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { rs.write_next_len()?; Ok(Self { root_serializer: rs, @@ -35,12 +36,12 @@ impl<'a> DocumentSerializer<'a> { } pub(crate) fn end_doc(self) -> crate::ser::Result<()> { - self.root_serializer.bytes.push(0); + self.root_serializer.buf.put_u8(0); Ok(()) } } -impl serde::ser::SerializeSeq for DocumentSerializer<'_> { +impl serde::ser::SerializeSeq for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -61,7 +62,7 @@ impl serde::ser::SerializeSeq for DocumentSerializer<'_> { } } -impl serde::ser::SerializeMap for DocumentSerializer<'_> { +impl serde::ser::SerializeMap for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -87,7 +88,7 @@ impl serde::ser::SerializeMap for DocumentSerializer<'_> { } } -impl serde::ser::SerializeStruct for DocumentSerializer<'_> { +impl serde::ser::SerializeStruct for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -107,7 +108,7 @@ impl serde::ser::SerializeStruct for DocumentSerializer<'_> { } } -impl serde::ser::SerializeTuple for DocumentSerializer<'_> { +impl serde::ser::SerializeTuple for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -129,7 +130,7 @@ impl serde::ser::SerializeTuple for DocumentSerializer<'_> { } } -impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_> { +impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -151,17 +152,17 @@ impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_> { /// Serializer used specifically for serializing document keys. /// Only keys that serialize to strings will be accepted. -struct KeySerializer<'a> { - root_serializer: &'a mut Serializer, +struct KeySerializer<'a, B> { + root_serializer: &'a mut Serializer, } -impl KeySerializer<'_> { +impl KeySerializer<'_, B> { fn invalid_key(v: T) -> Error { Error::InvalidDocumentKey(to_bson(&v).unwrap_or(Bson::Null)) } } -impl serde::Serializer for KeySerializer<'_> { +impl serde::Serializer for KeySerializer<'_, B> { type Ok = (); type Error = Error; diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index af35ecdd..a6d0ccb5 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -2,8 +2,7 @@ mod document_serializer; pub(super) mod len_serializer; mod value_serializer; -use std::io::Write; - +use bytes::BufMut; use serde::{ ser::{Error as SerdeError, SerializeMap, SerializeStruct}, Serialize, @@ -11,8 +10,8 @@ use serde::{ use self::value_serializer::{ValueSerializer, ValueType}; -use super::{write_binary, write_cstring, write_f64, write_i32, write_i64, write_string}; use crate::{ + de::MAX_BSON_SIZE, raw::{RAW_ARRAY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, ser::{Error, Result}, serde_helpers::HUMAN_READABLE_NEWTYPE, @@ -22,10 +21,11 @@ use crate::{ use document_serializer::DocumentSerializer; /// Serializer used to convert a type `T` into raw BSON bytes. -pub(crate) struct Serializer { - lens: std::vec::IntoIter, +pub(crate) struct Serializer { + buf: B, - bytes: Vec, + lens: std::vec::IntoIter, + started: bool, next_key: Option, @@ -64,11 +64,12 @@ impl SerializerHint { } } -impl Serializer { - pub(crate) fn new(lens: std::vec::IntoIter) -> Self { +impl Serializer { + pub(crate) fn new(buf: B, lens: std::vec::IntoIter) -> Self { Self { + buf, lens, - bytes: Vec::new(), + started: false, next_key: None, hint: SerializerHint::None, human_readable: false, @@ -76,13 +77,17 @@ impl Serializer { } /// Convert this serializer into the vec of the serialized bytes. - pub(crate) fn into_vec(self) -> Vec { - self.bytes + pub(crate) fn into_buf(self) -> B { + self.buf } + // XXX fix sig, this is not falliable. #[inline] fn write_next_len(&mut self) -> Result<()> { - write_i32(&mut self.bytes, self.lens.next().expect("pre-recorded len")) + self.buf + .put_i32_le(self.lens.next().expect("pre-recorded len")); + self.started = true; + Ok(()) } #[inline] @@ -93,20 +98,14 @@ impl Serializer { #[inline] fn write_key(&mut self, t: ElementType) -> Result<()> { if let Some(key) = self.next_key.take() { - println!("write_key({:?}) key={:?}", t, key); - self.bytes.push(t as u8); + self.buf.put_u8(t as u8); match key { - Key::Static(k) => write_cstring(&mut self.bytes, k), - Key::Owned(k) => write_cstring(&mut self.bytes, &k), - Key::Index(i) => { - use std::io::Write; - write!(&mut self.bytes, "{}", i)?; - self.bytes.push(0); - Ok(()) - } + Key::Static(k) => self.write_cstring(k), + Key::Owned(k) => self.write_cstring(&k), + Key::Index(i) => self.write_cstring(&i.to_string()), } } else { - if self.bytes.is_empty() && t == ElementType::EmbeddedDocument { + if !self.started && t == ElementType::EmbeddedDocument { // don't need to write element type and key for top-level document. Ok(()) } else { @@ -117,19 +116,62 @@ impl Serializer { } } } + + #[inline] + fn write_cstring(&mut self, s: &str) -> Result<()> { + if s.contains('\0') { + return Err(Error::InvalidCString(s.into())); + } + self.buf.put_slice(s.as_bytes()); + self.buf.put_u8(0); + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) { + self.buf.put_i32_le(s.len() as i32 + 1); + self.buf.put_slice(s.as_bytes()); + self.buf.put_u8(0); + } + + #[inline] + fn write_binary(&mut self, bytes: &[u8], subtype: BinarySubtype) -> Result<()> { + let len = if let BinarySubtype::BinaryOld = subtype { + bytes.len() + 4 + } else { + bytes.len() + }; + + if len > MAX_BSON_SIZE as usize { + return Err(Error::custom(format!( + "binary length {} exceeded maximum size", + bytes.len() + ))); + } + + self.buf.put_i32_le(len as i32); + self.buf.put_u8(subtype.into()); + + if let BinarySubtype::BinaryOld = subtype { + self.buf.put_i32_le(len as i32 - 4); + }; + + self.buf.put_slice(bytes); + Ok(()) + } } -impl<'a> serde::Serializer for &'a mut Serializer { +impl<'a, B: BufMut> serde::Serializer for &'a mut Serializer { type Ok = (); type Error = Error; - type SerializeSeq = DocumentSerializer<'a>; - type SerializeTuple = DocumentSerializer<'a>; - type SerializeTupleStruct = DocumentSerializer<'a>; - type SerializeTupleVariant = VariantSerializer<'a>; - type SerializeMap = DocumentSerializer<'a>; - type SerializeStruct = StructSerializer<'a>; - type SerializeStructVariant = VariantSerializer<'a>; + type SerializeSeq = DocumentSerializer<'a, B>; + type SerializeTuple = DocumentSerializer<'a, B>; + type SerializeTupleStruct = DocumentSerializer<'a, B>; + type SerializeTupleVariant = VariantSerializer<'a, B>; + type SerializeMap = DocumentSerializer<'a, B>; + type SerializeStruct = StructSerializer<'a, B>; + type SerializeStructVariant = VariantSerializer<'a, B>; fn is_human_readable(&self) -> bool { self.human_readable @@ -138,7 +180,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_bool(self, v: bool) -> Result { self.write_key(ElementType::Boolean)?; - self.bytes.push(v as u8); + self.buf.put_u8(v as u8); Ok(()) } @@ -155,14 +197,14 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_i32(self, v: i32) -> Result { self.write_key(ElementType::Int32)?; - write_i32(&mut self.bytes, v)?; + self.buf.put_i32_le(v); Ok(()) } #[inline] fn serialize_i64(self, v: i64) -> Result { self.write_key(ElementType::Int64)?; - write_i64(&mut self.bytes, v)?; + self.buf.put_i64_le(v); Ok(()) } @@ -199,7 +241,8 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_f64(self, v: f64) -> Result { self.write_key(ElementType::Double)?; - write_f64(&mut self.bytes, v) + self.buf.put_f64_le(v); + Ok(()) } #[inline] @@ -212,7 +255,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_str(self, v: &str) -> Result { self.write_key(ElementType::String)?; - write_string(&mut self.bytes, v); + self.write_string(v); Ok(()) } @@ -221,11 +264,11 @@ impl<'a> serde::Serializer for &'a mut Serializer { match self.hint.take() { SerializerHint::RawDocument => { self.write_key(ElementType::EmbeddedDocument)?; - self.bytes.write_all(v)?; + self.buf.put_slice(v); } SerializerHint::RawArray => { self.write_key(ElementType::Array)?; - self.bytes.write_all(v)?; + self.buf.put_slice(v); } hint => { self.write_key(ElementType::Binary)?; @@ -236,7 +279,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { BinarySubtype::Generic }; - write_binary(&mut self.bytes, v, subtype)?; + self.write_binary(v, subtype)?; } }; Ok(()) @@ -396,15 +439,15 @@ impl<'a> serde::Serializer for &'a mut Serializer { } } -pub(crate) enum StructSerializer<'a> { +pub(crate) enum StructSerializer<'a, B> { /// Serialize a BSON value currently represented in serde as a struct (e.g. ObjectId) - Value(ValueSerializer<'a>), + Value(ValueSerializer<'a, B>), /// Serialize the struct as a document. - Document(DocumentSerializer<'a>), + Document(DocumentSerializer<'a, B>), } -impl SerializeStruct for StructSerializer<'_> { +impl SerializeStruct for StructSerializer<'_, B> { type Ok = (); type Error = Error; @@ -435,16 +478,16 @@ enum VariantInnerType { /// Serializer used for enum variants, including both tuple (e.g. Foo::Bar(1, 2, 3)) and /// struct (e.g. Foo::Bar { a: 1 }). -pub(crate) struct VariantSerializer<'a> { - root_serializer: &'a mut Serializer, +pub(crate) struct VariantSerializer<'a, B> { + root_serializer: &'a mut Serializer, /// How many elements have been serialized in the inner document / array so far. num_elements_serialized: usize, } -impl<'a> VariantSerializer<'a> { +impl<'a, B: BufMut> VariantSerializer<'a, B> { fn start( - rs: &'a mut Serializer, + rs: &'a mut Serializer, variant: &'static str, inner_type: VariantInnerType, ) -> Result { @@ -454,9 +497,8 @@ impl<'a> VariantSerializer<'a> { VariantInnerType::Struct => ElementType::EmbeddedDocument, VariantInnerType::Tuple => ElementType::Array, }; - rs.bytes.push(inner as u8); - write_cstring(&mut rs.bytes, variant)?; - // write placeholder length for inner, will be updated at end + rs.buf.put_u8(inner as u8); + rs.write_cstring(&variant)?; rs.write_next_len()?; Ok(Self { @@ -480,14 +522,14 @@ impl<'a> VariantSerializer<'a> { #[inline] fn end_both(self) -> Result<()> { // null byte for the inner - self.root_serializer.bytes.push(0); + self.root_serializer.buf.put_u8(0); // null byte for document - self.root_serializer.bytes.push(0); + self.root_serializer.buf.put_u8(0); Ok(()) } } -impl serde::ser::SerializeTupleVariant for VariantSerializer<'_> { +impl serde::ser::SerializeTupleVariant for VariantSerializer<'_, B> { type Ok = (); type Error = Error; @@ -506,7 +548,7 @@ impl serde::ser::SerializeTupleVariant for VariantSerializer<'_> { } } -impl serde::ser::SerializeStructVariant for VariantSerializer<'_> { +impl serde::ser::SerializeStructVariant for VariantSerializer<'_, B> { type Ok = (); type Error = Error; diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index ddb74536..2a2c1361 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -1,5 +1,6 @@ -use std::{convert::TryFrom, io::Write}; +use std::convert::TryFrom; +use bytes::BufMut; use serde::{ ser::{Error as SerdeError, Impossible, SerializeMap, SerializeStruct}, Serialize, @@ -9,7 +10,7 @@ use crate::{ base64, oid::ObjectId, raw::RAW_DOCUMENT_NEWTYPE, - ser::{write_binary, write_cstring, write_i32, write_i64, write_string, Error, Result}, + ser::{Error, Result}, spec::{BinarySubtype, ElementType}, RawDocument, RawJavaScriptCodeWithScopeRef, }; @@ -18,8 +19,8 @@ use super::{document_serializer::DocumentSerializer, Serializer}; /// A serializer used specifically for serializing the serde-data-model form of a BSON type (e.g. /// [`Binary`]) to raw bytes. -pub(crate) struct ValueSerializer<'a> { - root_serializer: &'a mut Serializer, +pub(crate) struct ValueSerializer<'a, B> { + root_serializer: &'a mut Serializer, state: SerializationStep, } @@ -117,8 +118,8 @@ impl From for ElementType { } } -impl<'a> ValueSerializer<'a> { - pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { +impl<'a, B> ValueSerializer<'a, B> { + pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { let state = match value_type { ValueType::DateTime => SerializationStep::DateTime, ValueType::Binary => SerializationStep::Binary, @@ -148,7 +149,7 @@ impl<'a> ValueSerializer<'a> { } } -impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { +impl<'b, B: BufMut> serde::Serializer for &'b mut ValueSerializer<'_, B> { type Ok = (); type Error = Error; @@ -156,7 +157,7 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { type SerializeTuple = Impossible<(), Error>; type SerializeTupleStruct = Impossible<(), Error>; type SerializeTupleVariant = Impossible<(), Error>; - type SerializeMap = CodeWithScopeSerializer<'b>; + type SerializeMap = CodeWithScopeSerializer<'b, B>; type SerializeStruct = Self; type SerializeStructVariant = Impossible<(), Error>; @@ -191,8 +192,8 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { let t = u32::try_from(time).map_err(Error::custom)?; let i = u32::try_from(v).map_err(Error::custom)?; - write_i32(&mut self.root_serializer.bytes, i as i32)?; - write_i32(&mut self.root_serializer.bytes, t as i32)?; + self.root_serializer.buf.put_i32_le(i as i32); + self.root_serializer.buf.put_i32_le(t as i32); Ok(()) } _ => Err(self.invalid_step("i64")), @@ -203,7 +204,8 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { fn serialize_u8(self, v: u8) -> Result { match self.state { SerializationStep::RawBinarySubType { ref bytes } => { - write_binary(&mut self.root_serializer.bytes, bytes.as_slice(), v.into())?; + self.root_serializer + .write_binary(bytes.as_slice(), v.into())?; self.state = SerializationStep::Done; Ok(()) } @@ -245,11 +247,11 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { match &self.state { SerializationStep::DateTimeNumberLong => { let millis: i64 = v.parse().map_err(Error::custom)?; - write_i64(&mut self.root_serializer.bytes, millis)?; + self.root_serializer.buf.put_i64_le(millis); } SerializationStep::Oid => { let oid = ObjectId::parse_str(v).map_err(Error::custom)?; - self.root_serializer.bytes.write_all(&oid.bytes())?; + self.root_serializer.buf.put_slice(&oid.bytes()); } SerializationStep::BinaryBytes => { self.state = SerializationStep::BinarySubType { @@ -262,23 +264,24 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { let bytes = base64::decode(base64.as_str()).map_err(Error::custom)?; - write_binary(&mut self.root_serializer.bytes, bytes.as_slice(), subtype)?; + self.root_serializer + .write_binary(bytes.as_slice(), subtype)?; } SerializationStep::Symbol | SerializationStep::DbPointerRef => { - write_string(&mut self.root_serializer.bytes, v); + self.root_serializer.write_string(v); } SerializationStep::RegExPattern => { - write_cstring(&mut self.root_serializer.bytes, v)?; + self.root_serializer.write_cstring(v)?; } SerializationStep::RegExOptions => { let mut chars: Vec<_> = v.chars().collect(); chars.sort_unstable(); let sorted = chars.into_iter().collect::(); - write_cstring(&mut self.root_serializer.bytes, sorted.as_str())?; + self.root_serializer.write_cstring(sorted.as_str())?; } SerializationStep::Code => { - write_string(&mut self.root_serializer.bytes, v); + self.root_serializer.write_string(v); } SerializationStep::CodeWithScopeCode => { self.state = SerializationStep::CodeWithScopeScope { @@ -300,7 +303,7 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { fn serialize_bytes(self, v: &[u8]) -> Result { match self.state { SerializationStep::Decimal128Value => { - self.root_serializer.bytes.write_all(v)?; + self.root_serializer.buf.put_slice(v); Ok(()) } SerializationStep::BinaryBytes => { @@ -312,9 +315,9 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { code, scope: RawDocument::from_bytes(v).map_err(Error::custom)?, }; - write_i32(&mut self.root_serializer.bytes, raw.len())?; - write_string(&mut self.root_serializer.bytes, code); - self.root_serializer.bytes.write_all(v)?; + self.root_serializer.buf.put_i32_le(raw.len() as i32); + self.root_serializer.write_string(code); + self.root_serializer.buf.put_slice(v); self.state = SerializationStep::Done; Ok(()) } @@ -450,7 +453,7 @@ impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { } } -impl SerializeStruct for &mut ValueSerializer<'_> { +impl SerializeStruct for &mut ValueSerializer<'_, B> { type Ok = (); type Error = Error; @@ -580,22 +583,22 @@ impl SerializeStruct for &mut ValueSerializer<'_> { } } -pub(crate) struct CodeWithScopeSerializer<'a> { - doc: DocumentSerializer<'a>, +pub(crate) struct CodeWithScopeSerializer<'a, B> { + doc: DocumentSerializer<'a, B>, } -impl<'a> CodeWithScopeSerializer<'a> { +impl<'a, B: BufMut> CodeWithScopeSerializer<'a, B> { #[inline] - fn start(code: &str, rs: &'a mut Serializer) -> Result { + fn start(code: &str, rs: &'a mut Serializer) -> Result { rs.write_next_len()?; - write_string(&mut rs.bytes, code); + rs.write_string(code); let doc = DocumentSerializer::start(rs)?; Ok(Self { doc }) } } -impl SerializeMap for CodeWithScopeSerializer<'_> { +impl SerializeMap for CodeWithScopeSerializer<'_, B> { type Ok = (); type Error = Error; From 7f0df9dd6fa5e8533d23791a81ddb9a8690d1061 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Thu, 13 Mar 2025 21:54:07 -0700 Subject: [PATCH 5/6] WIP BufMut len_serializer --- src/ser/mod.rs | 4 +- src/ser/raw/document_serializer.rs | 21 ++-- src/ser/raw/mod.rs | 176 ++++++++++++++++++++++++----- src/ser/raw/value_serializer.rs | 15 ++- 4 files changed, 167 insertions(+), 49 deletions(-) diff --git a/src/ser/mod.rs b/src/ser/mod.rs index b1bbf5fa..02a45fb7 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -183,7 +183,7 @@ where } let lens = len_serializer.into_lens(); let buf = create(*lens.first().expect("root document must have length") as usize); - let mut serializer = raw::Serializer::new(buf, lens.into_iter()); + let mut serializer = raw::Serializer::new(raw::LenReplayingDocumentBufMut::new(buf, lens)); #[cfg(feature = "serde_path_to_error")] { serde_path_to_error::serialize(value, &mut serializer).map_err(Error::with_path)?; @@ -192,7 +192,7 @@ where { value.serialize(&mut serializer)?; } - Ok(serializer.into_buf()) + Ok(serializer.into_buf().into_inner()) } /// Serialize the given `T` as a [`RawDocumentBuf`]. diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs index 25914174..7c21d9ba 100644 --- a/src/ser/raw/document_serializer.rs +++ b/src/ser/raw/document_serializer.rs @@ -1,4 +1,3 @@ -use bytes::BufMut; use serde::{ser::Impossible, Serialize}; use crate::{ @@ -6,7 +5,7 @@ use crate::{ to_bson, Bson, }; -use super::{Key, Serializer}; +use super::{DocumentBufMut, Key, Serializer}; /// Serializer used to serialize document or array bodies. pub(crate) struct DocumentSerializer<'a, B> { @@ -14,9 +13,9 @@ pub(crate) struct DocumentSerializer<'a, B> { num_keys_serialized: usize, } -impl<'a, B: BufMut> DocumentSerializer<'a, B> { +impl<'a, B: DocumentBufMut> DocumentSerializer<'a, B> { pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { - rs.write_next_len()?; + rs.buf.begin_doc()?; Ok(Self { root_serializer: rs, num_keys_serialized: 0, @@ -36,12 +35,12 @@ impl<'a, B: BufMut> DocumentSerializer<'a, B> { } pub(crate) fn end_doc(self) -> crate::ser::Result<()> { - self.root_serializer.buf.put_u8(0); + self.root_serializer.buf.end_doc()?; Ok(()) } } -impl serde::ser::SerializeSeq for DocumentSerializer<'_, B> { +impl serde::ser::SerializeSeq for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -62,7 +61,7 @@ impl serde::ser::SerializeSeq for DocumentSerializer<'_, B> { } } -impl serde::ser::SerializeMap for DocumentSerializer<'_, B> { +impl serde::ser::SerializeMap for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -88,7 +87,7 @@ impl serde::ser::SerializeMap for DocumentSerializer<'_, B> { } } -impl serde::ser::SerializeStruct for DocumentSerializer<'_, B> { +impl serde::ser::SerializeStruct for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -108,7 +107,7 @@ impl serde::ser::SerializeStruct for DocumentSerializer<'_, B> { } } -impl serde::ser::SerializeTuple for DocumentSerializer<'_, B> { +impl serde::ser::SerializeTuple for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -130,7 +129,7 @@ impl serde::ser::SerializeTuple for DocumentSerializer<'_, B> { } } -impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_, B> { +impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_, B> { type Ok = (); type Error = Error; @@ -162,7 +161,7 @@ impl KeySerializer<'_, B> { } } -impl serde::Serializer for KeySerializer<'_, B> { +impl serde::Serializer for KeySerializer<'_, B> { type Ok = (); type Error = Error; diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index a6d0ccb5..f0b9491a 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -20,12 +20,145 @@ use crate::{ }; use document_serializer::DocumentSerializer; -/// Serializer used to convert a type `T` into raw BSON bytes. -pub(crate) struct Serializer { - buf: B, +// XXX begin_doc and end_doc appear to be infalliable. +pub(crate) trait DocumentBufMut: BufMut { + /// Track/record information related to the document started at this point. + fn begin_doc(&mut self) -> Result<()>; + /// Track/record any information related to the end of the current document. + fn end_doc(&mut self) -> Result<()>; + /// Return true if begin_doc() has been called at least once. + fn in_document(&self) -> bool; +} + +pub(crate) struct LenRecordingDocumentBufMut { + stream_len: usize, + lens: Vec, + stack: Vec<(usize, usize)>, +} + +impl LenRecordingDocumentBufMut { + fn new() -> Self { + Self { + stream_len: 0, + lens: vec![], + stack: vec![], + } + } + + fn into_lens(self) -> Vec { + assert!(self.stack.is_empty()); + self.lens + } +} + +impl DocumentBufMut for LenRecordingDocumentBufMut { + fn begin_doc(&mut self) -> Result<()> { + if self.stack.is_empty() && self.stream_len > 0 { + panic!("must begin stream with a document.") + } + let index = self.lens.len(); + self.lens.push(0); + self.stack.push((index, self.stream_len)); + self.stream_len += 4; // length value that will be written to the stream. + Ok(()) + } + + fn end_doc(&mut self) -> Result<()> { + self.stream_len += 1; // null terminator + let (index, doc_begin) = self.stack.pop().unwrap(); + self.lens[index] = self.stream_len as i32 - doc_begin as i32; + Ok(()) + } + fn in_document(&self) -> bool { + !self.stack.is_empty() + } +} + +unsafe impl BufMut for LenRecordingDocumentBufMut { + fn remaining_mut(&self) -> usize { + 0 + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + self.stream_len += cnt; + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + bytes::buf::UninitSlice::new(&mut []) + } + + fn put(&mut self, src: T) + where + Self: Sized, + { + self.stream_len += src.remaining() + } + + fn put_slice(&mut self, src: &[u8]) { + self.stream_len += src.len(); + } + + fn put_bytes(&mut self, _val: u8, cnt: usize) { + self.stream_len += cnt; + } +} + +pub(crate) struct LenReplayingDocumentBufMut { + buf: B, lens: std::vec::IntoIter, started: bool, +} + +impl LenReplayingDocumentBufMut { + pub(crate) fn new(buf: B, lens: Vec) -> Self { + Self { + buf, + lens: lens.into_iter(), + started: false, + } + } + + pub(crate) fn into_inner(self) -> B { + self.buf + } +} + +impl DocumentBufMut for LenReplayingDocumentBufMut { + fn begin_doc(&mut self) -> Result<()> { + let len = self.lens.next().unwrap(); + self.buf.put_i32_le(len); + self.started = true; + Ok(()) + } + + fn end_doc(&mut self) -> Result<()> { + self.buf.put_u8(0); + Ok(()) + } + + fn in_document(&self) -> bool { + self.started + } +} + +unsafe impl BufMut for LenReplayingDocumentBufMut { + fn remaining_mut(&self) -> usize { + self.buf.remaining_mut() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + self.buf.advance_mut(cnt); + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + self.buf.chunk_mut() + } +} + +/// Serializer used to convert a type `T` into raw BSON bytes. +pub(crate) struct Serializer { + buf: B, next_key: Option, @@ -64,12 +197,10 @@ impl SerializerHint { } } -impl Serializer { - pub(crate) fn new(buf: B, lens: std::vec::IntoIter) -> Self { +impl Serializer { + pub(crate) fn new(buf: B) -> Self { Self { buf, - lens, - started: false, next_key: None, hint: SerializerHint::None, human_readable: false, @@ -81,15 +212,6 @@ impl Serializer { self.buf } - // XXX fix sig, this is not falliable. - #[inline] - fn write_next_len(&mut self) -> Result<()> { - self.buf - .put_i32_le(self.lens.next().expect("pre-recorded len")); - self.started = true; - Ok(()) - } - #[inline] fn set_next_key(&mut self, key: Key) { self.next_key = Some(key); @@ -105,7 +227,7 @@ impl Serializer { Key::Index(i) => self.write_cstring(&i.to_string()), } } else { - if !self.started && t == ElementType::EmbeddedDocument { + if !self.buf.in_document() && t == ElementType::EmbeddedDocument { // don't need to write element type and key for top-level document. Ok(()) } else { @@ -161,7 +283,7 @@ impl Serializer { } } -impl<'a, B: BufMut> serde::Serializer for &'a mut Serializer { +impl<'a, B: DocumentBufMut> serde::Serializer for &'a mut Serializer { type Ok = (); type Error = Error; @@ -447,7 +569,7 @@ pub(crate) enum StructSerializer<'a, B> { Document(DocumentSerializer<'a, B>), } -impl SerializeStruct for StructSerializer<'_, B> { +impl SerializeStruct for StructSerializer<'_, B> { type Ok = (); type Error = Error; @@ -485,13 +607,13 @@ pub(crate) struct VariantSerializer<'a, B> { num_elements_serialized: usize, } -impl<'a, B: BufMut> VariantSerializer<'a, B> { +impl<'a, B: DocumentBufMut> VariantSerializer<'a, B> { fn start( rs: &'a mut Serializer, variant: &'static str, inner_type: VariantInnerType, ) -> Result { - rs.write_next_len()?; + rs.buf.begin_doc()?; let inner = match inner_type { VariantInnerType::Struct => ElementType::EmbeddedDocument, @@ -499,7 +621,7 @@ impl<'a, B: BufMut> VariantSerializer<'a, B> { }; rs.buf.put_u8(inner as u8); rs.write_cstring(&variant)?; - rs.write_next_len()?; + rs.buf.begin_doc()?; Ok(Self { root_serializer: rs, @@ -521,15 +643,13 @@ impl<'a, B: BufMut> VariantSerializer<'a, B> { #[inline] fn end_both(self) -> Result<()> { - // null byte for the inner - self.root_serializer.buf.put_u8(0); - // null byte for document - self.root_serializer.buf.put_u8(0); + self.root_serializer.buf.end_doc()?; + self.root_serializer.buf.end_doc()?; Ok(()) } } -impl serde::ser::SerializeTupleVariant for VariantSerializer<'_, B> { +impl serde::ser::SerializeTupleVariant for VariantSerializer<'_, B> { type Ok = (); type Error = Error; @@ -548,7 +668,7 @@ impl serde::ser::SerializeTupleVariant for VariantSerializer<'_, B> { } } -impl serde::ser::SerializeStructVariant for VariantSerializer<'_, B> { +impl serde::ser::SerializeStructVariant for VariantSerializer<'_, B> { type Ok = (); type Error = Error; diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 2a2c1361..2633545e 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -1,6 +1,5 @@ use std::convert::TryFrom; -use bytes::BufMut; use serde::{ ser::{Error as SerdeError, Impossible, SerializeMap, SerializeStruct}, Serialize, @@ -15,7 +14,7 @@ use crate::{ RawDocument, RawJavaScriptCodeWithScopeRef, }; -use super::{document_serializer::DocumentSerializer, Serializer}; +use super::{document_serializer::DocumentSerializer, DocumentBufMut, Serializer}; /// A serializer used specifically for serializing the serde-data-model form of a BSON type (e.g. /// [`Binary`]) to raw bytes. @@ -149,7 +148,7 @@ impl<'a, B> ValueSerializer<'a, B> { } } -impl<'b, B: BufMut> serde::Serializer for &'b mut ValueSerializer<'_, B> { +impl<'b, B: DocumentBufMut> serde::Serializer for &'b mut ValueSerializer<'_, B> { type Ok = (); type Error = Error; @@ -453,7 +452,7 @@ impl<'b, B: BufMut> serde::Serializer for &'b mut ValueSerializer<'_, B> { } } -impl SerializeStruct for &mut ValueSerializer<'_, B> { +impl SerializeStruct for &mut ValueSerializer<'_, B> { type Ok = (); type Error = Error; @@ -587,10 +586,10 @@ pub(crate) struct CodeWithScopeSerializer<'a, B> { doc: DocumentSerializer<'a, B>, } -impl<'a, B: BufMut> CodeWithScopeSerializer<'a, B> { +impl<'a, B: DocumentBufMut> CodeWithScopeSerializer<'a, B> { #[inline] fn start(code: &str, rs: &'a mut Serializer) -> Result { - rs.write_next_len()?; + rs.buf.begin_doc()?; rs.write_string(code); let doc = DocumentSerializer::start(rs)?; @@ -598,7 +597,7 @@ impl<'a, B: BufMut> CodeWithScopeSerializer<'a, B> { } } -impl SerializeMap for CodeWithScopeSerializer<'_, B> { +impl SerializeMap for CodeWithScopeSerializer<'_, B> { type Ok = (); type Error = Error; @@ -620,6 +619,6 @@ impl SerializeMap for CodeWithScopeSerializer<'_, B> { #[inline] fn end(self) -> Result { - self.doc.end_doc().map(|_| ()) + self.doc.end_doc() } } From 7f65d537816be84409191be09151aeff28516f92 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Wed, 19 Mar 2025 21:10:47 -0700 Subject: [PATCH 6/6] refactor length serialization to significantly reduce code size --- src/ser/mod.rs | 6 +- src/ser/raw/document_serializer.rs | 12 +- src/ser/raw/len_serializer.rs | 1430 ---------------------------- src/ser/raw/mod.rs | 64 +- src/ser/raw/value_serializer.rs | 7 +- 5 files changed, 56 insertions(+), 1463 deletions(-) delete mode 100644 src/ser/raw/len_serializer.rs diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 02a45fb7..a0639c51 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -172,7 +172,7 @@ where F: Fn(usize) -> B, B: BufMut, { - let mut len_serializer = raw::len_serializer::Serializer::new(); + let mut len_serializer = raw::Serializer::new(raw::LenRecordingDocumentBufMut::new()); #[cfg(feature = "serde_path_to_error")] { serde_path_to_error::serialize(value, &mut len_serializer).map_err(Error::with_path)?; @@ -181,8 +181,8 @@ where { value.serialize(&mut len_serializer)?; } - let lens = len_serializer.into_lens(); - let buf = create(*lens.first().expect("root document must have length") as usize); + let lens = len_serializer.into_buf().into_lens(); + let buf = create(*lens.first().unwrap_or(&5) as usize); let mut serializer = raw::Serializer::new(raw::LenReplayingDocumentBufMut::new(buf, lens)); #[cfg(feature = "serde_path_to_error")] { diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs index 7c21d9ba..325a44c5 100644 --- a/src/ser/raw/document_serializer.rs +++ b/src/ser/raw/document_serializer.rs @@ -2,6 +2,7 @@ use serde::{ser::Impossible, Serialize}; use crate::{ ser::{Error, Result}, + spec::ElementType, to_bson, Bson, }; @@ -14,8 +15,11 @@ pub(crate) struct DocumentSerializer<'a, B> { } impl<'a, B: DocumentBufMut> DocumentSerializer<'a, B> { - pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { - rs.buf.begin_doc()?; + pub(crate) fn start( + rs: &'a mut Serializer, + doc_type: ElementType, + ) -> crate::ser::Result { + rs.buf.begin_doc(doc_type)?; Ok(Self { root_serializer: rs, num_keys_serialized: 0, @@ -34,9 +38,9 @@ impl<'a, B: DocumentBufMut> DocumentSerializer<'a, B> { Ok(()) } - pub(crate) fn end_doc(self) -> crate::ser::Result<()> { + pub(crate) fn end_doc(self) -> crate::ser::Result<&'a mut Serializer> { self.root_serializer.buf.end_doc()?; - Ok(()) + Ok(self.root_serializer) } } diff --git a/src/ser/raw/len_serializer.rs b/src/ser/raw/len_serializer.rs deleted file mode 100644 index cd9bfece..00000000 --- a/src/ser/raw/len_serializer.rs +++ /dev/null @@ -1,1430 +0,0 @@ -use serde::{ - ser::{Error as SerdeError, SerializeMap, SerializeStruct}, - Serialize, -}; - -use crate::{ - raw::{RAW_ARRAY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, - ser::{Error, Result}, - serde_helpers::HUMAN_READABLE_NEWTYPE, - spec::{BinarySubtype, ElementType}, - uuid::UUID_NEWTYPE_NAME, -}; - -/// Serializer used to convert a type `T` into raw BSON bytes. -pub(crate) struct Serializer { - /// Length of all documents visited by the serializer in the order in which they are serialized. - /// The length of the root document will always appear at index zero. - lens: Vec, - - /// Index of each document and sub-document we are computing the length of. - /// For well-formed serialization requests this will always contain at least one element. - lens_stack: Vec, - - /// Hint provided by the type being serialized. - hint: SerializerHint, - - human_readable: bool, -} - -/// Various bits of information that the serialized type can provide to the serializer to -/// inform the purpose of the next serialization step. -#[derive(Debug, Clone, Copy)] -enum SerializerHint { - None, - - /// The next call to `serialize_bytes` is for the purposes of serializing a UUID. - Uuid, - - /// The next call to `serialize_bytes` is for the purposes of serializing a raw document. - RawDocument, - - /// The next call to `serialize_bytes` is for the purposes of serializing a raw array. - RawArray, -} - -impl SerializerHint { - fn take(&mut self) -> SerializerHint { - std::mem::replace(self, SerializerHint::None) - } -} - -impl Serializer { - pub(crate) fn new() -> Self { - Self { - lens: vec![], - lens_stack: vec![], - hint: SerializerHint::None, - human_readable: false, - } - } - - pub(crate) fn into_lens(self) -> Vec { - assert!(self.lens_stack.is_empty()); - self.lens - } - - #[inline] - fn enter_doc(&mut self) { - let index = self.lens.len(); - self.lens.push(0); - self.lens_stack.push(index); - } - - #[inline] - fn exit_doc(&mut self) { - let index = self - .lens_stack - .pop() - .expect("document enter and exit are paired"); - self.lens[index] += 4 + 1; // i32 doc len + null terminator. - let len = self.lens[index]; - if let Some(parent_index) = self.lens_stack.last() { - // propagate length back up to parent, if present. - self.lens[*parent_index] += len; - } - } - - #[inline] - fn add_bytes(&mut self, bytes: i32) -> Result<()> { - if let Some(index) = self.lens_stack.last() { - self.lens[*index] += bytes; - Ok(()) - } else { - Err(Error::custom(format!( - "attempted to encode a non-document type at the top level", - ))) - } - } - - #[inline] - fn add_element_name_and_type(&mut self, len: usize) -> Result<()> { - // type + length + null terminator. - self.add_bytes(1 + len as i32 + 1) - } - - #[inline] - fn add_cstr_bytes(&mut self, len: usize) -> Result<()> { - self.add_bytes(len as i32 + 1) - } - - #[inline] - fn add_bin_bytes(&mut self, len: usize, subtype: BinarySubtype) -> Result<()> { - let total_len = if subtype == BinarySubtype::BinaryOld { - 4 + 1 + 4 + len as i32 - } else { - 4 + 1 + len as i32 - }; - self.add_bytes(total_len) - } - - #[inline] - fn add_str_bytes(&mut self, len: usize) -> Result<()> { - self.add_bytes(4 + len as i32 + 1) - } -} - -impl<'a> serde::Serializer for &'a mut Serializer { - type Ok = (); - type Error = Error; - - type SerializeSeq = DocumentSerializer<'a>; - type SerializeTuple = DocumentSerializer<'a>; - type SerializeTupleStruct = DocumentSerializer<'a>; - type SerializeTupleVariant = VariantSerializer<'a>; - type SerializeMap = DocumentSerializer<'a>; - type SerializeStruct = StructSerializer<'a>; - type SerializeStructVariant = VariantSerializer<'a>; - - fn is_human_readable(&self) -> bool { - self.human_readable - } - - #[inline] - fn serialize_bool(self, _v: bool) -> Result { - self.add_bytes(1) - } - - #[inline] - fn serialize_i8(self, v: i8) -> Result { - self.serialize_i32(v.into()) - } - - #[inline] - fn serialize_i16(self, v: i16) -> Result { - self.serialize_i32(v.into()) - } - - #[inline] - fn serialize_i32(self, _v: i32) -> Result { - self.add_bytes(4) - } - - #[inline] - fn serialize_i64(self, _v: i64) -> Result { - self.add_bytes(8) - } - - #[inline] - fn serialize_u8(self, v: u8) -> Result { - self.serialize_i32(v.into()) - } - - #[inline] - fn serialize_u16(self, v: u16) -> Result { - self.serialize_i32(v.into()) - } - - #[inline] - fn serialize_u32(self, v: u32) -> Result { - self.serialize_i64(v.into()) - } - - #[inline] - fn serialize_u64(self, _v: u64) -> Result { - self.add_bytes(8) - } - - #[inline] - fn serialize_f32(self, v: f32) -> Result { - self.serialize_f64(v.into()) - } - - #[inline] - fn serialize_f64(self, _v: f64) -> Result { - self.add_bytes(8) - } - - #[inline] - fn serialize_char(self, v: char) -> Result { - let mut s = String::new(); - s.push(v); - self.serialize_str(&s) - } - - #[inline] - fn serialize_str(self, v: &str) -> Result { - self.add_str_bytes(v.len()) - } - - #[inline] - fn serialize_bytes(self, v: &[u8]) -> Result { - match self.hint.take() { - SerializerHint::RawDocument | SerializerHint::RawArray => { - if self.lens_stack.is_empty() { - // The root document is raw in this case. - self.enter_doc(); - let result = self.add_bytes(v.len() as i32); - self.exit_doc(); - result - } else { - // We don't record these as docs as the lengths aren't computed from multiple inputs. - self.add_bytes(v.len() as i32) - } - } - // NB: in this path we would never emit BinaryOld. - _ => self.add_bin_bytes(v.len(), BinarySubtype::Generic), - } - } - - #[inline] - fn serialize_none(self) -> Result { - // this writes an ElementType::Null, which records 0 following bytes for the value. - Ok(()) - } - - #[inline] - fn serialize_some(self, value: &T) -> Result - where - T: serde::Serialize + ?Sized, - { - value.serialize(self) - } - - #[inline] - fn serialize_unit(self) -> Result { - self.serialize_none() - } - - #[inline] - fn serialize_unit_struct(self, _name: &'static str) -> Result { - self.serialize_unit() - } - - #[inline] - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - ) -> Result { - self.serialize_str(variant) - } - - #[inline] - fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result - where - T: serde::Serialize + ?Sized, - { - match name { - UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, - RAW_DOCUMENT_NEWTYPE => self.hint = SerializerHint::RawDocument, - RAW_ARRAY_NEWTYPE => self.hint = SerializerHint::RawArray, - HUMAN_READABLE_NEWTYPE => { - let old = self.human_readable; - self.human_readable = true; - let result = value.serialize(&mut *self); - self.human_readable = old; - return result; - } - _ => {} - } - value.serialize(self) - } - - #[inline] - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - value: &T, - ) -> Result - where - T: serde::Serialize + ?Sized, - { - let mut d = DocumentSerializer::start(&mut *self)?; - d.serialize_entry(variant, value)?; - d.end_doc()?; - Ok(()) - } - - #[inline] - fn serialize_seq(self, _len: Option) -> Result { - DocumentSerializer::start(&mut *self) - } - - #[inline] - fn serialize_tuple(self, len: usize) -> Result { - self.serialize_seq(Some(len)) - } - - #[inline] - fn serialize_tuple_struct( - self, - _name: &'static str, - len: usize, - ) -> Result { - self.serialize_seq(Some(len)) - } - - #[inline] - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - _len: usize, - ) -> Result { - VariantSerializer::start(&mut *self, variant) - } - - #[inline] - fn serialize_map(self, _len: Option) -> Result { - DocumentSerializer::start(&mut *self) - } - - #[inline] - fn serialize_struct(self, name: &'static str, _len: usize) -> Result { - let value_type = match name { - "$oid" => Some(ValueType::ObjectId), - "$date" => Some(ValueType::DateTime), - "$binary" => Some(ValueType::Binary), - "$timestamp" => Some(ValueType::Timestamp), - "$minKey" => Some(ValueType::MinKey), - "$maxKey" => Some(ValueType::MaxKey), - "$code" => Some(ValueType::JavaScriptCode), - "$codeWithScope" => Some(ValueType::JavaScriptCodeWithScope), - "$symbol" => Some(ValueType::Symbol), - "$undefined" => Some(ValueType::Undefined), - "$regularExpression" => Some(ValueType::RegularExpression), - "$dbPointer" => Some(ValueType::DbPointer), - "$numberDecimal" => Some(ValueType::Decimal128), - _ => None, - }; - - match value_type { - Some(vt) => Ok(StructSerializer::Value(ValueSerializer::new(self, vt))), - None => Ok(StructSerializer::Document(DocumentSerializer::start(self)?)), - } - } - - #[inline] - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - _len: usize, - ) -> Result { - VariantSerializer::start(&mut *self, variant) - } -} - -pub(crate) enum StructSerializer<'a> { - /// Serialize a BSON value currently represented in serde as a struct (e.g. ObjectId) - Value(ValueSerializer<'a>), - - /// Serialize the struct as a document. - Document(DocumentSerializer<'a>), -} - -impl SerializeStruct for StructSerializer<'_> { - type Ok = (); - type Error = Error; - - #[inline] - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - match self { - StructSerializer::Value(ref mut v) => (&mut *v).serialize_field(key, value), - StructSerializer::Document(d) => d.serialize_field(key, value), - } - } - - #[inline] - fn end(self) -> Result { - match self { - StructSerializer::Document(d) => SerializeStruct::end(d), - StructSerializer::Value(mut v) => v.end(), - } - } -} - -/// Serializer used for enum variants, including both tuple (e.g. Foo::Bar(1, 2, 3)) and -/// struct (e.g. Foo::Bar { a: 1 }). -pub(crate) struct VariantSerializer<'a> { - root_serializer: &'a mut Serializer, - - /// How many elements have been serialized in the inner document / array so far. - num_elements_serialized: usize, -} - -impl<'a> VariantSerializer<'a> { - fn start(rs: &'a mut Serializer, variant: &'static str) -> Result { - rs.enter_doc(); // outer doc for variant - rs.add_element_name_and_type(variant.len())?; - - rs.enter_doc(); // inner doc/array containing variant doc/tuple. - Ok(Self { - root_serializer: rs, - num_elements_serialized: 0, - }) - } - - #[inline] - fn serialize_element(&mut self, k: &str, v: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - self.root_serializer.add_element_name_and_type(k.len())?; - v.serialize(&mut *self.root_serializer)?; - self.num_elements_serialized += 1; - Ok(()) - } - - #[inline] - fn end_both(self) -> Result<()> { - self.root_serializer.exit_doc(); // close variant doc/array - self.root_serializer.exit_doc(); // close variant wrapper. - Ok(()) - } -} - -impl serde::ser::SerializeTupleVariant for VariantSerializer<'_> { - type Ok = (); - - type Error = Error; - - #[inline] - fn serialize_field(&mut self, value: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - self.serialize_element(format!("{}", self.num_elements_serialized).as_str(), value) - } - - #[inline] - fn end(self) -> Result { - self.end_both() - } -} - -impl serde::ser::SerializeStructVariant for VariantSerializer<'_> { - type Ok = (); - - type Error = Error; - - #[inline] - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - self.serialize_element(key, value) - } - - #[inline] - fn end(self) -> Result { - self.end_both() - } -} - -use serde::ser::Impossible; - -use crate::{to_bson, Bson}; - -/// Serializer used to serialize document or array bodies. -pub(crate) struct DocumentSerializer<'a> { - root_serializer: &'a mut Serializer, - num_keys_serialized: usize, -} - -impl<'a> DocumentSerializer<'a> { - pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { - rs.enter_doc(); - Ok(Self { - root_serializer: rs, - num_keys_serialized: 0, - }) - } - - /// Serialize a document key using the provided closure. - fn serialize_doc_key_custom Result<()>>( - &mut self, - f: F, - ) -> Result<()> { - f(self.root_serializer)?; - self.num_keys_serialized += 1; - Ok(()) - } - - /// Serialize a document key to string using [`KeySerializer`]. - fn serialize_doc_key(&mut self, key: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - self.serialize_doc_key_custom(|rs| { - key.serialize(KeySerializer { - root_serializer: rs, - })?; - Ok(()) - })?; - Ok(()) - } - - pub(crate) fn end_doc(self) -> crate::ser::Result<&'a mut Serializer> { - self.root_serializer.exit_doc(); - Ok(self.root_serializer) - } -} - -impl serde::ser::SerializeSeq for DocumentSerializer<'_> { - type Ok = (); - type Error = Error; - - #[inline] - fn serialize_element(&mut self, value: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - let index = self.num_keys_serialized; - self.serialize_doc_key_custom(|rs| rs.add_element_name_and_type(index.to_string().len()))?; - value.serialize(&mut *self.root_serializer) - } - - #[inline] - fn end(self) -> Result { - self.end_doc().map(|_| ()) - } -} - -impl serde::ser::SerializeMap for DocumentSerializer<'_> { - type Ok = (); - - type Error = Error; - - #[inline] - fn serialize_key(&mut self, key: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - self.serialize_doc_key(key) - } - - #[inline] - fn serialize_value(&mut self, value: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - value.serialize(&mut *self.root_serializer) - } - - fn end(self) -> Result { - self.end_doc().map(|_| ()) - } -} - -impl serde::ser::SerializeStruct for DocumentSerializer<'_> { - type Ok = (); - - type Error = Error; - - #[inline] - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - self.serialize_doc_key(key)?; - value.serialize(&mut *self.root_serializer) - } - - #[inline] - fn end(self) -> Result { - self.end_doc().map(|_| ()) - } -} - -impl serde::ser::SerializeTuple for DocumentSerializer<'_> { - type Ok = (); - - type Error = Error; - - #[inline] - fn serialize_element(&mut self, value: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - self.serialize_doc_key(&self.num_keys_serialized.to_string())?; - value.serialize(&mut *self.root_serializer) - } - - #[inline] - fn end(self) -> Result { - self.end_doc().map(|_| ()) - } -} - -impl serde::ser::SerializeTupleStruct for DocumentSerializer<'_> { - type Ok = (); - - type Error = Error; - - #[inline] - fn serialize_field(&mut self, value: &T) -> Result<()> - where - T: serde::Serialize + ?Sized, - { - self.serialize_doc_key(&self.num_keys_serialized.to_string())?; - value.serialize(&mut *self.root_serializer) - } - - #[inline] - fn end(self) -> Result { - self.end_doc().map(|_| ()) - } -} - -/// Serializer used specifically for serializing document keys. -/// Only keys that serialize to strings will be accepted. -struct KeySerializer<'a> { - root_serializer: &'a mut Serializer, -} - -impl KeySerializer<'_> { - fn invalid_key(v: T) -> Error { - Error::InvalidDocumentKey(to_bson(&v).unwrap_or(Bson::Null)) - } -} - -impl serde::Serializer for KeySerializer<'_> { - type Ok = (); - - type Error = Error; - - type SerializeSeq = Impossible<(), Error>; - type SerializeTuple = Impossible<(), Error>; - type SerializeTupleStruct = Impossible<(), Error>; - type SerializeTupleVariant = Impossible<(), Error>; - type SerializeMap = Impossible<(), Error>; - type SerializeStruct = Impossible<(), Error>; - type SerializeStructVariant = Impossible<(), Error>; - - #[inline] - fn serialize_bool(self, v: bool) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_i8(self, v: i8) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_i16(self, v: i16) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_i32(self, v: i32) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_i64(self, v: i64) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_u8(self, v: u8) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_u16(self, v: u16) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_u32(self, v: u32) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_u64(self, v: u64) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_f32(self, v: f32) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_f64(self, v: f64) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_char(self, v: char) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_str(self, v: &str) -> Result { - self.root_serializer.add_element_name_and_type(v.len()) - } - - #[inline] - fn serialize_bytes(self, v: &[u8]) -> Result { - Err(Self::invalid_key(v)) - } - - #[inline] - fn serialize_none(self) -> Result { - Err(Self::invalid_key(Bson::Null)) - } - - #[inline] - fn serialize_some(self, value: &T) -> Result - where - T: Serialize + ?Sized, - { - value.serialize(self) - } - - #[inline] - fn serialize_unit(self) -> Result { - Err(Self::invalid_key(Bson::Null)) - } - - #[inline] - fn serialize_unit_struct(self, _name: &'static str) -> Result { - Err(Self::invalid_key(Bson::Null)) - } - - #[inline] - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - ) -> Result { - self.serialize_str(variant) - } - - #[inline] - fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result - where - T: Serialize + ?Sized, - { - value.serialize(self) - } - - #[inline] - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - value: &T, - ) -> Result - where - T: Serialize + ?Sized, - { - Err(Self::invalid_key(value)) - } - - #[inline] - fn serialize_seq(self, _len: Option) -> Result { - Err(Self::invalid_key(Bson::Array(vec![]))) - } - - #[inline] - fn serialize_tuple(self, _len: usize) -> Result { - Err(Self::invalid_key(Bson::Array(vec![]))) - } - - #[inline] - fn serialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - ) -> Result { - Err(Self::invalid_key(Bson::Document(doc! {}))) - } - - #[inline] - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(Self::invalid_key(Bson::Array(vec![]))) - } - - #[inline] - fn serialize_map(self, _len: Option) -> Result { - Err(Self::invalid_key(Bson::Document(doc! {}))) - } - - #[inline] - fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - Err(Self::invalid_key(Bson::Document(doc! {}))) - } - - #[inline] - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(Self::invalid_key(Bson::Document(doc! {}))) - } -} - -use crate::{base64, RawDocument, RawJavaScriptCodeWithScopeRef}; - -/// A serializer used specifically for serializing the serde-data-model form of a BSON type (e.g. -/// [`Binary`]) to raw bytes. -pub(crate) struct ValueSerializer<'a> { - root_serializer: &'a mut Serializer, - state: SerializationStep, -} - -/// State machine used to track which step in the serialization of a given type the serializer is -/// currently on. -#[derive(Debug)] -enum SerializationStep { - Oid, - - DateTime, - DateTimeNumberLong, - - Binary, - /// This step can either transition to the raw or base64 steps depending - /// on whether a string or bytes are serialized. - BinaryBytes, - BinarySubType { - base64: String, - }, - RawBinarySubType { - bytes: Vec, - }, - - Symbol, - - RegEx, - RegExPattern, - RegExOptions, - - Timestamp, - TimestampTime, - TimestampIncrement, - - DbPointer, - DbPointerRef, - DbPointerId, - - Code, - - CodeWithScopeCode, - CodeWithScopeScope { - code: String, - raw: bool, - }, - - MinKey, - - MaxKey, - - Undefined, - - Decimal128, - Decimal128Value, - - Done, -} - -/// Enum of BSON "value" types that this serializer can serialize. -#[derive(Debug, Clone, Copy)] -pub(super) enum ValueType { - DateTime, - Binary, - ObjectId, - Symbol, - RegularExpression, - Timestamp, - DbPointer, - JavaScriptCode, - JavaScriptCodeWithScope, - MinKey, - MaxKey, - Decimal128, - Undefined, -} - -impl From for ElementType { - fn from(vt: ValueType) -> Self { - match vt { - ValueType::Binary => ElementType::Binary, - ValueType::DateTime => ElementType::DateTime, - ValueType::DbPointer => ElementType::DbPointer, - ValueType::Decimal128 => ElementType::Decimal128, - ValueType::Symbol => ElementType::Symbol, - ValueType::RegularExpression => ElementType::RegularExpression, - ValueType::Timestamp => ElementType::Timestamp, - ValueType::JavaScriptCode => ElementType::JavaScriptCode, - ValueType::JavaScriptCodeWithScope => ElementType::JavaScriptCodeWithScope, - ValueType::MaxKey => ElementType::MaxKey, - ValueType::MinKey => ElementType::MinKey, - ValueType::Undefined => ElementType::Undefined, - ValueType::ObjectId => ElementType::ObjectId, - } - } -} - -impl<'a> ValueSerializer<'a> { - pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { - let state = match value_type { - ValueType::DateTime => SerializationStep::DateTime, - ValueType::Binary => SerializationStep::Binary, - ValueType::ObjectId => SerializationStep::Oid, - ValueType::Symbol => SerializationStep::Symbol, - ValueType::RegularExpression => SerializationStep::RegEx, - ValueType::Timestamp => SerializationStep::Timestamp, - ValueType::DbPointer => SerializationStep::DbPointer, - ValueType::JavaScriptCode => SerializationStep::Code, - ValueType::JavaScriptCodeWithScope => SerializationStep::CodeWithScopeCode, - ValueType::MinKey => SerializationStep::MinKey, - ValueType::MaxKey => SerializationStep::MaxKey, - ValueType::Decimal128 => SerializationStep::Decimal128, - ValueType::Undefined => SerializationStep::Undefined, - }; - Self { - root_serializer: rs, - state, - } - } - - fn invalid_step(&self, primitive_type: &'static str) -> Error { - Error::custom(format!( - "cannot serialize {} at step {:?}", - primitive_type, self.state - )) - } -} - -impl<'b> serde::Serializer for &'b mut ValueSerializer<'_> { - type Ok = (); - type Error = Error; - - type SerializeSeq = Impossible<(), Error>; - type SerializeTuple = Impossible<(), Error>; - type SerializeTupleStruct = Impossible<(), Error>; - type SerializeTupleVariant = Impossible<(), Error>; - type SerializeMap = CodeWithScopeSerializer<'b>; - type SerializeStruct = Self; - type SerializeStructVariant = Impossible<(), Error>; - - #[inline] - fn serialize_bool(self, _v: bool) -> Result { - Err(self.invalid_step("bool")) - } - - #[inline] - fn serialize_i8(self, _v: i8) -> Result { - Err(self.invalid_step("i8")) - } - - #[inline] - fn serialize_i16(self, _v: i16) -> Result { - Err(self.invalid_step("i16")) - } - - #[inline] - fn serialize_i32(self, _v: i32) -> Result { - Err(self.invalid_step("i32")) - } - - #[inline] - fn serialize_i64(self, _v: i64) -> Result { - match self.state { - SerializationStep::TimestampTime => { - self.state = SerializationStep::TimestampIncrement; - Ok(()) - } - SerializationStep::TimestampIncrement => self.root_serializer.add_bytes(8), - _ => Err(self.invalid_step("i64")), - } - } - - #[inline] - fn serialize_u8(self, v: u8) -> Result { - match self.state { - SerializationStep::RawBinarySubType { ref bytes } => { - self.root_serializer.add_bin_bytes(bytes.len(), v.into())?; - self.state = SerializationStep::Done; - Ok(()) - } - _ => Err(self.invalid_step("u8")), - } - } - - #[inline] - fn serialize_u16(self, _v: u16) -> Result { - Err(self.invalid_step("u16")) - } - - #[inline] - fn serialize_u32(self, _v: u32) -> Result { - Err(self.invalid_step("u32")) - } - - #[inline] - fn serialize_u64(self, _v: u64) -> Result { - Err(self.invalid_step("u64")) - } - - #[inline] - fn serialize_f32(self, _v: f32) -> Result { - Err(self.invalid_step("f32")) - } - - #[inline] - fn serialize_f64(self, _v: f64) -> Result { - Err(self.invalid_step("f64")) - } - - #[inline] - fn serialize_char(self, _v: char) -> Result { - Err(self.invalid_step("char")) - } - - fn serialize_str(self, v: &str) -> Result { - match &self.state { - SerializationStep::DateTimeNumberLong => { - self.root_serializer.add_bytes(8)?; - } - SerializationStep::Oid => { - self.root_serializer.add_bytes(12)?; - } - SerializationStep::BinaryBytes => { - self.state = SerializationStep::BinarySubType { - base64: v.to_string(), - }; - } - SerializationStep::BinarySubType { base64 } => { - let subtype_byte = hex::decode(v).map_err(Error::custom)?; - let subtype: BinarySubtype = subtype_byte[0].into(); - let bytes = base64::decode(base64.as_str()).map_err(Error::custom)?; - self.root_serializer.add_bin_bytes(bytes.len(), subtype)?; - } - SerializationStep::Symbol | SerializationStep::DbPointerRef => { - self.root_serializer.add_str_bytes(v.len())?; - } - SerializationStep::RegExPattern => { - self.root_serializer.add_cstr_bytes(v.len())?; - } - SerializationStep::RegExOptions => { - self.root_serializer.add_cstr_bytes(v.len())?; - } - SerializationStep::Code => { - self.root_serializer.add_str_bytes(v.len())?; - } - SerializationStep::CodeWithScopeCode => { - self.state = SerializationStep::CodeWithScopeScope { - code: v.to_string(), - raw: false, - }; - } - s => { - return Err(Error::custom(format!( - "can't serialize string for step {:?}", - s - ))) - } - } - Ok(()) - } - - #[inline] - fn serialize_bytes(self, v: &[u8]) -> Result { - match self.state { - SerializationStep::Decimal128Value => self.root_serializer.add_bytes(16), - SerializationStep::BinaryBytes => { - self.state = SerializationStep::RawBinarySubType { bytes: v.to_vec() }; - Ok(()) - } - SerializationStep::CodeWithScopeScope { ref code, raw } if raw => { - let raw = RawJavaScriptCodeWithScopeRef { - code, - scope: RawDocument::from_bytes(v).map_err(Error::custom)?, - }; - self.root_serializer.add_bytes(4)?; - self.root_serializer.add_str_bytes(code.len())?; - self.root_serializer.add_bytes(raw.len())?; - self.state = SerializationStep::Done; - Ok(()) - } - _ => Err(self.invalid_step("&[u8]")), - } - } - - #[inline] - fn serialize_none(self) -> Result { - Err(self.invalid_step("none")) - } - - #[inline] - fn serialize_some(self, _value: &T) -> Result - where - T: Serialize + ?Sized, - { - Err(self.invalid_step("some")) - } - - #[inline] - fn serialize_unit(self) -> Result { - Err(self.invalid_step("unit")) - } - - #[inline] - fn serialize_unit_struct(self, _name: &'static str) -> Result { - Err(self.invalid_step("unit_struct")) - } - - #[inline] - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - ) -> Result { - Err(self.invalid_step("unit_variant")) - } - - #[inline] - fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result - where - T: Serialize + ?Sized, - { - match (&mut self.state, name) { - ( - SerializationStep::CodeWithScopeScope { - code: _, - ref mut raw, - }, - RAW_DOCUMENT_NEWTYPE, - ) => { - *raw = true; - value.serialize(self) - } - _ => Err(self.invalid_step("newtype_struct")), - } - } - - #[inline] - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _value: &T, - ) -> Result - where - T: Serialize + ?Sized, - { - Err(self.invalid_step("newtype_variant")) - } - - #[inline] - fn serialize_seq(self, _len: Option) -> Result { - Err(self.invalid_step("seq")) - } - - #[inline] - fn serialize_tuple(self, _len: usize) -> Result { - Err(self.invalid_step("newtype_tuple")) - } - - #[inline] - fn serialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - ) -> Result { - Err(self.invalid_step("tuple_struct")) - } - - #[inline] - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(self.invalid_step("tuple_variant")) - } - - #[inline] - fn serialize_map(self, _len: Option) -> Result { - match self.state { - SerializationStep::CodeWithScopeScope { ref code, raw } if !raw => { - CodeWithScopeSerializer::start(code.as_str(), self.root_serializer) - } - _ => Err(self.invalid_step("map")), - } - } - - #[inline] - fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - Ok(self) - } - - #[inline] - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - Err(self.invalid_step("struct_variant")) - } - - fn is_human_readable(&self) -> bool { - false - } -} - -impl SerializeStruct for &mut ValueSerializer<'_> { - type Ok = (); - type Error = Error; - - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - match (&self.state, key) { - (SerializationStep::DateTime, "$date") => { - self.state = SerializationStep::DateTimeNumberLong; - value.serialize(&mut **self)?; - } - (SerializationStep::DateTimeNumberLong, "$numberLong") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::Oid, "$oid") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::Binary, "$binary") => { - self.state = SerializationStep::BinaryBytes; - value.serialize(&mut **self)?; - } - (SerializationStep::BinaryBytes, key) if key == "bytes" || key == "base64" => { - // state is updated in serialize - value.serialize(&mut **self)?; - } - (SerializationStep::RawBinarySubType { .. }, "subType") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::BinarySubType { .. }, "subType") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::Symbol, "$symbol") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::RegEx, "$regularExpression") => { - self.state = SerializationStep::RegExPattern; - value.serialize(&mut **self)?; - } - (SerializationStep::RegExPattern, "pattern") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::RegExOptions; - } - (SerializationStep::RegExOptions, "options") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::Timestamp, "$timestamp") => { - self.state = SerializationStep::TimestampTime; - value.serialize(&mut **self)?; - } - (SerializationStep::TimestampTime, "t") => { - // state is updated in serialize - value.serialize(&mut **self)?; - } - (SerializationStep::TimestampIncrement { .. }, "i") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::DbPointer, "$dbPointer") => { - self.state = SerializationStep::DbPointerRef; - value.serialize(&mut **self)?; - } - (SerializationStep::DbPointerRef, "$ref") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::DbPointerId; - } - (SerializationStep::DbPointerId, "$id") => { - self.state = SerializationStep::Oid; - value.serialize(&mut **self)?; - } - (SerializationStep::Code, "$code") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::CodeWithScopeCode, "$code") => { - // state is updated in serialize - value.serialize(&mut **self)?; - } - (SerializationStep::CodeWithScopeScope { .. }, "$scope") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::MinKey, "$minKey") => { - self.state = SerializationStep::Done; - } - (SerializationStep::MaxKey, "$maxKey") => { - self.state = SerializationStep::Done; - } - (SerializationStep::Undefined, "$undefined") => { - self.state = SerializationStep::Done; - } - (SerializationStep::Decimal128, "$numberDecimal") - | (SerializationStep::Decimal128, "$numberDecimalBytes") => { - self.state = SerializationStep::Decimal128Value; - value.serialize(&mut **self)?; - } - (SerializationStep::Decimal128Value, "$numberDecimal") => { - value.serialize(&mut **self)?; - self.state = SerializationStep::Done; - } - (SerializationStep::Done, k) => { - return Err(Error::custom(format!( - "expected to end serialization of type, got extra key \"{}\"", - k - ))); - } - (state, k) => { - return Err(Error::custom(format!( - "mismatched serialization step and next key: {:?} + \"{}\"", - state, k - ))); - } - } - - Ok(()) - } - - #[inline] - fn end(self) -> Result { - Ok(()) - } -} - -pub(crate) struct CodeWithScopeSerializer<'a> { - doc: DocumentSerializer<'a>, -} - -impl<'a> CodeWithScopeSerializer<'a> { - #[inline] - fn start(code: &str, rs: &'a mut Serializer) -> Result { - rs.enter_doc(); - rs.add_str_bytes(code.len())?; - - let doc = DocumentSerializer::start(rs)?; - Ok(Self { doc }) - } -} - -impl SerializeMap for CodeWithScopeSerializer<'_> { - type Ok = (); - type Error = Error; - - #[inline] - fn serialize_key(&mut self, key: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - self.doc.serialize_key(key) - } - - #[inline] - fn serialize_value(&mut self, value: &T) -> Result<()> - where - T: Serialize + ?Sized, - { - self.doc.serialize_value(value) - } - - #[inline] - fn end(self) -> Result { - let rs = self.doc.end_doc()?; - // code with scope does not have an additional null terminator. - rs.add_bytes(-1)?; - rs.exit_doc(); - Ok(()) - } -} diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index f0b9491a..1710c190 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -1,5 +1,4 @@ mod document_serializer; -pub(super) mod len_serializer; mod value_serializer; use bytes::BufMut; @@ -23,21 +22,28 @@ use document_serializer::DocumentSerializer; // XXX begin_doc and end_doc appear to be infalliable. pub(crate) trait DocumentBufMut: BufMut { /// Track/record information related to the document started at this point. - fn begin_doc(&mut self) -> Result<()>; + fn begin_doc(&mut self, doc_type: ElementType) -> Result<()>; /// Track/record any information related to the end of the current document. fn end_doc(&mut self) -> Result<()>; /// Return true if begin_doc() has been called at least once. fn in_document(&self) -> bool; } +#[derive(Debug)] +struct StackItem { + len_index: usize, + begin_offset: usize, + doc_type: ElementType, +} + pub(crate) struct LenRecordingDocumentBufMut { stream_len: usize, lens: Vec, - stack: Vec<(usize, usize)>, + stack: Vec, } impl LenRecordingDocumentBufMut { - fn new() -> Self { + pub(crate) fn new() -> Self { Self { stream_len: 0, lens: vec![], @@ -45,28 +51,34 @@ impl LenRecordingDocumentBufMut { } } - fn into_lens(self) -> Vec { - assert!(self.stack.is_empty()); + pub(crate) fn into_lens(self) -> Vec { + assert!(self.stack.is_empty(), "{:?}", self.stack); self.lens } } impl DocumentBufMut for LenRecordingDocumentBufMut { - fn begin_doc(&mut self) -> Result<()> { + fn begin_doc(&mut self, doc_type: ElementType) -> Result<()> { if self.stack.is_empty() && self.stream_len > 0 { panic!("must begin stream with a document.") } let index = self.lens.len(); self.lens.push(0); - self.stack.push((index, self.stream_len)); + self.stack.push(StackItem { + len_index: index, + begin_offset: self.stream_len, + doc_type, + }); self.stream_len += 4; // length value that will be written to the stream. Ok(()) } fn end_doc(&mut self) -> Result<()> { - self.stream_len += 1; // null terminator - let (index, doc_begin) = self.stack.pop().unwrap(); - self.lens[index] = self.stream_len as i32 - doc_begin as i32; + let item = self.stack.pop().expect("paired with begin_doc()"); + if item.doc_type != ElementType::JavaScriptCodeWithScope { + self.stream_len += 1; // null terminator + } + self.lens[item.len_index] = self.stream_len as i32 - item.begin_offset as i32; Ok(()) } @@ -107,7 +119,7 @@ unsafe impl BufMut for LenRecordingDocumentBufMut { pub(crate) struct LenReplayingDocumentBufMut { buf: B, lens: std::vec::IntoIter, - started: bool, + doc_type_stack: Vec, } impl LenReplayingDocumentBufMut { @@ -115,7 +127,7 @@ impl LenReplayingDocumentBufMut { Self { buf, lens: lens.into_iter(), - started: false, + doc_type_stack: vec![], } } @@ -125,20 +137,23 @@ impl LenReplayingDocumentBufMut { } impl DocumentBufMut for LenReplayingDocumentBufMut { - fn begin_doc(&mut self) -> Result<()> { + fn begin_doc(&mut self, doc_type: ElementType) -> Result<()> { let len = self.lens.next().unwrap(); self.buf.put_i32_le(len); - self.started = true; + self.doc_type_stack.push(doc_type); Ok(()) } fn end_doc(&mut self) -> Result<()> { - self.buf.put_u8(0); + let doc_type = self.doc_type_stack.pop().expect("paired with begin_doc()"); + if doc_type != ElementType::JavaScriptCodeWithScope { + self.buf.put_u8(0); + } Ok(()) } fn in_document(&self) -> bool { - self.started + !self.doc_type_stack.is_empty() } } @@ -474,7 +489,7 @@ impl<'a, B: DocumentBufMut> serde::Serializer for &'a mut Serializer { T: serde::Serialize + ?Sized, { self.write_key(ElementType::EmbeddedDocument)?; - let mut d = DocumentSerializer::start(&mut *self)?; + let mut d = DocumentSerializer::start(&mut *self, ElementType::EmbeddedDocument)?; d.serialize_entry(variant, value)?; d.end_doc()?; Ok(()) @@ -483,7 +498,7 @@ impl<'a, B: DocumentBufMut> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_seq(self, _len: Option) -> Result { self.write_key(ElementType::Array)?; - DocumentSerializer::start(&mut *self) + DocumentSerializer::start(&mut *self, ElementType::Array) } #[inline] @@ -515,7 +530,7 @@ impl<'a, B: DocumentBufMut> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_map(self, _len: Option) -> Result { self.write_key(ElementType::EmbeddedDocument)?; - DocumentSerializer::start(&mut *self) + DocumentSerializer::start(&mut *self, ElementType::EmbeddedDocument) } #[inline] @@ -544,7 +559,10 @@ impl<'a, B: DocumentBufMut> serde::Serializer for &'a mut Serializer { )?; match value_type { Some(vt) => Ok(StructSerializer::Value(ValueSerializer::new(self, vt))), - None => Ok(StructSerializer::Document(DocumentSerializer::start(self)?)), + None => Ok(StructSerializer::Document(DocumentSerializer::start( + self, + ElementType::EmbeddedDocument, + )?)), } } @@ -613,7 +631,7 @@ impl<'a, B: DocumentBufMut> VariantSerializer<'a, B> { variant: &'static str, inner_type: VariantInnerType, ) -> Result { - rs.buf.begin_doc()?; + rs.buf.begin_doc(ElementType::EmbeddedDocument)?; let inner = match inner_type { VariantInnerType::Struct => ElementType::EmbeddedDocument, @@ -621,7 +639,7 @@ impl<'a, B: DocumentBufMut> VariantSerializer<'a, B> { }; rs.buf.put_u8(inner as u8); rs.write_cstring(&variant)?; - rs.buf.begin_doc()?; + rs.buf.begin_doc(inner)?; Ok(Self { root_serializer: rs, diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 2633545e..918721a7 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -589,10 +589,10 @@ pub(crate) struct CodeWithScopeSerializer<'a, B> { impl<'a, B: DocumentBufMut> CodeWithScopeSerializer<'a, B> { #[inline] fn start(code: &str, rs: &'a mut Serializer) -> Result { - rs.buf.begin_doc()?; + rs.buf.begin_doc(ElementType::JavaScriptCodeWithScope)?; rs.write_string(code); - let doc = DocumentSerializer::start(rs)?; + let doc = DocumentSerializer::start(rs, ElementType::EmbeddedDocument)?; Ok(Self { doc }) } } @@ -619,6 +619,7 @@ impl SerializeMap for CodeWithScopeSerializer<'_, B> { #[inline] fn end(self) -> Result { - self.doc.end_doc() + let rs = self.doc.end_doc()?; + rs.buf.end_doc() } }