diff --git a/src/lib.rs b/src/lib.rs index d65772983..6d14c7fb5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,7 +117,7 @@ pub extern crate percent_encoding; use encoding::EncodingOverride; #[cfg(feature = "heapsize")] use heapsize::HeapSizeOf; use host::HostInternal; -use parser::{Parser, Context, SchemeType, to_u32}; +use parser::{Parser, Context, SchemeType, to_u32, ViolationFn}; use percent_encoding::{PATH_SEGMENT_ENCODE_SET, USERINFO_ENCODE_SET, percent_encode, percent_decode, utf8_percent_encode}; use std::borrow::Borrow; @@ -135,7 +135,7 @@ use std::str; pub use origin::{Origin, OpaqueOrigin}; pub use host::{Host, HostAndPort, SocketAddrs}; pub use path_segments::PathSegmentsMut; -pub use parser::ParseError; +pub use parser::{ParseError, SyntaxViolation}; pub use slicing::Position; mod encoding; @@ -186,7 +186,7 @@ impl HeapSizeOf for Url { pub struct ParseOptions<'a> { base_url: Option<&'a Url>, encoding_override: encoding::EncodingOverride, - log_syntax_violation: Option<&'a Fn(&'static str)>, + violation_fn: ViolationFn<'a>, } impl<'a> ParseOptions<'a> { @@ -209,9 +209,47 @@ impl<'a> ParseOptions<'a> { self } - /// Call the provided function or closure on non-fatal parse errors. + /// Call the provided function or closure on non-fatal parse errors, passing + /// a static string description. This method is deprecated in favor of + /// `syntax_violation_callback` and is implemented as an adaptor for the + /// latter, passing the `SyntaxViolation` description. Only the last value + /// passed to either method will be used by a parser. + #[deprecated] pub fn log_syntax_violation(mut self, new: Option<&'a Fn(&'static str)>) -> Self { - self.log_syntax_violation = new; + self.violation_fn = match new { + Some(f) => ViolationFn::OldFn(f), + None => ViolationFn::NoOp + }; + self + } + + /// Call the provided function or closure for a non-fatal `SyntaxViolation` + /// when it occurs during parsing. Note that since the provided function is + /// `Fn`, the caller might need to utilize _interior mutability_, such as with + /// a `RefCell`, to collect the violations. + /// + /// ## Example + /// ``` + /// use std::cell::RefCell; + /// use url::{Url, SyntaxViolation}; + /// # use url::ParseError; + /// # fn run() -> Result<(), url::ParseError> { + /// let violations = RefCell::new(Vec::new()); + /// let url = Url::options() + /// .syntax_violation_callback(Some(&|v| violations.borrow_mut().push(v))) + /// .parse("https:////example.com")?; + /// assert_eq!(url.as_str(), "https://example.com/"); + /// assert_eq!(violations.into_inner(), + /// vec!(SyntaxViolation::ExpectedDoubleSlash)); + /// # Ok(()) + /// # } + /// # run().unwrap(); + /// ``` + pub fn syntax_violation_callback(mut self, new: Option<&'a Fn(SyntaxViolation)>) -> Self { + self.violation_fn = match new { + Some(f) => ViolationFn::NewFn(f), + None => ViolationFn::NoOp + }; self } @@ -221,7 +259,7 @@ impl<'a> ParseOptions<'a> { serialization: String::with_capacity(input.len()), base_url: self.base_url, query_encoding_override: self.encoding_override, - log_syntax_violation: self.log_syntax_violation, + violation_fn: self.violation_fn, context: Context::UrlParser, }.parse_url(input) } @@ -229,11 +267,12 @@ impl<'a> ParseOptions<'a> { impl<'a> Debug for ParseOptions<'a> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "ParseOptions {{ base_url: {:?}, encoding_override: {:?}, log_syntax_violation: ", self.base_url, self.encoding_override)?; - match self.log_syntax_violation { - Some(_) => write!(f, "Some(Fn(&'static str)) }}"), - None => write!(f, "None }}") - } + write!(f, + "ParseOptions {{ base_url: {:?}, encoding_override: {:?}, \ + violation_fn: {:?} }}", + self.base_url, + self.encoding_override, + self.violation_fn) } } @@ -363,7 +402,7 @@ impl Url { ParseOptions { base_url: None, encoding_override: EncodingOverride::utf8(), - log_syntax_violation: None, + violation_fn: ViolationFn::NoOp, } } diff --git a/src/parser.rs b/src/parser.rs index b16ecb7f6..92b97afdd 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -70,6 +70,54 @@ impl From<::idna::uts46::Errors> for ParseError { fn from(_: ::idna::uts46::Errors) -> ParseError { ParseError::IdnaError } } +macro_rules! syntax_violation_enum { + ($($name: ident => $description: expr,)+) => { + /// Non-fatal syntax violations that can occur during parsing. + #[derive(PartialEq, Eq, Clone, Copy, Debug)] + pub enum SyntaxViolation { + $( + $name, + )+ + } + + impl SyntaxViolation { + pub fn description(&self) -> &'static str { + match *self { + $( + SyntaxViolation::$name => $description, + )+ + } + } + } + } +} + +syntax_violation_enum! { + Backslash => "backslash", + C0SpaceIgnored => + "leading or trailing control or space character are ignored in URLs", + EmbeddedCredentials => + "embedding authentication information (username or password) \ + in an URL is not recommended", + ExpectedDoubleSlash => "expected //", + ExpectedFileDoubleSlash => "expected // after file:", + FileWithHostAndWindowsDrive => "file: with host and Windows drive letter", + NonUrlCodePoint => "non-URL code point", + NullInFragment => "NULL characters are ignored in URL fragment identifiers", + PercentDecode => "expected 2 hex digits after %", + TabOrNewlineIgnored => "tabs or newlines are ignored in URLs", + UnencodedAtSign => "unencoded @ sign in username or password", +} + +#[cfg(feature = "heapsize")] +known_heap_size!(0, SyntaxViolation); + +impl fmt::Display for SyntaxViolation { + fn fmt(&self, fmt: &mut Formatter) -> fmt::Result { + self.description().fmt(fmt) + } +} + #[derive(Copy, Clone)] pub enum SchemeType { File, @@ -112,18 +160,17 @@ pub struct Input<'i> { impl<'i> Input<'i> { pub fn new(input: &'i str) -> Self { - Input::with_log(input, None) + Input::with_log(input, ViolationFn::NoOp) } - pub fn with_log(original_input: &'i str, log_syntax_violation: Option<&Fn(&'static str)>) - -> Self { + pub fn with_log(original_input: &'i str, vfn: ViolationFn) -> Self { let input = original_input.trim_matches(c0_control_or_space); - if let Some(log) = log_syntax_violation { + if vfn.is_set() { if input.len() < original_input.len() { - log("leading or trailing control or space character are ignored in URLs") + vfn.call(SyntaxViolation::C0SpaceIgnored) } if input.chars().any(|c| matches!(c, '\t' | '\n' | '\r')) { - log("tabs or newlines are ignored in URLs") + vfn.call(SyntaxViolation::TabOrNewlineIgnored) } } Input { chars: input.chars() } @@ -216,11 +263,60 @@ impl<'i> Iterator for Input<'i> { } } +/// Wrapper for syntax violation callback functions. +#[derive(Copy, Clone)] +pub enum ViolationFn<'a> { + NewFn(&'a (Fn(SyntaxViolation) + 'a)), + OldFn(&'a (Fn(&'static str) + 'a)), + NoOp +} + +impl<'a> ViolationFn<'a> { + /// Call with a violation. + pub fn call(self, v: SyntaxViolation) { + match self { + ViolationFn::NewFn(f) => f(v), + ViolationFn::OldFn(f) => f(v.description()), + ViolationFn::NoOp => {} + } + } + + /// Call with a violation, if provided test returns true. Avoids + /// the test entirely if `NoOp`. + pub fn call_if(self, v: SyntaxViolation, test: F) + where F: Fn() -> bool + { + match self { + ViolationFn::NewFn(f) => if test() { f(v) }, + ViolationFn::OldFn(f) => if test() { f(v.description()) }, + ViolationFn::NoOp => {} // avoid test + } + } + + /// True if not `NoOp` + pub fn is_set(self) -> bool { + match self { + ViolationFn::NoOp => false, + _ => true + } + } +} + +impl<'a> fmt::Debug for ViolationFn<'a> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match *self { + ViolationFn::NewFn(_) => write!(f, "NewFn(Fn(SyntaxViolation))"), + ViolationFn::OldFn(_) => write!(f, "OldFn(Fn(&'static str))"), + ViolationFn::NoOp => write!(f, "NoOp") + } + } +} + pub struct Parser<'a> { pub serialization: String, pub base_url: Option<&'a Url>, pub query_encoding_override: EncodingOverride, - pub log_syntax_violation: Option<&'a Fn(&'static str)>, + pub violation_fn: ViolationFn<'a>, pub context: Context, } @@ -237,29 +333,14 @@ impl<'a> Parser<'a> { serialization: serialization, base_url: None, query_encoding_override: EncodingOverride::utf8(), - log_syntax_violation: None, + violation_fn: ViolationFn::NoOp, context: Context::Setter, } } - fn syntax_violation(&self, reason: &'static str) { - if let Some(log) = self.log_syntax_violation { - log(reason) - } - } - - fn syntax_violation_if bool>(&self, reason: &'static str, test: F) { - // Skip test if not logging. - if let Some(log) = self.log_syntax_violation { - if test() { - log(reason) - } - } - } - /// https://url.spec.whatwg.org/#concept-basic-url-parser pub fn parse_url(mut self, input: &str) -> ParseResult { - let input = Input::with_log(input, self.log_syntax_violation); + let input = Input::with_log(input, self.violation_fn); if let Ok(remaining) = self.parse_scheme(input.clone()) { return self.parse_with_scheme(remaining) } @@ -310,12 +391,13 @@ impl<'a> Parser<'a> { } fn parse_with_scheme(mut self, input: Input) -> ParseResult { + use SyntaxViolation::{ExpectedFileDoubleSlash, ExpectedDoubleSlash}; let scheme_end = to_u32(self.serialization.len())?; let scheme_type = SchemeType::from(&self.serialization); self.serialization.push(':'); match scheme_type { SchemeType::File => { - self.syntax_violation_if("expected // after file:", || !input.starts_with("//")); + self.violation_fn.call_if(ExpectedFileDoubleSlash, || !input.starts_with("//")); let base_file_url = self.base_url.and_then(|base| { if base.scheme() == "file" { Some(base) } else { None } }); @@ -335,7 +417,7 @@ impl<'a> Parser<'a> { } } // special authority slashes state - self.syntax_violation_if("expected //", || { + self.violation_fn.call_if(ExpectedDoubleSlash, || { input.clone().take_while(|&c| matches!(c, '/' | '\\')) .collect::() != "//" }); @@ -371,6 +453,7 @@ impl<'a> Parser<'a> { } fn parse_file(mut self, input: Input, mut base_file_url: Option<&Url>) -> ParseResult { + use SyntaxViolation::Backslash; // file state debug_assert!(self.serialization.is_empty()); let (first_char, input_after_first_char) = input.split_first(); @@ -468,10 +551,10 @@ impl<'a> Parser<'a> { } } Some('/') | Some('\\') => { - self.syntax_violation_if("backslash", || first_char == Some('\\')); + self.violation_fn.call_if(Backslash, || first_char == Some('\\')); // file slash state let (next_char, input_after_next_char) = input_after_first_char.split_first(); - self.syntax_violation_if("backslash", || next_char == Some('\\')); + self.violation_fn.call_if(Backslash, || next_char == Some('\\')); if matches!(next_char, Some('/') | Some('\\')) { // file host state self.serialization.push_str("file://"); @@ -623,7 +706,7 @@ impl<'a> Parser<'a> { Some('/') | Some('\\') => { let (slashes_count, remaining) = input.count_matching(|c| matches!(c, '/' | '\\')); if slashes_count >= 2 { - self.syntax_violation_if("expected //", || { + self.violation_fn.call_if(SyntaxViolation::ExpectedDoubleSlash, || { input.clone().take_while(|&c| matches!(c, '/' | '\\')) .collect::() != "//" }); @@ -687,11 +770,9 @@ impl<'a> Parser<'a> { match c { '@' => { if last_at.is_some() { - self.syntax_violation("unencoded @ sign in username or password") + self.violation_fn.call(SyntaxViolation::UnencodedAtSign) } else { - self.syntax_violation( - "embedding authentication information (username or password) \ - in an URL is not recommended") + self.violation_fn.call(SyntaxViolation::EmbeddedCredentials) } last_at = Some((char_count, remaining.clone())) }, @@ -889,7 +970,7 @@ impl<'a> Parser<'a> { match input.split_first() { (Some('/'), remaining) => input = remaining, (Some('\\'), remaining) => if scheme_type.is_special() { - self.syntax_violation("backslash"); + self.violation_fn.call(SyntaxViolation::Backslash); input = remaining }, _ => {} @@ -917,7 +998,7 @@ impl<'a> Parser<'a> { }, '\\' if self.context != Context::PathSegmentSetter && scheme_type.is_special() => { - self.syntax_violation("backslash"); + self.violation_fn.call(SyntaxViolation::Backslash); ends_with_slash = true; break }, @@ -958,7 +1039,7 @@ impl<'a> Parser<'a> { self.serialization.push(':'); } if *has_host { - self.syntax_violation("file: with host and Windows drive letter"); + self.violation_fn.call(SyntaxViolation::FileWithHostAndWindowsDrive); *has_host = false; // FIXME account for this in callers } } @@ -1100,7 +1181,7 @@ impl<'a> Parser<'a> { pub fn parse_fragment(&mut self, mut input: Input) { while let Some((c, utf8_c)) = input.next_utf8() { if c == '\0' { - self.syntax_violation("NULL characters are ignored in URL fragment identifiers") + self.violation_fn.call(SyntaxViolation::NullInFragment) } else { self.check_url_code_point(c, &input); self.serialization.extend(utf8_percent_encode(utf8_c, @@ -1110,15 +1191,16 @@ impl<'a> Parser<'a> { } fn check_url_code_point(&self, c: char, input: &Input) { - if let Some(log) = self.log_syntax_violation { + let vfn = self.violation_fn; + if vfn.is_set() { if c == '%' { let mut input = input.clone(); if !matches!((input.next(), input.next()), (Some(a), Some(b)) if is_ascii_hex_digit(a) && is_ascii_hex_digit(b)) { - log("expected 2 hex digits after %") + vfn.call(SyntaxViolation::PercentDecode) } } else if !is_url_code_point(c) { - log("non-URL code point") + vfn.call(SyntaxViolation::NonUrlCodePoint) } } } diff --git a/tests/unit.rs b/tests/unit.rs index b76a1f80d..10bb86a9d 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -13,6 +13,7 @@ extern crate url; use std::ascii::AsciiExt; use std::borrow::Cow; +use std::cell::{Cell, RefCell}; use std::net::{Ipv4Addr, Ipv6Addr}; use std::path::{Path, PathBuf}; use url::{Host, HostAndPort, Url, form_urlencoded}; @@ -477,3 +478,68 @@ fn test_windows_unc_path() { let url = Url::from_file_path(Path::new(r"\\.\some\path\file.txt")); assert!(url.is_err()); } + +// Test the now deprecated log_syntax_violation method for backward +// compatibility +#[test] +#[allow(deprecated)] +fn test_old_log_violation_option() { + let violation = Cell::new(None); + let url = Url::options() + .log_syntax_violation(Some(&|s| violation.set(Some(s.to_owned())))) + .parse("http:////mozilla.org:42").unwrap(); + assert_eq!(url.port(), Some(42)); + + let violation = violation.take(); + assert_eq!(violation, Some("expected //".to_string())); +} + +#[test] +fn test_syntax_violation_callback() { + use url::SyntaxViolation::*; + let violation = Cell::new(None); + let url = Url::options() + .syntax_violation_callback(Some(&|v| violation.set(Some(v)))) + .parse("http:////mozilla.org:42").unwrap(); + assert_eq!(url.port(), Some(42)); + + let v = violation.take().unwrap(); + assert_eq!(v, ExpectedDoubleSlash); + assert_eq!(v.description(), "expected //"); +} + +#[test] +fn test_syntax_violation_callback_lifetimes() { + use url::SyntaxViolation::*; + let violation = Cell::new(None); + let vfn = |s| violation.set(Some(s)); + + let url = Url::options() + .syntax_violation_callback(Some(&vfn)) + .parse("http:////mozilla.org:42").unwrap(); + assert_eq!(url.port(), Some(42)); + assert_eq!(violation.take(), Some(ExpectedDoubleSlash)); + + let url = Url::options() + .syntax_violation_callback(Some(&vfn)) + .parse("http://mozilla.org\\path").unwrap(); + assert_eq!(url.path(), "/path"); + assert_eq!(violation.take(), Some(Backslash)); +} + +#[test] +fn test_options_reuse() { + use url::SyntaxViolation::*; + let violations = RefCell::new(Vec::new()); + let vfn = |v| violations.borrow_mut().push(v); + + let options = Url::options() + .syntax_violation_callback(Some(&vfn)); + let url = options.parse("http:////mozilla.org").unwrap(); + + let options = options.base_url(Some(&url)); + let url = options.parse("/sub\\path").unwrap(); + assert_eq!(url.as_str(), "http://mozilla.org/sub/path"); + assert_eq!(*violations.borrow(), + vec!(ExpectedDoubleSlash, Backslash)); +}