Skip to content

Commit 99a2bc7

Browse files
committed
Add a Threshold<T> type
We have various enums in the codebase that include a `Thresh` variant, we have to explicitly check that invariants are maintained all over the place because these enums are public (eg, `policy::Concrete`). Add a `Threshold<T>` type that abstracts over a threshold and maintains the following invariants: - v.len() > 0 - k > 0 - k <= v.len()
1 parent b60a702 commit 99a2bc7

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub mod miniscript;
126126
pub mod plan;
127127
pub mod policy;
128128
pub mod psbt;
129+
pub mod threshold;
129130

130131
#[cfg(test)]
131132
mod test_utils;

src/threshold.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// SPDX-License-Identifier: CC0-1.0
2+
3+
//! A generic (k,n)-threshold type.
4+
5+
use core::fmt;
6+
7+
use crate::prelude::Vec;
8+
9+
/// A (k, n)-threshold.
10+
///
11+
/// This type maintains the following invariants:
12+
/// - n > 0
13+
/// - k > 0
14+
/// - k <= n
15+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
16+
pub struct Threshold<T> {
17+
k: usize,
18+
v: Vec<T>,
19+
}
20+
21+
impl<T> Threshold<T> {
22+
/// Creates a `Theshold<T>` after checking that invariants hold.
23+
pub fn new(k: usize, v: Vec<T>) -> Result<Threshold<T>, Error> {
24+
if v.len() == 0 {
25+
Err(Error::ZeroN)
26+
} else if k == 0 {
27+
Err(Error::ZeroK)
28+
} else if k > v.len() {
29+
Err(Error::BigK)
30+
} else {
31+
Ok(Threshold { k, v })
32+
}
33+
}
34+
35+
/// Creates a `Theshold<T>` without checking that invariants hold.
36+
pub fn new_unchecked(k: usize, v: Vec<T>) -> Threshold<T> { Threshold { k, v } }
37+
38+
/// Returns `k`, the threshold value.
39+
pub fn k(&self) -> usize { self.k }
40+
41+
/// Returns `n`, the total number of elements in the threshold.
42+
pub fn n(&self) -> usize { self.v.len() }
43+
44+
/// Returns a read-only iterator over the threshold elements.
45+
pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() }
46+
47+
/// Returns the threshold elements, consuming self.
48+
// TODO: Find a better name for this functiion.
49+
pub fn into_elements(self) -> Vec<T> { self.v }
50+
}
51+
52+
/// An error attempting to construct a `Threshold<T>`.
53+
#[derive(Debug, Clone, PartialEq, Eq)]
54+
#[non_exhaustive]
55+
pub enum Error {
56+
/// Threshold `n` value must be non-zero.
57+
ZeroN,
58+
/// Threshold `k` value must be non-zero.
59+
ZeroK,
60+
/// Threshold `k` value must be <= `n`.
61+
BigK,
62+
}
63+
64+
impl fmt::Display for Error {
65+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
66+
use Error::*;
67+
68+
match *self {
69+
ZeroN => f.write_str("threshold `n` value must be non-zero"),
70+
ZeroK => f.write_str("threshold `k` value must be non-zero"),
71+
BigK => f.write_str("threshold `k` value must be <= `n`"),
72+
}
73+
}
74+
}
75+
76+
#[cfg(feature = "std")]
77+
impl std::error::Error for Error {
78+
fn cause(&self) -> Option<&dyn std::error::Error> {
79+
use Error::*;
80+
81+
match *self {
82+
ZeroN | ZeroK | BigK => None,
83+
}
84+
}
85+
}

0 commit comments

Comments
 (0)