diff --git a/src/descriptor/sortedmulti.rs b/src/descriptor/sortedmulti.rs index 3bd5fa8e0..d87369458 100644 --- a/src/descriptor/sortedmulti.rs +++ b/src/descriptor/sortedmulti.rs @@ -197,7 +197,7 @@ impl SortedMultiVec { impl policy::Liftable for SortedMultiVec { fn lift(&self) -> Result, Error> { - let ret = policy::semantic::Policy::Threshold( + let ret = policy::semantic::Policy::Thresh( self.k, self.pks .iter() diff --git a/src/descriptor/tr.rs b/src/descriptor/tr.rs index 48deb63e0..129883418 100644 --- a/src/descriptor/tr.rs +++ b/src/descriptor/tr.rs @@ -621,7 +621,7 @@ impl Liftable for TapTree { fn lift_helper(s: &TapTree) -> Result, Error> { match *s { TapTree::Tree { ref left, ref right, height: _ } => { - Ok(Policy::Threshold(1, vec![lift_helper(left)?, lift_helper(right)?])) + Ok(Policy::Thresh(1, vec![lift_helper(left)?, lift_helper(right)?])) } TapTree::Leaf(ref leaf) => leaf.lift(), } @@ -636,7 +636,7 @@ impl Liftable for Tr { fn lift(&self) -> Result, Error> { match &self.tree { Some(root) => { - Ok(Policy::Threshold(1, vec![Policy::Key(self.internal_key.clone()), root.lift()?])) + Ok(Policy::Thresh(1, vec![Policy::Key(self.internal_key.clone()), root.lift()?])) } None => Ok(Policy::Key(self.internal_key.clone())), } diff --git a/src/iter/mod.rs b/src/iter/mod.rs index f8dbed939..771bd4d81 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -77,7 +77,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for &'a policy::Concrete { | Ripemd160(_) | Hash160(_) => Tree::Nullary, And(ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| p.as_ref()).collect()), - Threshold(_, ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), + Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::as_ref).collect()), } } } @@ -90,7 +90,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for Arc> { | Ripemd160(_) | Hash160(_) => Tree::Nullary, And(ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| Arc::clone(p)).collect()), - Threshold(_, ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), + Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::clone).collect()), } } } diff --git a/src/lib.rs b/src/lib.rs index 7b798915e..d5b136b31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,6 +126,7 @@ pub mod miniscript; pub mod plan; pub mod policy; pub mod psbt; +pub mod threshold; #[cfg(test)] mod test_utils; @@ -861,7 +862,7 @@ mod prelude { rc, slice, string::{String, ToString}, sync, - vec::Vec, + vec::{self, Vec}, }; #[cfg(any(feature = "std", test))] pub use std::{ @@ -872,7 +873,7 @@ mod prelude { string::{String, ToString}, sync, sync::Mutex, - vec::Vec, + vec::{self, Vec}, }; #[cfg(all(not(feature = "std"), not(test)))] diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 66284f046..0f12fabe2 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -920,8 +920,9 @@ where compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI); compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI); } - Concrete::Threshold(k, ref subs) => { - let n = subs.len(); + Concrete::Thresh(ref thresh) => { + let k = thresh.k(); + let n = thresh.n(); let k_over_n = k as f64 / n as f64; let mut sub_ast = Vec::with_capacity(n); @@ -931,7 +932,7 @@ where let mut best_ws = Vec::with_capacity(n); let mut min_value = (0, f64::INFINITY); - for (i, ast) in subs.iter().enumerate() { + for (i, ast) in thresh.iter().enumerate() { let sp = sat_prob * k_over_n; //Expressions must be dissatisfiable let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); @@ -949,7 +950,7 @@ where } sub_ext_data.push(best_es[min_value.0].0); sub_ast.push(Arc::clone(&best_es[min_value.0].1.ms)); - for (i, _ast) in subs.iter().enumerate() { + for (i, _ast) in thresh.iter().enumerate() { if i != min_value.0 { sub_ext_data.push(best_ws[i].0); sub_ast.push(Arc::clone(&best_ws[i].1.ms)); @@ -966,7 +967,7 @@ where insert_wrap!(ast_ext); } - let key_vec: Vec = subs + let key_vec: Vec = thresh .iter() .filter_map(|s| { if let Concrete::Key(ref pk) = s.as_ref() { @@ -978,16 +979,16 @@ where .collect(); match Ctx::sig_type() { - SigType::Schnorr if key_vec.len() == subs.len() => { + SigType::Schnorr if key_vec.len() == thresh.n() => { insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec))) } SigType::Ecdsa - if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG => + if key_vec.len() == thresh.n() && thresh.n() <= MAX_PUBKEYS_PER_MULTISIG => { insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec))) } - _ if k == subs.len() => { - let mut it = subs.iter(); + _ if k == thresh.n() => { + let mut it = thresh.iter(); let mut policy = it.next().expect("No sub policy in thresh() ?").clone(); policy = it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]).into()); @@ -1157,6 +1158,7 @@ mod tests { use super::*; use crate::miniscript::{Legacy, Segwitv0, Tap}; use crate::policy::Liftable; + use crate::threshold::Threshold; use crate::{script_num_size, ToPublicKey}; type SPolicy = Concrete; @@ -1301,19 +1303,19 @@ mod tests { let policy: BPolicy = Concrete::Or(vec![ ( 127, - Arc::new(Concrete::Threshold( + Arc::new(Concrete::Thresh(Threshold::new_unchecked( 3, key_pol[0..5].iter().map(|p| (p.clone()).into()).collect(), - )), + ))), ), ( 1, Arc::new(Concrete::And(vec![ Arc::new(Concrete::Older(Sequence::from_height(10000))), - Arc::new(Concrete::Threshold( + Arc::new(Concrete::Thresh(Threshold::new_unchecked( 2, key_pol[5..8].iter().map(|p| (p.clone()).into()).collect(), - )), + ))), ])), ), ]); @@ -1430,7 +1432,7 @@ mod tests { .iter() .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let big_thresh = Concrete::Threshold(*k, pubkeys); + let big_thresh = Concrete::Thresh(Threshold::new_unchecked(*k, pubkeys)); let big_thresh_ms: SegwitMiniScript = big_thresh.compile().unwrap(); if *k == 21 { // N * (PUSH + pubkey + CHECKSIGVERIFY) @@ -1466,8 +1468,8 @@ mod tests { .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Arc::new(Concrete::Threshold(keys_a.len(), keys_a))), - (1, Arc::new(Concrete::Threshold(keys_b.len(), keys_b))), + (1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_a.len(), keys_a)))), + (1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_b.len(), keys_b)))), ]) .compile(); let script_size = thresh_res.clone().and_then(|m| Ok(m.script_size())); @@ -1485,7 +1487,7 @@ mod tests { .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); let thresh_res: Result = - Concrete::Threshold(keys.len(), keys).compile(); + Concrete::Thresh(Threshold::new_unchecked(keys.len(), keys)).compile(); let n_elements = thresh_res .clone() .and_then(|m| Ok(m.max_satisfaction_witness_elements())); @@ -1506,7 +1508,7 @@ mod tests { .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); let thresh_res: Result = - Concrete::Threshold(keys.len() - 1, keys).compile(); + Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile(); let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count())); assert_eq!( thresh_res, @@ -1520,7 +1522,8 @@ mod tests { .iter() .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let thresh_res = Concrete::Threshold(keys.len() - 1, keys).compile::(); + let thresh_res = + Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile::(); let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count())); assert_eq!( thresh_res, diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 19e7771b9..3637d90dd 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -27,6 +27,7 @@ use crate::iter::TreeLike; use crate::miniscript::types::extra_props::TimelockInfo; use crate::prelude::*; use crate::sync::Arc; +use crate::threshold::Threshold; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; use crate::{errstr, AbsLockTime, Error, ForEachKey, MiniscriptKey, Translator}; @@ -67,7 +68,7 @@ pub enum Policy { /// relative probabilities for each one. Or(Vec<(usize, Arc>)>), /// A set of descriptors, satisfactions must be provided for `k` of them. - Threshold(usize, Vec>>), + Thresh(Threshold>>), } impl Policy @@ -210,9 +211,10 @@ impl Policy { }) .collect::>() } - Policy::Threshold(k, ref subs) if *k == 1 => { - let total_odds = subs.len(); - subs.iter() + Policy::Thresh(thresh) if thresh.k() == 1 => { + let total_odds = thresh.n(); + thresh + .iter() .flat_map(|policy| policy.to_tapleaf_prob_vec(prob / total_odds as f64)) .collect::>() } @@ -265,7 +267,7 @@ impl Policy { /// ### TapTree compilation /// /// The policy tree constructed by root-level disjunctions over [`Policy::Or`] and - /// [`Policy::Threshold`](1, ..) which is flattened into a vector (with respective + /// [`Policy::Thresh`](1, ..) which is flattened into a vector (with respective /// probabilities derived from odds) of policies. /// /// For example, the policy `thresh(1,or(pk(A),pk(B)),and(or(pk(C),pk(D)),pk(E)))` gives the @@ -317,7 +319,7 @@ impl Policy { /// ### TapTree compilation /// /// The policy tree constructed by root-level disjunctions over [`Policy::Or`] and - /// [`Policy::Threshold`](k, ..n..) which is flattened into a vector (with respective + /// [`Policy::Thresh`](k, ..n..) which is flattened into a vector (with respective /// probabilities derived from odds) of policies. For example, the policy /// `thresh(1,or(pk(A),pk(B)),and(or(pk(C),pk(D)),pk(E)))` gives the vector /// `[pk(A),pk(B),and(or(pk(C),pk(D)),pk(E)))]`. @@ -430,13 +432,16 @@ impl Policy { .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) .collect::>() } - Policy::Threshold(k, subs) if *k == 1 => { - let total_odds = subs.len(); - subs.iter() + Policy::Thresh(thresh) if thresh.k() == 1 => { + let total_odds = thresh.n(); + thresh + .iter() .map(|pol| (prob / total_odds as f64, pol.clone())) .collect::>() } - Policy::Threshold(k, subs) if *k != subs.len() => generate_combination(subs, prob, *k), + Policy::Thresh(thresh) if thresh.k() != thresh.n() => { + generate_combination(thresh, prob) + } pol => vec![(prob, Arc::new(pol.clone()))], } } @@ -585,7 +590,7 @@ impl Policy { .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect()), - Threshold(ref k, ref subs) => Threshold(*k, (0..subs.len()).map(child_n).collect()), + Thresh(ref thresh) => Thresh(thresh.mapped((0..thresh.n()).map(child_n).collect())), }; translated.push(Arc::new(new_policy)); } @@ -611,8 +616,8 @@ impl Policy { .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect())), - Threshold(k, ref subs) => { - Some(Threshold(*k, (0..subs.len()).map(child_n).collect())) + Thresh(ref thresh) => { + Some(Thresh(thresh.mapped((0..thresh.n()).map(child_n).collect()))) } _ => None, }; @@ -638,7 +643,7 @@ impl Policy { } /// Gets the number of [TapLeaf](`TapTree::Leaf`)s considering exhaustive root-level [`Policy::Or`] - /// and [`Policy::Threshold`] disjunctions for the `TapTree`. + /// and [`Policy::Thresh`] disjunctions for the `TapTree`. #[cfg(feature = "compiler")] fn num_tap_leaves(&self) -> usize { use Policy::*; @@ -649,7 +654,7 @@ impl Policy { let num = match data.node { Or(subs) => (0..subs.len()).map(num_for_child_n).sum(), - Threshold(k, subs) if *k == 1 => (0..subs.len()).map(num_for_child_n).sum(), + Thresh(thresh) if thresh.k() == 1 => (0..thresh.n()).map(num_for_child_n).sum(), _ => 1, }; nums.push(num); @@ -732,9 +737,9 @@ impl Policy { let iter = (0..subs.len()).map(info_for_child_n); TimelockInfo::combine_threshold(1, iter) } - Threshold(ref k, subs) => { - let iter = (0..subs.len()).map(info_for_child_n); - TimelockInfo::combine_threshold(*k, iter) + Thresh(ref thresh) => { + let iter = (0..thresh.n()).map(info_for_child_n); + TimelockInfo::combine_threshold(thresh.k(), iter) } _ => TimelockInfo::default(), }; @@ -770,21 +775,11 @@ impl Policy { return Err(PolicyError::TimeTooFar); } } - And(ref subs) => { - if subs.len() != 2 { - return Err(PolicyError::NonBinaryArgAnd); - } - } Or(ref subs) => { if subs.len() != 2 { return Err(PolicyError::NonBinaryArgOr); } } - Threshold(k, ref subs) => { - if k == 0 || k > subs.len() { - return Err(PolicyError::IncorrectThresh); - } - } _ => {} } } @@ -824,16 +819,16 @@ impl Policy { }); (all_safe, atleast_one_safe && all_non_mall) } - Threshold(k, ref subs) => { - let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold( - (0, 0), - |(safe_count, non_mall_count), (safe, non_mall)| { + Policy::Thresh(ref thresh) => { + let (safe_count, non_mall_count) = thresh + .iter() + .map(|sub| sub.is_safe_nonmalleable()) + .fold((0, 0), |(safe_count, non_mall_count), (safe, non_mall)| { (safe_count + safe as usize, non_mall_count + non_mall as usize) - }, - ); + }); ( - safe_count >= (subs.len() - k + 1), - non_mall_count == subs.len() && safe_count >= (subs.len() - k), + safe_count >= (thresh.n() - thresh.k() + 1), + non_mall_count == thresh.n() && safe_count >= (thresh.n() - thresh.k()), ) } }; @@ -876,10 +871,10 @@ impl fmt::Debug for Policy { } f.write_str(")") } - Policy::Threshold(k, ref subs) => { - write!(f, "thresh({}", k)?; - for sub in subs { - write!(f, ",{:?}", sub)?; + Policy::Thresh(ref thresh) => { + write!(f, "thresh({}", thresh.k())?; + for policy in thresh.iter() { + write!(f, ",{:?}", policy)?; } f.write_str(")") } @@ -919,10 +914,10 @@ impl fmt::Display for Policy { } f.write_str(")") } - Policy::Threshold(k, ref subs) => { - write!(f, "thresh({}", k)?; - for sub in subs { - write!(f, ",{}", sub)?; + Policy::Thresh(ref thresh) => { + write!(f, "thresh({}", thresh.k())?; + for policy in thresh.iter() { + write!(f, ",{}", policy)?; } f.write_str(")") } @@ -1035,8 +1030,8 @@ impl_block_str!( return Err(Error::PolicyError(PolicyError::IncorrectThresh)); } - let thresh = expression::parse_num(top.args[0].name)?; - if thresh >= nsubs || thresh == 0 { + let k = expression::parse_num(top.args[0].name)?; + if k >= nsubs || k == 0 { return Err(Error::PolicyError(PolicyError::IncorrectThresh)); } @@ -1044,7 +1039,10 @@ impl_block_str!( for arg in &top.args[1..] { subs.push(Policy::from_tree(arg)?); } - Ok(Policy::Threshold(thresh as usize, subs.into_iter().map(Arc::new).collect())) + let v = subs.into_iter().map(Arc::new).collect(); + + let thresh = Threshold::new(k as usize, v).map_err(|_| PolicyError::IncorrectThresh)?; + Ok(Policy::Thresh(thresh)) } _ => Err(errstr(top.name)), } @@ -1087,7 +1085,7 @@ fn with_huffman_tree( Ok(node) } -/// Enumerates a [`Policy::Threshold(k, ..n..)`] into `n` different thresh's. +/// Enumerates a [`Policy::Thresh(k, ..n..)`] into `n` different thresh's. /// /// ## Strategy /// @@ -1096,20 +1094,20 @@ fn with_huffman_tree( /// any one of the conditions exclusively. #[cfg(feature = "compiler")] fn generate_combination( - policy_vec: &Vec>>, + policy_thresh: &Threshold>>, prob: f64, - k: usize, ) -> Vec<(f64, Arc>)> { - debug_assert!(k <= policy_vec.len()); - let mut ret: Vec<(f64, Arc>)> = vec![]; - for i in 0..policy_vec.len() { - let policies: Vec>> = policy_vec + let k = policy_thresh.k(); + for i in 0..policy_thresh.n() { + let policies: Vec>> = policy_thresh .iter() .enumerate() .filter_map(|(j, sub)| if j != i { Some(Arc::clone(sub)) } else { None }) .collect(); - ret.push((prob / policy_vec.len() as f64, Arc::new(Policy::Threshold(k, policies)))); + if let Ok(thresh) = Threshold::new(k, policies) { + ret.push((prob / policy_thresh.n() as f64, Arc::new(Policy::Thresh(thresh)))); + } } ret } @@ -1130,7 +1128,8 @@ mod compiler_tests { .map(|p| Arc::new(p)) .collect(); - let combinations = generate_combination(&policies, 1.0, 2); + let thresh = Threshold::new_unchecked(2, policies); + let combinations = generate_combination(&thresh, 1.0); let comb_a: Vec> = vec![ policy_str!("pk(B)"), @@ -1157,10 +1156,10 @@ mod compiler_tests { .map(|sub_pol| { ( 0.25, - Arc::new(Policy::Threshold( + Arc::new(Policy::Thresh(Threshold::new_unchecked( 2, sub_pol.into_iter().map(|p| Arc::new(p)).collect(), - )), + ))), ) }) .collect::>(); diff --git a/src/policy/mod.rs b/src/policy/mod.rs index 138ec45b7..122ce5b3f 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -136,12 +136,12 @@ impl Liftable for Terminal { | Terminal::NonZero(ref sub) | Terminal::ZeroNotEqual(ref sub) => sub.node.lift()?, Terminal::AndV(ref left, ref right) | Terminal::AndB(ref left, ref right) => { - Semantic::Threshold(2, vec![left.node.lift()?, right.node.lift()?]) + Semantic::Thresh(2, vec![left.node.lift()?, right.node.lift()?]) } - Terminal::AndOr(ref a, ref b, ref c) => Semantic::Threshold( + Terminal::AndOr(ref a, ref b, ref c) => Semantic::Thresh( 1, vec![ - Semantic::Threshold(2, vec![a.node.lift()?, b.node.lift()?]), + Semantic::Thresh(2, vec![a.node.lift()?, b.node.lift()?]), c.node.lift()?, ], ), @@ -149,14 +149,14 @@ impl Liftable for Terminal { | Terminal::OrD(ref left, ref right) | Terminal::OrC(ref left, ref right) | Terminal::OrI(ref left, ref right) => { - Semantic::Threshold(1, vec![left.node.lift()?, right.node.lift()?]) + Semantic::Thresh(1, vec![left.node.lift()?, right.node.lift()?]) } Terminal::Thresh(k, ref subs) => { let semantic_subs: Result<_, Error> = subs.iter().map(|s| s.node.lift()).collect(); - Semantic::Threshold(k, semantic_subs?) + Semantic::Thresh(k, semantic_subs?) } Terminal::Multi(k, ref keys) | Terminal::MultiA(k, ref keys) => { - Semantic::Threshold(k, keys.iter().map(|k| Semantic::Key(k.clone())).collect()) + Semantic::Thresh(k, keys.iter().map(|k| Semantic::Key(k.clone())).collect()) } } .normalized(); @@ -198,16 +198,16 @@ impl Liftable for Concrete { Concrete::Hash160(ref h) => Semantic::Hash160(h.clone()), Concrete::And(ref subs) => { let semantic_subs: Result<_, Error> = subs.iter().map(Liftable::lift).collect(); - Semantic::Threshold(2, semantic_subs?) + Semantic::Thresh(2, semantic_subs?) } Concrete::Or(ref subs) => { let semantic_subs: Result<_, Error> = subs.iter().map(|(_p, sub)| sub.lift()).collect(); - Semantic::Threshold(1, semantic_subs?) + Semantic::Thresh(1, semantic_subs?) } - Concrete::Threshold(k, ref subs) => { - let semantic_subs: Result<_, Error> = subs.iter().map(Liftable::lift).collect(); - Semantic::Threshold(k, semantic_subs?) + Concrete::Thresh(ref thresh) => { + let semantic_subs: Result<_, Error> = thresh.iter().map(Liftable::lift).collect(); + Semantic::Thresh(thresh.k(), semantic_subs?) } } .normalized(); @@ -345,10 +345,10 @@ mod tests { .parse() .unwrap(); assert_eq!( - Semantic::Threshold( + Semantic::Thresh( 1, vec![ - Semantic::Threshold( + Semantic::Thresh( 2, vec![ Semantic::Key(key_a), diff --git a/src/policy/semantic.rs b/src/policy/semantic.rs index 9636c0d2a..164f0e9f8 100644 --- a/src/policy/semantic.rs +++ b/src/policy/semantic.rs @@ -42,7 +42,7 @@ pub enum Policy { /// A HASH160 whose preimage must be provided to satisfy the descriptor. Hash160(Pk::Hash160), /// A set of descriptors, satisfactions must be provided for `k` of them. - Threshold(usize, Vec>), + Thresh(usize, Vec>), } impl Policy @@ -79,9 +79,7 @@ impl Policy { | Policy::Hash160(..) | Policy::After(..) | Policy::Older(..) => true, - Policy::Threshold(_, ref subs) => { - subs.iter().all(|sub| sub.real_for_each_key(&mut *pred)) - } + Policy::Thresh(_, ref subs) => subs.iter().all(|sub| sub.real_for_each_key(&mut *pred)), } } @@ -149,10 +147,10 @@ impl Policy { Policy::Hash160(ref h) => t.hash160(h).map(Policy::Hash160), Policy::After(n) => Ok(Policy::After(n)), Policy::Older(n) => Ok(Policy::Older(n)), - Policy::Threshold(k, ref subs) => { + Policy::Thresh(k, ref subs) => { let new_subs: Result>, _> = subs.iter().map(|sub| sub._translate_pk(t)).collect(); - new_subs.map(|ok| Policy::Threshold(k, ok)) + new_subs.map(|ok| Policy::Thresh(k, ok)) } } } @@ -193,7 +191,7 @@ impl Policy { // Helper function to compute the number of constraints in policy. fn n_terminals(&self) -> usize { match self { - &Policy::Threshold(_k, ref subs) => subs.iter().map(|sub| sub.n_terminals()).sum(), + &Policy::Thresh(_k, ref subs) => subs.iter().map(|sub| sub.n_terminals()).sum(), &Policy::Trivial | &Policy::Unsatisfiable => 0, _leaf => 1, } @@ -205,7 +203,7 @@ impl Policy { fn first_constraint(&self) -> Policy { debug_assert!(self.clone().normalized() == self.clone()); match self { - &Policy::Threshold(_k, ref subs) => subs[0].first_constraint(), + &Policy::Thresh(_k, ref subs) => subs[0].first_constraint(), first => first.clone(), } } @@ -216,18 +214,18 @@ impl Policy { // normalized policy pub(crate) fn satisfy_constraint(self, witness: &Policy, available: bool) -> Policy { debug_assert!(self.clone().normalized() == self); - if let Policy::Threshold { .. } = *witness { - // We can't debug_assert on Policy::Threshold. + if let Policy::Thresh { .. } = *witness { + // We can't debug_assert on Policy::Thresh. panic!("should be unreachable") } let ret = match self { - Policy::Threshold(k, subs) => { + Policy::Thresh(k, subs) => { let mut ret_subs = vec![]; for sub in subs { ret_subs.push(sub.satisfy_constraint(witness, available)); } - Policy::Threshold(k, ret_subs) + Policy::Thresh(k, ret_subs) } ref leaf if leaf == witness => { if available { @@ -254,7 +252,7 @@ impl fmt::Debug for Policy { Policy::Hash256(ref h) => write!(f, "hash256({})", h), Policy::Ripemd160(ref h) => write!(f, "ripemd160({})", h), Policy::Hash160(ref h) => write!(f, "hash160({})", h), - Policy::Threshold(k, ref subs) => { + Policy::Thresh(k, ref subs) => { if k == subs.len() { write!(f, "and(")?; } else if k == 1 { @@ -287,7 +285,7 @@ impl fmt::Display for Policy { Policy::Hash256(ref h) => write!(f, "hash256({})", h), Policy::Ripemd160(ref h) => write!(f, "ripemd160({})", h), Policy::Hash160(ref h) => write!(f, "hash160({})", h), - Policy::Threshold(k, ref subs) => { + Policy::Thresh(k, ref subs) => { if k == subs.len() { write!(f, "and(")?; } else if k == 1 { @@ -354,7 +352,7 @@ impl_from_tree!( for arg in &top.args { subs.push(Policy::from_tree(arg)?); } - Ok(Policy::Threshold(nsubs, subs)) + Ok(Policy::Thresh(nsubs, subs)) } ("or", nsubs) => { if nsubs < 2 { @@ -364,7 +362,7 @@ impl_from_tree!( for arg in &top.args { subs.push(Policy::from_tree(arg)?); } - Ok(Policy::Threshold(1, subs)) + Ok(Policy::Thresh(1, subs)) } ("thresh", nsubs) => { if nsubs == 0 || nsubs == 1 { @@ -391,7 +389,7 @@ impl_from_tree!( for arg in &top.args[1..] { subs.push(Policy::from_tree(arg)?); } - Ok(Policy::Threshold(thresh as usize, subs)) + Ok(Policy::Thresh(thresh as usize, subs)) } _ => Err(errstr(top.name)), } @@ -403,7 +401,7 @@ impl Policy { /// `Unsatisfiable`s. Does not reorder any branches; use `.sort`. pub fn normalized(self) -> Policy { match self { - Policy::Threshold(k, subs) => { + Policy::Thresh(k, subs) => { let mut ret_subs = Vec::with_capacity(subs.len()); let subs: Vec<_> = subs.into_iter().map(|sub| sub.normalized()).collect(); @@ -421,15 +419,15 @@ impl Policy { for sub in subs { match sub { Policy::Trivial | Policy::Unsatisfiable => {} - Policy::Threshold(k, subs) => { + Policy::Thresh(k, subs) => { match (is_and, is_or) { (true, true) => { // means m = n = 1, thresh(1,X) type thing. - ret_subs.push(Policy::Threshold(k, subs)); + ret_subs.push(Policy::Thresh(k, subs)); } (true, false) if k == subs.len() => ret_subs.extend(subs), // and case (false, true) if k == 1 => ret_subs.extend(subs), // or case - _ => ret_subs.push(Policy::Threshold(k, subs)), + _ => ret_subs.push(Policy::Thresh(k, subs)), } } x => ret_subs.push(x), @@ -443,11 +441,11 @@ impl Policy { } else if ret_subs.len() == 1 { ret_subs.pop().unwrap() } else if is_and { - Policy::Threshold(ret_subs.len(), ret_subs) + Policy::Thresh(ret_subs.len(), ret_subs) } else if is_or { - Policy::Threshold(1, ret_subs) + Policy::Thresh(1, ret_subs) } else { - Policy::Threshold(m, ret_subs) + Policy::Thresh(m, ret_subs) } } x => x, @@ -480,7 +478,7 @@ impl Policy { | Policy::Hash160(..) => vec![], Policy::After(..) => vec![], Policy::Older(t) => vec![t.to_consensus_u32()], - Policy::Threshold(_, ref subs) => subs.iter().fold(vec![], |mut acc, x| { + Policy::Thresh(_, ref subs) => subs.iter().fold(vec![], |mut acc, x| { acc.extend(x.real_relative_timelocks()); acc }), @@ -508,7 +506,7 @@ impl Policy { | Policy::Hash160(..) => vec![], Policy::Older(..) => vec![], Policy::After(t) => vec![t.to_u32()], - Policy::Threshold(_, ref subs) => subs.iter().fold(vec![], |mut acc, x| { + Policy::Thresh(_, ref subs) => subs.iter().fold(vec![], |mut acc, x| { acc.extend(x.real_absolute_timelocks()); acc }), @@ -538,8 +536,8 @@ impl Policy { Policy::Older(t) } } - Policy::Threshold(k, subs) => { - Policy::Threshold(k, subs.into_iter().map(|sub| sub.at_age(age)).collect()) + Policy::Thresh(k, subs) => { + Policy::Thresh(k, subs.into_iter().map(|sub| sub.at_age(age)).collect()) } x => x, }; @@ -565,8 +563,8 @@ impl Policy { Policy::After(t.into()) } } - Policy::Threshold(k, subs) => { - Policy::Threshold(k, subs.into_iter().map(|sub| sub.at_lock_time(n)).collect()) + Policy::Thresh(k, subs) => { + Policy::Thresh(k, subs.into_iter().map(|sub| sub.at_lock_time(n)).collect()) } x => x, }; @@ -585,7 +583,7 @@ impl Policy { | Policy::Hash256(..) | Policy::Ripemd160(..) | Policy::Hash160(..) => 0, - Policy::Threshold(_, ref subs) => subs.iter().map(|sub| sub.n_keys()).sum::(), + Policy::Thresh(_, ref subs) => subs.iter().map(|sub| sub.n_keys()).sum::(), } } @@ -606,7 +604,7 @@ impl Policy { | Policy::Hash256(..) | Policy::Ripemd160(..) | Policy::Hash160(..) => Some(0), - Policy::Threshold(k, ref subs) => { + Policy::Thresh(k, ref subs) => { let mut sublens: Vec = subs.iter().filter_map(Policy::minimum_n_keys).collect(); if sublens.len() < k { @@ -629,10 +627,10 @@ impl Policy { /// implemented. pub fn sorted(self) -> Policy { match self { - Policy::Threshold(k, subs) => { + Policy::Thresh(k, subs) => { let mut new_subs: Vec<_> = subs.into_iter().map(Policy::sorted).collect(); new_subs.sort(); - Policy::Threshold(k, new_subs) + Policy::Thresh(k, new_subs) } x => x, } @@ -698,7 +696,7 @@ mod tests { let policy = StringPolicy::from_str("or(pk(),older(1000))").unwrap(); assert_eq!( policy, - Policy::Threshold( + Policy::Thresh( 1, vec![ Policy::Key("".to_owned()), @@ -721,7 +719,7 @@ mod tests { let policy = StringPolicy::from_str("or(pk(),UNSATISFIABLE)").unwrap(); assert_eq!( policy, - Policy::Threshold(1, vec![Policy::Key("".to_owned()), Policy::Unsatisfiable,]) + Policy::Thresh(1, vec![Policy::Key("".to_owned()), Policy::Unsatisfiable,]) ); assert_eq!(policy.relative_timelocks(), vec![]); assert_eq!(policy.absolute_timelocks(), vec![]); @@ -731,7 +729,7 @@ mod tests { let policy = StringPolicy::from_str("and(pk(),UNSATISFIABLE)").unwrap(); assert_eq!( policy, - Policy::Threshold(2, vec![Policy::Key("".to_owned()), Policy::Unsatisfiable,]) + Policy::Thresh(2, vec![Policy::Key("".to_owned()), Policy::Unsatisfiable,]) ); assert_eq!(policy.relative_timelocks(), vec![]); assert_eq!(policy.absolute_timelocks(), vec![]); @@ -746,7 +744,7 @@ mod tests { .unwrap(); assert_eq!( policy, - Policy::Threshold( + Policy::Thresh( 2, vec![ Policy::Older(Sequence::from_height(1000)), @@ -770,7 +768,7 @@ mod tests { .unwrap(); assert_eq!( policy, - Policy::Threshold( + Policy::Thresh( 2, vec![ Policy::Older(Sequence::from_height(1000)), @@ -883,7 +881,7 @@ mod tests { "or(and(older(4096),thresh(2,pk(A),pk(B),pk(C))),thresh(11,pk(F1),pk(F2),pk(F3),pk(F4),pk(F5),pk(F6),pk(F7),pk(F8),pk(F9),pk(F10),pk(F11),pk(F12),pk(F13),pk(F14)))").unwrap(); // Very bad idea to add master key,pk but let's have it have 50M blocks let master_key = StringPolicy::from_str("and(older(50000000),pk(master))").unwrap(); - let new_liquid_pol = Policy::Threshold(1, vec![liquid_pol.clone(), master_key]); + let new_liquid_pol = Policy::Thresh(1, vec![liquid_pol.clone(), master_key]); assert!(liquid_pol.clone().entails(new_liquid_pol.clone()).unwrap()); assert!(!new_liquid_pol.entails(liquid_pol.clone()).unwrap()); diff --git a/src/threshold.rs b/src/threshold.rs new file mode 100644 index 000000000..f04aa141c --- /dev/null +++ b/src/threshold.rs @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! A generic (k,n)-threshold type. + +use core::fmt; + +use crate::prelude::{vec, Vec}; + +/// A (k, n)-threshold. +/// +/// This type maintains the following invariants: +/// - n > 0 +/// - k > 0 +/// - k <= n +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Threshold { + k: usize, + v: Vec, +} + +impl Threshold { + /// Creates a `Theshold` after checking that invariants hold. + pub fn new(k: usize, v: Vec) -> Result, Error> { + if v.len() == 0 { + Err(Error::ZeroN) + } else if k == 0 { + Err(Error::ZeroK) + } else if k > v.len() { + Err(Error::BigK) + } else { + Ok(Threshold { k, v }) + } + } + + /// Creates a `Theshold` without checking that invariants hold. + #[cfg(test)] + pub fn new_unchecked(k: usize, v: Vec) -> Threshold { Threshold { k, v } } + + /// Returns `k`, the threshold value. + pub fn k(&self) -> usize { self.k } + + /// Returns `n`, the total number of elements in the threshold. + pub fn n(&self) -> usize { self.v.len() } + + /// Returns a read-only iterator over the threshold elements. + pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() } + + /// Creates an iterator over the threshold elements. + pub fn into_iter(self) -> vec::IntoIter { self.v.into_iter() } + + /// Creates an iterator over the threshold elements. + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { self.v.iter_mut() } + + /// Returns the threshold elements, consuming self. + pub fn into_elements(self) -> Vec { self.v } + + /// Creates a new (k, n)-threshold using a newly mapped vector. + /// + /// Typically this function is called after collecting a vector that was + /// created by iterating this threshold. E.g., + /// + /// `thresh.mapped((0..thresh.n()).map(|element| some_function(element)).collect())` + /// + /// # Panics + /// + /// Panics if the new vector is not the same length as the + /// original i.e., `new.len() != self.n()`. + pub(crate) fn mapped(&self, new: Vec) -> Threshold { + if self.n() != new.len() { + panic!("cannot map to a different length vector") + } + Threshold { k: self.k(), v: new } + } +} + +/// An error attempting to construct a `Threshold`. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum Error { + /// Threshold `n` value must be non-zero. + ZeroN, + /// Threshold `k` value must be non-zero. + ZeroK, + /// Threshold `k` value must be <= `n`. + BigK, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use Error::*; + + match *self { + ZeroN => f.write_str("threshold `n` value must be non-zero"), + ZeroK => f.write_str("threshold `k` value must be non-zero"), + BigK => f.write_str("threshold `k` value must be <= `n`"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error { + fn cause(&self) -> Option<&dyn std::error::Error> { + use Error::*; + + match *self { + ZeroN | ZeroK | BigK => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn threshold_constructor_valid() { + let v = vec![1, 2, 3]; + let n = 3; + + for k in 1..=3 { + let thresh = Threshold::new(k, v.clone()).expect("failed to create threshold"); + assert_eq!(thresh.k(), k); + assert_eq!(thresh.n(), n); + } + } + + #[test] + fn threshold_constructor_invalid() { + let v = vec![1, 2, 3]; + assert!(Threshold::new(0, v.clone()).is_err()); + assert!(Threshold::new(4, v.clone()).is_err()); + } +}