diff --git a/Cargo.lock b/Cargo.lock index 1c2f7ad..9fdc800 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,10 +155,13 @@ dependencies = [ [[package]] name = "custom-constraints" -version = "0.1.1" +version = "0.1.0" dependencies = [ "ark-ff", + "ark-std", "getrandom", + "rand", + "rayon", "rstest", "wasm-bindgen", "wasm-bindgen-test", @@ -360,6 +363,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ + "libc", "rand_chacha", "rand_core", ] @@ -379,6 +383,9 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] [[package]] name = "rayon" diff --git a/Cargo.toml b/Cargo.toml index 3cb3273..4adc85e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,16 +7,28 @@ license = "MIT" name = "custom-constraints" readme = "README.md" repository = "https://github.com/autoparallel/custom-constraints" -version = "0.1.1" +version = "0.1.0" [dependencies] -ark-ff = { version = "0.5", default-features = false, features = ["parallel"] } +ark-ff = { version = "0.5", default-features = false, features = [ + "parallel", + "asm", +] } +rayon = { version = "1.10" } [dev-dependencies] -rstest = { version = "0.24", default-features = false } - +ark-std = { version = "0.5", default-features = false, features = ["std"] } +rand = "0.8" +rstest = { version = "0.24", default-features = false } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } wasm-bindgen = { version = "0.2" } wasm-bindgen-test = { version = "0.3" } + +[profile.release] +codegen-units = 1 +lto = "fat" +opt-level = 3 +panic = "abort" +strip = true diff --git a/benches/matrix.rs b/benches/matrix.rs new file mode 100644 index 0000000..cece205 --- /dev/null +++ b/benches/matrix.rs @@ -0,0 +1,96 @@ +#![feature(test)] +extern crate test; + +use ark_ff::{Fp256, MontBackend, MontConfig}; +use custom_constraints::matrix::SparseMatrix; +use test::Bencher; + +// Define a large prime field for testing +#[derive(MontConfig)] +#[modulus = "52435875175126190479447740508185965837690552500527637822603658699938581184513"] +#[generator = "7"] +pub struct FqConfig; +pub type Fq = Fp256>; + +// Helper function to create a random field element +fn random_field_element() -> Fq { + // Create a random field element using ark_ff's random generation + use ark_ff::UniformRand; + let mut rng = ark_std::test_rng(); + Fq::rand(&mut rng) +} + +// Helper to create a sparse matrix with given density +fn create_sparse_matrix(rows: usize, cols: usize, density: f64) -> SparseMatrix { + let mut row_offsets = vec![0]; + let mut col_indices = Vec::new(); + let mut values = Vec::new(); + let mut current_offset = 0; + + for _ in 0..rows { + for j in 0..cols { + if rand::random::() < density { + col_indices.push(j); + values.push(random_field_element()); + current_offset += 1; + } + } + row_offsets.push(current_offset); + } + + SparseMatrix::new(row_offsets, col_indices, values, cols) +} + +const COLS: usize = 100; +const SMALL: usize = 2_usize.pow(10); +const MEDIUM: usize = 2_usize.pow(15); +const LARGE: usize = 2_usize.pow(20); + +// Matrix-vector multiplication benchmarks +#[bench] +fn bench_sparse_matrix_vec_mul_small(b: &mut Bencher) { + let matrix = create_sparse_matrix(SMALL, COLS, 0.1); + let vector: Vec = (0..COLS).map(|_| random_field_element()).collect(); + + b.iter(|| &matrix * &vector); +} + +#[bench] +fn bench_sparse_matrix_vec_mul_medium(b: &mut Bencher) { + let matrix = create_sparse_matrix(MEDIUM, COLS, 0.01); + let vector: Vec = (0..COLS).map(|_| random_field_element()).collect(); + + b.iter(|| &matrix * &vector); +} + +#[bench] +fn bench_sparse_matrix_vec_mul_large(b: &mut Bencher) { + let matrix = create_sparse_matrix(LARGE, LARGE, 0.01); + let vector: Vec = (0..LARGE).map(|_| random_field_element()).collect(); + + b.iter(|| &matrix * &vector); +} + +#[bench] +fn bench_sparse_matrix_hadamard_small(b: &mut Bencher) { + let matrix1 = create_sparse_matrix(SMALL, COLS, 0.1); + let matrix2 = create_sparse_matrix(SMALL, COLS, 0.1); + + b.iter(|| &matrix1 * &matrix2); +} + +#[bench] +fn bench_sparse_matrix_hadamard_medium(b: &mut Bencher) { + let matrix1 = create_sparse_matrix(MEDIUM, COLS, 0.1); + let matrix2 = create_sparse_matrix(MEDIUM, COLS, 0.1); + + b.iter(|| &matrix1 * &matrix2); +} + +#[bench] +fn bench_sparse_matrix_hadamard_large(b: &mut Bencher) { + let matrix1 = create_sparse_matrix(LARGE, COLS, 0.1); + let matrix2 = create_sparse_matrix(LARGE, COLS, 0.1); + + b.iter(|| &matrix1 * &matrix2); +} diff --git a/justfile b/justfile index 3e0be06..b573579 100644 --- a/justfile +++ b/justfile @@ -123,7 +123,7 @@ build-wasm: # Run tests for native architecture and wasm test: @just header "Running native architecture tests" - cargo test --workspace --all-targets --all-features + cargo test --workspace --tests --all-features @just header "Running wasm tests" wasm-pack test --node diff --git a/rust_toolchain.toml b/rust-toolchain.toml similarity index 100% rename from rust_toolchain.toml rename to rust-toolchain.toml diff --git a/src/ccs.rs b/src/ccs.rs new file mode 100644 index 0000000..efcdae5 --- /dev/null +++ b/src/ccs.rs @@ -0,0 +1,211 @@ +//! Implements the Customizable Constraint System (CCS) format. +//! +//! A CCS represents arithmetic constraints through a combination of matrices +//! and multisets, allowing efficient verification of arithmetic computations. +//! +//! The system consists of: +//! - A set of sparse matrices representing linear combinations +//! - Multisets defining which matrices participate in each constraint +//! - Constants applied to each constraint term + +use matrix::SparseMatrix; + +use super::*; + +/// A Customizable Constraint System over a field F. +#[derive(Debug, Default)] +pub struct CCS { + /// Constants for each constraint term + pub constants: Vec, + /// Sets of matrix indices for Hadamard products + pub multisets: Vec>, + /// Constraint matrices + pub matrices: Vec>, +} + +impl CCS { + /// Creates a new empty CCS. + pub fn new() -> Self { + Self::default() + } + + /// Checks if a witness and public input satisfy the constraint system. + /// + /// Forms vector z = (w, 1, x) and verifies that all constraints are satisfied. + /// + /// # Arguments + /// * `w` - The witness vector + /// * `x` - The public input vector + /// + /// # Returns + /// `true` if all constraints are satisfied, `false` otherwise + pub fn is_satisfied(&self, w: &[F], x: &[F]) -> bool { + // Construct z = (w, 1, x) + let mut z = Vec::with_capacity(w.len() + 1 + x.len()); + z.extend(w.iter().copied()); + z.push(F::ONE); + z.extend(x.iter().copied()); + + // Compute all matrix-vector products + let products: Vec> = self + .matrices + .iter() + .enumerate() + .map(|(i, matrix)| { + let result = matrix * &z; + println!("M{i} · z = {result:?}"); + result + }) + .collect(); + + // For each row in the output... + let m = if let Some(first) = products.first() { + first.len() + } else { + return true; // No constraints + }; + + // For each output coordinate... + for row in 0..m { + let mut sum = F::ZERO; + + // For each constraint... + for (i, multiset) in self.multisets.iter().enumerate() { + let mut term = products[multiset[0]][row]; + + for &idx in multiset.iter().skip(1) { + term *= products[idx][row]; + } + + let contribution = self.constants[i] * term; + sum += contribution; + } + + if sum != F::ZERO { + return false; + } + } + + true + } + + /// Creates a new CCS configured for constraints up to the given degree. + /// + /// # Arguments + /// * `d` - Maximum degree of constraints + /// + /// # Panics + /// If d < 2 + pub fn new_degree(d: usize) -> Self { + assert!(d >= 2, "Degree must be positive"); + + let mut ccs = Self { constants: Vec::new(), multisets: Vec::new(), matrices: Vec::new() }; + + // We'll create terms starting from highest degree down to degree 1 + // For a degree d CCS, we need terms of all degrees from d down to 1 + let mut next_matrix_index = 0; + + // Handle each degree from d down to 1 + for degree in (1..=d).rev() { + // For a term of degree k, we need k matrices Hadamard multiplied + let matrix_indices: Vec = (0..degree).map(|i| next_matrix_index + i).collect(); + + // Add this term's multiset and its coefficient + ccs.multisets.push(matrix_indices); + ccs.constants.push(F::ONE); + + // Update our tracking of matrix indices + next_matrix_index += degree; + } + + // Calculate total number of matrices needed: + // For degree d, we need d + (d-1) + ... + 1 matrices + // This is the triangular number formula: n(n+1)/2 + let total_matrices = (d * (d + 1)) / 2; + + // Initialize empty matrices - their content will be filled during conversion + for _ in 0..total_matrices { + ccs.matrices.push(SparseMatrix::new_rows_cols(1, 0)); + } + + ccs + } +} + +impl Display for CCS { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + writeln!(f, "Customizable Constraint System:")?; + + // First, display all matrices with their indices + writeln!(f, "\nMatrices:")?; + for (i, matrix) in self.matrices.iter().enumerate() { + writeln!(f, "M{i} =")?; + writeln!(f, "{matrix}")?; + } + + // Show how constraints are formed from multisets and constants + writeln!(f, "\nConstraints:")?; + + // We expect multisets to come in pairs, each pair forming one constraint + for i in 0..self.multisets.len() { + // Write the constant for the first multiset + write!(f, "{}·(", self.constants[i])?; + + // Write the Hadamard product for the first multiset + if let Some(first_idx) = self.multisets[i].first() { + write!(f, "M{first_idx}")?; + for &idx in &self.multisets[i][1..] { + write!(f, "∘M{idx}")?; + } + } + write!(f, ")")?; + + // Sum up the expressions to the last one + if i < self.multisets.len() - 1 { + write!(f, " + ")?; + } + } + writeln!(f, " = 0")?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mock::F17; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_ccs_satisfaction() { + println!("\nSetting up CCS for constraint x * y = z"); + + // For z = (y, z, 1, x), create matrices: + let mut m1 = SparseMatrix::new_rows_cols(1, 4); + m1.write(0, 3, F17::ONE); // Select x + let mut m2 = SparseMatrix::new_rows_cols(1, 4); + m2.write(0, 0, F17::ONE); // Select y + let mut m3 = SparseMatrix::new_rows_cols(1, 4); + m3.write(0, 1, F17::ONE); // Select z + + println!("Created matrices:"); + println!("M1 (selects x): {m1:?}"); + println!("M2 (selects y): {m2:?}"); + println!("M3 (selects z): {m3:?}"); + + let mut ccs = CCS::new(); + ccs.matrices = vec![m1, m2, m3]; + // Encode x * y - z = 0 + ccs.multisets = vec![vec![0, 1], vec![2]]; + ccs.constants = vec![F17::ONE, F17::from(-1)]; + + println!("\nTesting valid case: x=2, y=3, z=6"); + let x = vec![F17::from(2)]; // public input x = 2 + let w = vec![F17::from(3), F17::from(6)]; // witness y = 3, z = 6 + assert!(ccs.is_satisfied(&w, &x)); + + println!("\nTesting invalid case: x=2, y=3, z=7"); + let w_invalid = vec![F17::from(3), F17::from(7)]; // witness y = 3, z = 7 (invalid) + assert!(!ccs.is_satisfied(&w_invalid, &x)); + } +} diff --git a/src/circuit/expression.rs b/src/circuit/expression.rs new file mode 100644 index 0000000..5d9474c --- /dev/null +++ b/src/circuit/expression.rs @@ -0,0 +1,147 @@ +//! Defines arithmetic expressions used in circuit construction. +//! +//! This module provides types for building and manipulating arithmetic expressions +//! over a field, supporting operations like addition, multiplication, and negation. + +use super::*; + +/// Variables used in arithmetic expressions. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Variable { + /// Public input variable x_i + Public(usize), + /// Witness variable w_i + Witness(usize), + /// Auxiliary variable y_i + Aux(usize), + /// Output variable o_i + Output(usize), +} + +/// An arithmetic expression over a field F. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum Expression { + /// A single variable + Variable(Variable), + /// A constant field element + Constant(F), + /// Sum of expressions + Add(Vec>), + /// Product of expressions + Mul(Vec>), +} + +impl std::ops::Add for Expression { + type Output = Self; + + /// Implements addition between expressions. + /// + /// Flattens nested additions to maintain a canonical form: + /// - `(a + b) + c` becomes `a + b + c` + /// - `a + (b + c)` becomes `a + b + c` + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Self::Add(mut v1), Self::Add(v2)) => { + v1.extend(v2); + Self::Add(v1) + }, + (Self::Add(mut v), rhs) => { + v.push(rhs); + Self::Add(v) + }, + (lhs, Self::Add(mut v)) => { + v.insert(0, lhs); + Self::Add(v) + }, + (lhs, rhs) => Self::Add(vec![lhs, rhs]), + } + } +} + +impl std::ops::Mul for Expression { + type Output = Self; + + /// Implements multiplication between expressions. + /// + /// Flattens nested multiplications to maintain a canonical form: + /// - `(a * b) * c` becomes `a * b * c` + /// - `a * (b * c)` becomes `a * b * c` + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Self::Mul(mut v1), Self::Mul(v2)) => { + v1.extend(v2); + Self::Mul(v1) + }, + (Self::Mul(mut v), rhs) => { + v.push(rhs); + Self::Mul(v) + }, + (lhs, Self::Mul(mut v)) => { + v.insert(0, lhs); + Self::Mul(v) + }, + (lhs, rhs) => Self::Mul(vec![lhs, rhs]), + } + } +} + +impl std::ops::Neg for Expression { + type Output = Self; + + /// Implements negation by multiplying by -1. + fn neg(self) -> Self::Output { + // Negation is multiplication by -1 + Self::Mul(vec![Self::Constant(F::from(-1)), self]) + } +} + +// Implement subtraction +impl std::ops::Sub for Expression { + type Output = Self; + + /// Implements subtraction as addition with negation: a - b = a + (-b) + fn sub(self, rhs: Self) -> Self::Output { + // a - b is the same as a + (-b) + self + (-rhs) + } +} + +impl Display for Expression { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Variable(var) => write!(f, "{var}"), + Self::Constant(c) => write!(f, "{c}"), + Self::Add(terms) => { + write!(f, "(")?; + for (i, term) in terms.iter().enumerate() { + if i > 0 { + write!(f, " + ")?; + } + write!(f, "{term}")?; + } + write!(f, ")") + }, + Self::Mul(factors) => { + write!(f, "(")?; + for (i, factor) in factors.iter().enumerate() { + if i > 0 { + write!(f, " * ")?; + } + write!(f, "{factor}")?; + } + write!(f, ")") + }, + } + } +} + +impl Display for Variable { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + Self::Public(i) => write!(f, "x_{i}"), + Self::Witness(j) => write!(f, "w_{j}"), + Self::Aux(k) => write!(f, "y_{k}"), + Self::Output(l) => write!(f, "o_{l}"), + } + } +} diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs new file mode 100644 index 0000000..595df10 --- /dev/null +++ b/src/circuit/mod.rs @@ -0,0 +1,750 @@ +//! Circuit building and optimization for CCS. +//! +//! Provides a staged compilation pipeline: +//! 1. Building: Initial circuit construction +//! 2. DegreeConstrained: Circuit with enforced degree bounds +//! 3. Optimized: Circuit after optimization passes + +use super::*; + +use std::{collections::HashMap, marker::PhantomData}; + +use crate::{ccs::CCS, matrix::SparseMatrix}; + +pub mod expression; +#[cfg(test)] +mod tests; + +use self::expression::*; + +/// State marker for initial circuit construction. +#[derive(Debug)] +pub struct Building; + +/// State marker for degree-constrained circuit. +#[derive(Debug)] +pub struct DegreeConstrained; + +/// State marker for optimized circuit. +#[derive(Debug)] +pub struct Optimized; + +/// Circuit state trait, implemented by state markers. +pub trait CircuitState {} + +impl CircuitState for Building {} +impl CircuitState for DegreeConstrained {} +impl CircuitState for Optimized {} + +/// An arithmetic circuit with typed state transitions. +#[derive(Debug, Clone, Default)] +pub struct Circuit { + /// Number of public inputs + pub pub_inputs: usize, + /// Number of witness inputs + pub wit_inputs: usize, + /// Number of auxiliary variables + pub aux_count: usize, + /// Number of output variables + pub output_count: usize, + /// Circuit expressions and their assigned variables + expressions: Vec<(Expression, Variable)>, + /// Memoization cache for expressions + memo: HashMap, + /// State type marker + _marker: PhantomData, +} + +impl Circuit { + /// Creates a new empty circuit. + pub fn new() -> Self { + Self { + pub_inputs: 0, + wit_inputs: 0, + aux_count: 0, + output_count: 0, + expressions: Vec::new(), + memo: HashMap::new(), + _marker: PhantomData, + } + } + + /// Creates a public input variable x_i. + pub fn x(&mut self, i: usize) -> Expression { + assert!(i <= self.pub_inputs); + self.pub_inputs = self.pub_inputs.max(i + 1); + Expression::Variable(Variable::Public(i)) + } + + /// Creates a witness variable w_i. + pub fn w(&mut self, i: usize) -> Expression { + assert!(i <= self.wit_inputs); + self.wit_inputs = self.wit_inputs.max(i + 1); + Expression::Variable(Variable::Witness(i)) + } + + /// Creates a constant expression. + pub const fn constant(c: F) -> Expression { + Expression::Constant(c) + } + + /// Adds an internal auxiliary variable. + pub fn add_internal(&mut self, expr: Expression) -> Expression { + self.get_or_create_aux(&expr) + } + + /// Marks an expression as a circuit output. + pub fn mark_output(&mut self, expr: Expression) -> Expression { + if let Expression::Variable(Variable::Aux(aux_idx)) = expr { + // Find and convert the specific auxiliary variable we want to change + for (_, var) in &mut self.expressions { + if *var == Variable::Aux(aux_idx) { + *var = Variable::Output(self.output_count); + break; // Found and converted the variable + } + } + let output_idx = self.output_count; + self.output_count += 1; + self.aux_count -= 1; // Decrease aux count since we converted one + Expression::Variable(Variable::Output(output_idx)) + } else { + // For other expressions, create a new output variable + let output_idx = self.output_count; + let var = Variable::Output(output_idx); + self.output_count += 1; + self.expressions.push((expr, var)); + Expression::Variable(var) + } + } + + // TODO: Remove clone + /// Transitions circuit to degree-constrained state. + pub fn fix_degree(mut self) -> Circuit, F> { + // First, collect all expressions we need to process + let expressions_to_process: Vec<_> = self.expressions.clone(); + + // Clear existing expressions since we'll rebuild them + self.expressions.clear(); + + // Process non-output expressions first + for (expr, var) in &expressions_to_process { + if let Variable::Output(_) = var { + continue; + } + let reduced = self.reduce_degree(expr.clone(), D); + self.expressions.push((reduced, *var)); + } + + // Now handle output expressions + for (expr, var) in &expressions_to_process { + if let Variable::Output(_) = var { + let reduced = self.reduce_degree(expr.clone(), D); + self.expressions.push((reduced, *var)); + } + } + + Circuit { + pub_inputs: self.pub_inputs, + wit_inputs: self.wit_inputs, + aux_count: self.aux_count, + output_count: self.output_count, + expressions: self.expressions, + memo: self.memo, + _marker: PhantomData, + } + } + + /// Creates a new auxiliary variable and increments the counter. + const fn new_aux(&mut self) -> Variable { + let var = Variable::Aux(self.aux_count); + self.aux_count += 1; + var + } + + /// Returns existing auxiliary variable for expression or creates new one. + /// + /// Used for memoization/common subexpression elimination. + fn get_or_create_aux(&mut self, expr: &Expression) -> Expression { + // Create a string representation of the expression for memoization + let expr_key = format!("{expr}"); + + if let Some(&var) = self.memo.get(&expr_key) { + // We've seen this expression before, reuse the existing variable + Expression::Variable(var) + } else { + // First time seeing this expression, create new auxiliary variable + let var = self.new_aux(); + self.expressions.push((expr.clone(), var)); + self.memo.insert(expr_key, var); + Expression::Variable(var) + } + } + + /// Reduces expression degree through auxiliary variable introduction. + /// + /// # Arguments + /// * `expr` - Expression to reduce + /// * `d` - Target degree bound + fn reduce_degree(&mut self, expr: Expression, d: usize) -> Expression { + let current_degree = compute_degree(&expr); + if current_degree <= d { + return expr; + } + + match expr { + Expression::Mul(factors) => { + let mut current_group = Vec::new(); + let mut current_group_degree = 0; + let mut reduced_factors = Vec::new(); + + for factor in factors { + let factor_degree = compute_degree(&factor); + + if current_group_degree + factor_degree > d && !current_group.is_empty() { + let group_expr = if current_group.len() == 1 { + current_group.pop().unwrap() + } else { + Expression::Mul(std::mem::take(&mut current_group)) + }; + reduced_factors.push(self.get_or_create_aux(&group_expr)); + current_group_degree = 0; + } + + let reduced_factor = self.reduce_degree(factor, d); + let reduced_factor_degree = compute_degree(&reduced_factor); + + current_group.push(reduced_factor); + current_group_degree += reduced_factor_degree; + } + + if !current_group.is_empty() { + let group_expr = if current_group.len() == 1 { + current_group.pop().unwrap() + } else { + Expression::Mul(current_group) + }; + reduced_factors.push(self.get_or_create_aux(&group_expr)); + } + + if reduced_factors.len() > 1 { + self.reduce_degree(Expression::Mul(reduced_factors), d) + } else { + reduced_factors.pop().unwrap() + } + }, + Expression::Add(terms) => { + let reduced_terms: Vec<_> = + terms.into_iter().map(|term| self.reduce_degree(term, d)).collect(); + Expression::Add(reduced_terms) + }, + _ => expr, + } + } +} + +impl Circuit, F> { + /// Converts circuit to CCS format. + pub fn into_ccs(self) -> CCS { + let mut ccs = CCS::new_degree(D); + + // Calculate dimensions + let num_cols = 1 + self.pub_inputs + self.wit_inputs + self.aux_count + self.output_count; + + // Initialize matrices + for matrix in &mut ccs.matrices { + *matrix = SparseMatrix::new_rows_cols(num_cols, num_cols); + } + + // Process expressions with a more generic approach + for (expr, var) in &self.expressions { + let row = self.get_z_position(var); + self.create_constraint(&mut ccs, D, row, expr, var); + } + + ccs + } + + /// Creates constraint for an expression at specified row. + /// + /// # Arguments + /// * `ccs` - Target CCS + /// * `d` - Maximum degree + /// * `row` - Row index for constraint + /// * `expr` - Expression to constrain + /// * `output` - Output variable + fn create_constraint( + &self, + ccs: &mut CCS, + d: usize, + row: usize, + expr: &Expression, + output: &Variable, + ) { + // Write -1 times the output variable to the last matrix + let output_pos = self.get_z_position(output); + ccs.matrices.last_mut().unwrap().write(row, output_pos, -F::ONE); + + match expr { + Expression::Add(terms) => { + for term in terms { + self.process_term(ccs, d, row, term); + } + }, + _ => self.process_term(ccs, d, row, expr), + } + } + + /// Processes term in constraint creation. + fn process_term(&self, ccs: &mut CCS, d: usize, row: usize, term: &Expression) { + // First, fully expand the expression + let expanded = expand_expression(term); + + match expanded { + Expression::Add(terms) => { + // Process each term in the addition + for term in terms { + self.process_simple_term(ccs, d, row, &term); + } + }, + _ => self.process_simple_term(ccs, d, row, &expanded), + } + } + + /// Processes a simple (non-compound) term. + fn process_simple_term(&self, ccs: &mut CCS, d: usize, row: usize, term: &Expression) { + match term { + Expression::Mul(factors) => { + // Collect constants and variables + let mut coefficient = F::ONE; + let mut var_factors: Vec<_> = Vec::new(); + + for factor in factors { + match factor { + Expression::Constant(c) => coefficient *= *c, + Expression::Variable(_) => var_factors.push(factor), + _ => panic!("Unexpected non-simple factor after expansion"), + } + } + + let degree = var_factors.len(); + assert!(degree <= d, "Term degree exceeds maximum"); + + if degree == 0 { + // Pure constant term goes in last matrix + ccs.matrices.last_mut().unwrap().write(row, 0, coefficient); + } else { + // Calculate starting matrix index based on variable factors only + let start_idx = if degree == d { 0 } else { (degree + 1..=d).sum() }; + + // Write variable factors with coefficient on first one + for (i, factor) in var_factors.iter().enumerate() { + let pos = self.get_variable_position(factor); + if i == 0 { + ccs.matrices[start_idx + i].write(row, pos, coefficient); + } else { + ccs.matrices[start_idx + i].write(row, pos, F::ONE); + } + } + } + }, + Expression::Variable(_) => { + // Single variable goes in last matrix + let pos = self.get_variable_position(term); + ccs.matrices.last_mut().unwrap().write(row, pos, F::ONE); + }, + Expression::Constant(c) => { + // Single constant goes in last matrix + ccs.matrices.last_mut().unwrap().write(row, 0, *c); + }, + Expression::Add(_) => panic!("Unexpected complex term after expansion"), + } + } + + /// Optimizes the circuit by eliminating auxiliary variables less than degree `D` + pub fn optimize(self) -> Circuit, F> { + println!("\nStarting optimization process..."); + + // Create a new building circuit + let mut new_circuit = Circuit::::new(); + new_circuit.pub_inputs = self.pub_inputs; + new_circuit.wit_inputs = self.wit_inputs; + + println!("\nInitial definitions:"); + let definitions: HashMap> = self + .expressions + .iter() + .map(|(expr, var)| { + println!("{} := {} (degree {})", var, expr, compute_degree(expr)); + (*var, expr.clone()) + }) + .collect(); + + // Map to track how we process auxiliary variables + let mut aux_map = HashMap::new(); + + fn process_expr( + expr: &Expression, + definitions: &HashMap>, + aux_map: &mut HashMap>, + new_circuit: &mut Circuit, + depth: usize, // Add depth parameter for indentation + ) -> Expression { + let indent = " ".repeat(depth); + + println!("{}Processing expression: {}", indent, expr); + + match expr { + Expression::Variable(var @ Variable::Aux(_)) => { + println!("{}Found auxiliary variable: {}", indent, var); + + // Check if we've processed this before + if let Some(mapped_expr) = aux_map.get(var) { + println!("{}Already processed, reusing: {}", indent, mapped_expr); + return mapped_expr.clone(); + } + + // Get its definition + if let Some(def) = definitions.get(var) { + let degree = compute_degree(def); + println!("{}Definition has degree {}: {}", indent, degree, def); + + if degree == D { + println!("{}Creating new aux var for degree {} expression", indent, D); + let new_expr = + process_expr::(def, definitions, aux_map, new_circuit, depth + 1); + let result = new_circuit.add_internal(new_expr); + println!("{}Created new aux var: {}", indent, result); + aux_map.insert(*var, result.clone()); + result + } else { + println!("{}Using definition directly (degree < {})", indent, D); + process_expr::(def, definitions, aux_map, new_circuit, depth + 1) + } + } else { + println!("{}No definition found, using as is", indent); + expr.clone() + } + }, + Expression::Add(terms) => { + println!("{}Processing addition with {} terms", indent, terms.len()); + let processed = Expression::Add( + terms + .iter() + .map(|term| { + let result = + process_expr::(term, definitions, aux_map, new_circuit, depth + 1); + println!("{}Processed term: {} -> {}", indent, term, result); + result + }) + .collect(), + ); + println!("{}Addition result: {}", indent, processed); + processed + }, + Expression::Mul(factors) => { + println!("{}Processing multiplication with {} factors", indent, factors.len()); + let processed = Expression::Mul( + factors + .iter() + .map(|factor| { + let result = + process_expr::(factor, definitions, aux_map, new_circuit, depth + 1); + println!("{}Processed factor: {} -> {}", indent, factor, result); + result + }) + .collect(), + ); + println!("{}Multiplication result: {}", indent, processed); + processed + }, + _ => { + println!("{}Base case: {}", indent, expr); + expr.clone() + }, + } + } + + println!("\nProcessing output expressions:"); + let mut output_exprs = Vec::new(); + for (expr, var) in self.expressions { + if let Variable::Output(_) = var { + println!("\nProcessing output {}", var); + let new_expr = process_expr::(&expr, &definitions, &mut aux_map, &mut new_circuit, 1); + println!("Output {} result: {}", var, new_expr); + output_exprs.push((var, new_expr)); + } + } + + println!("\nMarking outputs in new circuit:"); + for (var, expr) in output_exprs { + println!("Marking output {} := {}", var, expr); + new_circuit.mark_output(expr); + } + + println!("\nFinal new circuit state:"); + for (expr, var) in &new_circuit.expressions { + println!("{} := {} (degree {})", var, expr, compute_degree(expr)); + } + + // Convert to optimized circuit + Circuit { + pub_inputs: new_circuit.pub_inputs, + wit_inputs: new_circuit.wit_inputs, + aux_count: new_circuit.aux_count, + output_count: new_circuit.output_count, + expressions: new_circuit.expressions, + memo: new_circuit.memo, + _marker: PhantomData, + } + } +} + +/// Expands expressions by distributing multiplication over addition. +fn expand_expression(expr: &Expression) -> Expression { + match expr { + Expression::Mul(factors) => { + // First expand each factor + let expanded_factors: Vec<_> = factors.iter().map(|f| expand_expression(f)).collect(); + + // If any factor is an addition, we need to distribute + let mut result = expanded_factors[0].clone(); + for factor in expanded_factors.iter().skip(1) { + result = multiply_expressions(&result, factor); + } + result + }, + Expression::Add(terms) => { + // Expand each term and combine + let expanded_terms: Vec<_> = terms.iter().map(|t| expand_expression(t)).collect(); + Expression::Add(expanded_terms) + }, + // Variables and constants stay as they are + _ => expr.clone(), + } +} + +/// Multiplies two expressions with distribution. +fn multiply_expressions(a: &Expression, b: &Expression) -> Expression { + match (a, b) { + (Expression::Add(terms_a), _) => { + // Distribute multiplication over addition + let distributed: Vec<_> = terms_a.iter().map(|term| multiply_expressions(term, b)).collect(); + Expression::Add(distributed) + }, + (_, Expression::Add(terms_b)) => { + // Distribute multiplication over addition + let distributed: Vec<_> = terms_b.iter().map(|term| multiply_expressions(a, term)).collect(); + Expression::Add(distributed) + }, + (Expression::Mul(factors_a), Expression::Mul(factors_b)) => { + // Combine the factors + let mut new_factors = factors_a.clone(); + new_factors.extend(factors_b.clone()); + Expression::Mul(new_factors) + }, + (Expression::Mul(factors), b) | (b, Expression::Mul(factors)) => { + // Add the new factor to the existing ones + let mut new_factors = factors.clone(); + new_factors.push(b.clone()); + Expression::Mul(new_factors) + }, + (a, b) => Expression::Mul(vec![a.clone(), b.clone()]), + } +} + +impl Circuit, F> { + /// Converts and `Optimized` circuit into CCS. + pub fn into_ccs(self) -> CCS { + let mut ccs = CCS::new_degree(D); + + // Calculate dimensions + let num_cols = 1 + self.pub_inputs + self.wit_inputs + self.aux_count + self.output_count; + + // Initialize matrices + for matrix in &mut ccs.matrices { + *matrix = SparseMatrix::new_rows_cols(num_cols, num_cols); + } + + // Process expressions with a more generic approach + for (expr, var) in &self.expressions { + let row = self.get_z_position(var); + self.create_constraint(&mut ccs, D, row, expr, var); + } + + ccs + } + + /// Creates a constraint in the constraint system + fn create_constraint( + &self, + ccs: &mut CCS, + d: usize, + row: usize, + expr: &Expression, + output: &Variable, + ) { + // Write -1 times the output variable to the last matrix + let output_pos = self.get_z_position(output); + ccs.matrices.last_mut().unwrap().write(row, output_pos, -F::ONE); + + match expr { + Expression::Add(terms) => { + for term in terms { + self.process_term(ccs, d, row, term); + } + }, + _ => self.process_term(ccs, d, row, expr), + } + } + + /// Processes term in constraint creation. + fn process_term(&self, ccs: &mut CCS, d: usize, row: usize, term: &Expression) { + match term { + Expression::Mul(factors) => { + let degree = factors.len(); + assert!(degree <= d, "Term degree exceeds maximum"); + + // For each factor, we need to process it recursively + let mut processed_factors = Vec::new(); + for factor in factors { + match factor { + Expression::Variable(_) | Expression::Constant(_) => { + processed_factors.push(factor.clone()); + }, + Expression::Mul(inner_factors) => { + // If a factor is itself a multiplication, we need to merge it + processed_factors.extend(inner_factors.iter().cloned()); + }, + Expression::Add(terms) => { + // If a factor is an addition, we need to distribute + // This is a more complex case that we might want to handle differently + // For now, we'll just collect all terms + for term in terms { + self.process_term(ccs, d, row, term); + } + return; + }, + } + } + + // Calculate starting matrix index based on degree + let start_idx = match processed_factors.len() { + n if n == D => 0, // Highest degree terms start at 0 + n => { + // For degree k, start after all matrices used by higher degrees + // For degree 3: 0 + // For degree 2: 3 + // For degree 1: 5 + let mut offset = 0; + for deg in (n + 1)..=D { + offset += deg; + } + offset + }, + }; + for (i, factor) in processed_factors.iter().enumerate() { + let pos = self.get_variable_position(factor); + ccs.matrices[start_idx + i].write(row, pos, F::ONE); + } + }, + Expression::Add(terms) => { + // Process each term independently + for term in terms { + self.process_term(ccs, d, row, term); + } + }, + _ => { + // Base case: single variable or constant + let pos = self.get_variable_position(term); + ccs.matrices.last_mut().unwrap().write(row, pos, F::ONE); + }, + } + } +} + +impl Circuit { + /// Returns circuit expressions. + pub fn expressions(&self) -> &[(Expression, Variable)] { + &self.expressions + } + + // TODO: Should this really only be some kind of `#[cfg(test)]` fn? + /// Expands an expression by substituting definitions. + pub fn expand(&self, expr: &Expression) -> Expression { + match expr { + // Base cases: constants and input variables remain unchanged + Expression::Constant(_) + | Expression::Variable(Variable::Public(_) | Variable::Witness(_)) => expr.clone(), + + // For auxiliary and output variables, look up their definition + Expression::Variable(var @ (Variable::Aux(_) | Variable::Output(_))) => { + self.get_definition(var).map_or_else(|| expr.clone(), |definition| self.expand(definition)) + }, + + Expression::Add(terms) => { + Expression::Add(terms.iter().map(|term| self.expand(term)).collect()) + }, + Expression::Mul(factors) => { + Expression::Mul(factors.iter().map(|factor| self.expand(factor)).collect()) + }, + } + } + + /// Gets definition for a variable if it exists. + fn get_definition(&self, var: &Variable) -> Option<&Expression> { + match var { + Variable::Aux(idx) | Variable::Output(idx) => { + self.expressions.get(*idx).map(|(expr, _)| expr) + }, + _ => None, + } + } + + /// Gets position of variable in z vector. + fn get_z_position(&self, var: &Variable) -> usize { + match var { + // Public inputs start at position 1 + Variable::Public(i) => 1 + i, + // Witness variables follow public inputs + Variable::Witness(i) => 1 + self.pub_inputs + i, + // Auxiliary variables follow witness variables + Variable::Aux(i) => 1 + self.pub_inputs + self.wit_inputs + i, + // Output variables follow auxiliary variables + Variable::Output(i) => 1 + self.pub_inputs + self.wit_inputs + self.aux_count + i, + } + } + + /// Gets position of expression in z vector. + fn get_variable_position(&self, expr: &Expression) -> usize { + match expr { + Expression::Variable(var) => self.get_z_position(var), + Expression::Constant(_) => 0, + Expression::Mul(factors) => { + // For a product, we need to handle each factor + assert!(factors.len() == 1, "Expected simplified multiplication"); + self.get_variable_position(&factors[0]) + }, + Expression::Add(terms) => { + // For a sum, we need to handle each term + assert!(terms.len() == 1, "Expected simplified addition"); + self.get_variable_position(&terms[0]) + }, + } + } +} + +/// Computes the degree of an expression. +fn compute_degree(expr: &Expression) -> usize { + match expr { + // Constants are degree 0 + Expression::Constant(_) => 0, + // Base cases: variables degree 1 + Expression::Variable(_) => 1, + + // For addition, take the maximum degree of any term + Expression::Add(terms) => terms.iter().map(|term| compute_degree(term)).max().unwrap_or(0), + + // For multiplication, sum the degrees of all factors + Expression::Mul(factors) => factors.iter().map(|factor| compute_degree(factor)).sum(), + } +} diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs new file mode 100644 index 0000000..5e5ab7e --- /dev/null +++ b/src/circuit/tests.rs @@ -0,0 +1,338 @@ +// TODO: all these tests really need to check things more strictly +use super::*; + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_compute_degree_base_cases() { + // Constants and variables should have degree 1 + let constant = Expression::Constant(F17::from(5)); + assert_eq!(compute_degree(&constant), 0, "Constants should have degree 0"); + + let public = Expression::::Variable(Variable::Public(0)); + assert_eq!(compute_degree(&public), 1, "Public variables should have degree 1"); + + let witness = Expression::::Variable(Variable::Witness(0)); + assert_eq!(compute_degree(&witness), 1, "Witness variables should have degree 1"); + + let aux = Expression::::Variable(Variable::Aux(0)); + assert_eq!(compute_degree(&aux), 1, "Auxiliary variables should have degree 1"); +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_compute_degree_addition() { + // Addition should take the maximum degree of its terms + let x = Expression::::Variable(Variable::Public(0)); + let y = Expression::Variable(Variable::Witness(0)); + + // Simple addition: x + y (degree 1) + let simple_add = Expression::Add(vec![x.clone(), y.clone()]); + assert_eq!(compute_degree(&simple_add), 1, "x + y should have degree 1"); + + // x + (x * y) (degree 2) + let mul = Expression::Mul(vec![x.clone(), y.clone()]); + let mixed_add = Expression::Add(vec![x.clone(), mul.clone()]); + assert_eq!(compute_degree(&mixed_add), 2, "x + (x * y) should have degree 2"); + + // x + (x * y) + (x * y * y) (degree 3) + let triple_mul = Expression::Mul(vec![x.clone(), y.clone(), y.clone()]); + let complex_add = Expression::Add(vec![x, mul, triple_mul]); + assert_eq!(compute_degree(&complex_add), 3, "x + (x * y) + (x * y * y) should have degree 3"); +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_compute_degree_multiplication() { + // Multiplication should sum the degrees of its factors + let x = Expression::::Variable(Variable::Public(0)); + let y = Expression::Variable(Variable::Witness(0)); + + // Simple multiplication: x * y (degree 2) + let simple_mul = Expression::Mul(vec![x.clone(), y.clone()]); + assert_eq!(compute_degree(&simple_mul), 2, "x * y should have degree 2"); + + // x * y * y (degree 3) + let triple_mul = Expression::Mul(vec![x.clone(), y.clone(), y.clone()]); + assert_eq!(compute_degree(&triple_mul), 3, "x * y * y should have degree 3"); + + // (x * y) * (y * y) (degree 4) + let double_mul = Expression::Mul(vec![y.clone(), y.clone()]); + let nested_mul = Expression::Mul(vec![simple_mul, double_mul]); + assert_eq!(compute_degree(&nested_mul), 4, "(x * y) * (y * y) should have degree 4"); +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_compute_degree_complex_expressions() { + let x = Expression::::Variable(Variable::Public(0)); + let y = Expression::Variable(Variable::Witness(0)); + + // Build (x * y * y) + (x + y) + let triple_mul = Expression::Mul(vec![x.clone(), y.clone(), y.clone()]); + let simple_add = Expression::Add(vec![x.clone(), y.clone()]); + let complex = Expression::Add(vec![triple_mul, simple_add]); + assert_eq!(compute_degree(&complex), 3, "(x * y * y) + (x + y) should have degree 3"); + + // Build ((x * y) + y) * (x * x) + let mul_add = Expression::Add(vec![Expression::Mul(vec![x.clone(), y.clone()]), y.clone()]); + let square = Expression::Mul(vec![x.clone(), x.clone()]); + let complex_mul = Expression::Mul(vec![mul_add, square]); + assert_eq!(compute_degree(&complex_mul), 4, "((x * y) + y) * (x * x) should have degree 4"); +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_expression_arithmetic() { + let mut builder = Circuit::new(); + + // Create base values + let x0 = builder.x(0); + let x1 = builder.x(1); + let w0 = builder.w(0); + let three = Circuit::constant(F17::from(3)); + + // Test negation: -x0 + let neg_x0 = -x0; + let y0 = builder.add_internal(neg_x0); + + // Test subtraction: w0 - x1 + let sub_expr = w0 - x1; + let y1 = builder.add_internal(sub_expr); + + // Test complex expression: 3 * (w0 - x1) - (-x0) + let complex_expr = three * y1 - y0; + let y2 = builder.add_internal(complex_expr); + builder.mark_output(y2); + + println!("\nOriginal expressions:"); + for (expr, var) in builder.expressions() { + if let Variable::Aux(idx) = var { + println!("y_{} := {}", idx, expr); + } + } + + println!("\nExpanded forms:"); + for (expr, var) in builder.expressions() { + match var { + Variable::Aux(idx) => println!("Auxiliary y_{} := {}", idx, builder.expand(expr)), + Variable::Output(idx) => println!("Output o_{} := {}", idx, builder.expand(expr)), + _ => println!("Other {} := {}", var, builder.expand(expr)), + } + } +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_multiple_outputs() { + let mut builder = Circuit::<_, F17>::new(); + + // Let's create a circuit that computes several outputs + let x = builder.x(0); // Public input x + let y = builder.w(0); // Witness y + let z = builder.w(1); // Witness z + + // First output: x * y + let mul1 = x.clone() * y.clone(); + let o1 = builder.mark_output(mul1); // This should become o_0 + + // Second output: y * z + let mul2 = y * z; + let _o2 = builder.mark_output(mul2); // This should become o_1 + + // Third output: x * o1 (using a previous output) + let mul3 = x * o1; + let _o3 = builder.mark_output(mul3); // This should become o_2 + + println!("\nMultiple outputs test:"); + for (expr, var) in builder.expressions() { + match var { + Variable::Output(idx) => println!("Output o_{} := {}", idx, builder.expand(expr)), + _ => println!("Other {} := {}", var, builder.expand(expr)), + } + } + + // Verify we have the right number of outputs + assert_eq!(builder.output_count, 3); + assert_eq!(builder.aux_count, 0); // No auxiliary variables needed +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_aux_to_output_conversion() { + let mut builder = Circuit::<_, F17>::new(); + + // Create a more complex computation that needs auxiliary variables + let x = builder.x(0); + let _y = builder.w(0); + + // Create some intermediate computations + let square = x.clone() * x.clone(); + let aux1 = builder.add_internal(square); // y_0 + + let cube = aux1.clone() * x; + let aux2 = builder.add_internal(cube); // y_1 + + // Now convert both to outputs + builder.mark_output(aux1); // Should become o_0 + builder.mark_output(aux2); // Should become o_1 + + println!("\nAux to output conversion test:"); + for (expr, var) in builder.expressions() { + match var { + Variable::Output(idx) => println!("Output o_{} := {}", idx, builder.expand(expr)), + Variable::Aux(idx) => println!("Aux y_{} := {}", idx, builder.expand(expr)), + _ => println!("Other {} := {}", var, builder.expand(expr)), + } + } + + // Verify our counts + assert_eq!(builder.output_count, 2); + assert_eq!(builder.aux_count, 0); // Both aux vars were converted +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_mixed_aux_and_output() { + let mut builder = Circuit::<_, F17>::new(); + + let x = builder.x(0); + let y = builder.w(0); + + // Create an auxiliary computation we'll keep as auxiliary + let square = x.clone() * x.clone(); + let aux1 = builder.add_internal(square); // y_0 + + // Create an output directly + let direct_output = y.clone() * y; + let _o1 = builder.mark_output(direct_output); // o_0 + + // Create another auxiliary and convert it + let cube = aux1 * x; + let aux2 = builder.add_internal(cube); // y_1 + builder.mark_output(aux2); // Converts to o_1 + + println!("\nMixed aux and output test:"); + for (expr, var) in builder.expressions() { + match var { + Variable::Output(idx) => println!("Output o_{} := {}", idx, builder.expand(expr)), + Variable::Aux(idx) => println!("Aux y_{} := {}", idx, builder.expand(expr)), + _ => println!("Other {} := {}", var, builder.expand(expr)), + } + } + + // Verify final state + assert_eq!(builder.output_count, 2); + assert_eq!(builder.aux_count, 1); // aux1 remains as auxiliary +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_reduce_degree() { + let mut builder = Circuit::<_, F17>::new(); + + // Create expression: x0 * x1 * x2 * x3 + let x0 = builder.x(0); + let x1 = builder.x(1); + let x2 = builder.x(2); + let x3 = builder.x(3); + let expr = x0 * x1 * x2 * x3; // degree 4 + + // Reduce to degree 2 + let reduced = builder.reduce_degree(expr, 2); + + println!("Reduced expression: {reduced}"); + println!("\nAuxiliary variables:"); + for (expr, var) in builder.expressions() { + println!("{var} := {expr}"); + } +} + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_complex_degree_reduction() { + let mut builder = Circuit::<_, F17>::new(); + + // Create inputs and build expressions (same as before) + let x = builder.x(0); + let y = builder.w(0); + + let x_cubed = x.clone() * x.clone() * x.clone(); + let y_squared = y.clone() * y.clone(); + let common_term = x_cubed + y_squared.clone(); + + // Build P1: (x^3 + y^2)^2 * (x + y) + let common_term_squared = common_term.clone() * common_term.clone(); + let x_plus_y = x.clone() + y; + let p1 = common_term_squared.clone() * x_plus_y; + + // Build P2: x * y^4 + (x^3 + y^2)^3 + let y_fourth = y_squared.clone() * y_squared.clone(); + let term1 = x.clone() * y_fourth; + let common_term_cubed = common_term * common_term_squared; + let p2 = term1 + common_term_cubed; + + // Print original expressions before reduction + println!("\nOriginal P1: {}", p1); + println!("Original P2: {}", p2); + + // Mark outputs + builder.mark_output(p1); + builder.mark_output(p2); + + println!("\nOriginal circuit state:"); + for (expr, var) in builder.expressions() { + match var { + Variable::Aux(idx) => println!("y_{} := {}", idx, expr), + Variable::Output(idx) => println!("o_{} := {}", idx, expr), + _ => println!("{} := {}", var, expr), + } + } + + // Now fix the degree + let deg_3_circuit = builder.fix_degree::<3>(); + + // Verify degrees after fixing + println!("\nDegree-constrained expressions:"); + for (expr, var) in deg_3_circuit.expressions() { + let degree = compute_degree(expr); + println!("{} := {} (degree {})", var, expr, degree); + assert!(degree <= 3, "Expression {} exceeds degree bound", var); + } + + let optimized_circuit = deg_3_circuit.optimize(); + let ccs = optimized_circuit.into_ccs(); + println!("\nFinal CCS:\n{}", ccs); +} + +// TODO: This test can show that if we run the optimzer, we may unjustifiably kill off constraints. +// So we need to rethink what optimization means in this case. +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn test_raw_low_degree_constraint_not_removed() { + let mut builder = Circuit::<_, F17>::new(); + + // Create inputs and build expressions (same as before) + let x = builder.x(0); + let y = builder.w(0); + + // Enforce x is a bool + let bool = x.clone() * (Expression::Constant(F17::from(1)) - x.clone()); + let toggle = x * y + Expression::Constant(F17::ONE); + + builder.add_internal(bool); + builder.add_internal(toggle); + + let fixed = builder.fix_degree::<3>(); + // Verify degrees after fixing + + println!("\nDegree-constrained expressions:"); + for (expr, var) in fixed.expressions() { + let degree = compute_degree(expr); + println!("{} := {} (degree {})", var, expr, degree); + assert!(degree <= 3, "Expression {} exceeds degree bound", var); + } + + let ccs = fixed.into_ccs(); + + println!("CCS: {ccs}"); +} diff --git a/src/lib.rs b/src/lib.rs index c0123f1..d0afb6f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,41 @@ #![doc = include_str!("../README.md")] -#![warn(missing_docs)] +#![warn(missing_docs, clippy::missing_docs_in_private_items)] +//! Custom Constraints provides an implementation of Customizable Constraint Systems (CCS), +//! a framework for zero-knowledge proof systems. +//! +//! This crate provides tools for: +//! - Building arithmetic circuits with degree bounds +//! - Converting circuits to CCS form +//! - Optimizing circuit representations +//! - Working with sparse matrix operations +//! +//! The core components are: +//! - [`Circuit`](circuit::Circuit): For constructing and manipulating arithmetic circuits +//! - [`CCS`](ccs::CCS): The customizable constraint system representation +//! - [`SparseMatrix`](matrix::SparseMatrix): Efficient sparse matrix operations + +use ark_ff::Field; +#[cfg(test)] +use mock::F17; +use std::fmt::{self, Display, Formatter}; #[cfg(all(target_arch = "wasm32", test))] use wasm_bindgen_test::wasm_bindgen_test; -#[cfg(test)] use {mock::F17, rstest::rstest}; +pub mod ccs; +pub mod circuit; pub mod matrix; #[cfg(test)] mod mock { + //! Test utilities including a simple finite field implementation. use ark_ff::{Fp, MontBackend, MontConfig}; + #[allow(unexpected_cfgs)] #[derive(MontConfig)] #[modulus = "17"] #[generator = "3"] pub struct F17Config; + /// A finite field of order 17 used for testing. pub type F17 = Fp, 1>; } diff --git a/src/matrix.rs b/src/matrix.rs index be996ad..16333b8 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,15 +1,29 @@ -use ark_ff::Field; +//! Provides a Compressed Sparse Row (CSR) matrix implementation optimized for efficient operations. +//! +//! The [`SparseMatrix`] type is designed to handle sparse matrices efficiently by storing only +//! non-zero elements in a compressed format. It supports matrix-vector multiplication and +//! element-wise (Hadamard) matrix multiplication. + +use std::ops::Mul; use super::*; +// TODO: Probably just combine values with their col indices +/// A sparse matrix implementation using the Compressed Sparse Row (CSR) format. +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct SparseMatrix { + /// Offsets into col_indices/values for the start of each row row_offsets: Vec, + /// Column indices of non-zero elements col_indices: Vec, - values: Vec, - num_cols: usize, + /// Values of non-zero elements + values: Vec, + /// Number of columns in the matrix + num_cols: usize, } impl SparseMatrix { + /// Creates a new sparse matrix from its CSR components. pub fn new( row_offsets: Vec, col_indices: Vec, @@ -20,10 +34,81 @@ impl SparseMatrix { Self { row_offsets, col_indices, values, num_cols } } - pub fn mul_vector(&self, rhs: &[F]) -> Vec { + /// Creates an empty sparse matrix with the specified dimensions. + pub fn new_rows_cols(num_rows: usize, num_cols: usize) -> SparseMatrix { + Self { row_offsets: vec![0; num_rows + 1], col_indices: vec![], values: vec![], num_cols } + } + + /// Writes a value to the specified position in the matrix. + /// + /// # Panics + /// - If row or column indices are out of bounds + /// - If attempting to write a zero value + pub fn write(&mut self, row: usize, col: usize, val: F) { + // Check bounds + assert!(row < self.row_offsets.len() - 1, "Row index out of bounds"); + assert!(col < self.num_cols, "Column index out of bounds"); + assert_ne!(val, F::ZERO, "Trying to add a zero element into the `SparseMatrix`!"); + + // Get the range of indices for the current row + let start = self.row_offsets[row]; + let end = self.row_offsets[row + 1]; + + // Search for the column index in the current row + let pos = + self.col_indices[start..end].binary_search(&col).map_or_else(|i| start + i, |i| start + i); + + if pos < end && self.col_indices[pos] == col { + // Overwrite existing value + self.values[pos] = val; + } else { + // Insert new value + self.col_indices.insert(pos, col); + self.values.insert(pos, val); + + // Update row offsets for subsequent rows + for i in row + 1..self.row_offsets.len() { + self.row_offsets[i] += 1; + } + } + } + + #[allow(unused)] + /// Removes an entry from the [`SparseMatrix`] + fn remove(&mut self, row: usize, col: usize) { + // Get the range of indices for the current row + let start = self.row_offsets[row]; + let end = self.row_offsets[row + 1]; + + // Search for the column index in the current row + if let Ok(pos) = self.col_indices[start..end].binary_search(&col) { + let pos = start + pos; + + // Remove the element + self.col_indices.remove(pos); + self.values.remove(pos); + + // Update row offsets for subsequent rows + for i in row + 1..self.row_offsets.len() { + self.row_offsets[i] -= 1; + } + } + } +} + +impl Mul<&Vec> for &SparseMatrix { + type Output = Vec; + + /// Performs matrix-vector multiplication. + /// + /// # Panics + /// If the vector length doesn't match the matrix column count. + fn mul(self, rhs: &Vec) -> Self::Output { + // TODO: Make error assert_eq!(rhs.len(), self.num_cols, "Invalid vector length"); let mut result = vec![F::ZERO; self.row_offsets.len() - 1]; + #[allow(clippy::needless_range_loop)] for row in 0..self.row_offsets.len() - 1 { let start = self.row_offsets[row]; let end = self.row_offsets[row + 1]; @@ -39,20 +124,123 @@ impl SparseMatrix { } } +impl Mul<&SparseMatrix> for &SparseMatrix { + type Output = SparseMatrix; + + /// Performs element-wise (Hadamard) matrix multiplication. + /// + /// # Panics + /// If matrix dimensions don't match. + fn mul(self, rhs: &SparseMatrix) -> Self::Output { + // We'll implement elementwise multiplication but first check dimensions match + assert_eq!(self.num_cols, rhs.num_cols, "Matrices must have same dimensions"); + + // For the Hadamard product, we'll only have non-zero elements where both matrices + // have non-zero elements at the same position + let mut result_values = Vec::new(); + let mut result_col_indices = Vec::new(); + let mut result_row_offsets = vec![0]; + + // Process each row + for row in 0..self.row_offsets.len() - 1 { + // Get the ranges for non-zero elements in this row for both matrices + let self_start = self.row_offsets[row]; + let self_end = self.row_offsets[row + 1]; + let rhs_start = rhs.row_offsets[row]; + let rhs_end = rhs.row_offsets[row + 1]; + + // Create iterators over the non-zero elements in this row + let mut self_iter = (self_start..self_end).map(|i| (self.col_indices[i], self.values[i])); + let mut rhs_iter = (rhs_start..rhs_end).map(|i| (rhs.col_indices[i], rhs.values[i])); + + // Keep track of our position in each iterator + let mut self_next = self_iter.next(); + let mut rhs_next = rhs_iter.next(); + + // Merge the non-zero elements + while let (Some((self_col, self_val)), Some((rhs_col, rhs_val))) = (self_next, rhs_next) { + match self_col.cmp(&rhs_col) { + std::cmp::Ordering::Equal => { + // When columns match, multiply the values + result_values.push(self_val * rhs_val); + result_col_indices.push(self_col); + self_next = self_iter.next(); + rhs_next = rhs_iter.next(); + }, + std::cmp::Ordering::Less => { + // Skip elements only in self + self_next = self_iter.next(); + }, + std::cmp::Ordering::Greater => { + // Skip elements only in rhs + rhs_next = rhs_iter.next(); + }, + } + } + + // Record where this row ends + result_row_offsets.push(result_values.len()); + } + + SparseMatrix::new(result_row_offsets, result_col_indices, result_values, self.num_cols) + } +} + +impl Display for SparseMatrix { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // First, we'll find the maximum width needed for any number + // This helps us align columns nicely + let max_width = self.values.iter().map(|v| format!("{v}").len()).max().unwrap_or(1).max(1); // At least 1 character for "0" + + // For each row... + for row in 0..self.row_offsets.len() - 1 { + write!(f, "[")?; + + // Find the non-zero elements in this row + let row_start = self.row_offsets[row]; + let row_end = self.row_offsets[row + 1]; + let mut current_col = 0; + + // Process each column, inserting zeros where needed + for col in 0..self.num_cols { + // Add spacing between elements + if col > 0 { + write!(f, " ")?; + } + + // Check if we have a non-zero element at this position + if current_col < row_end - row_start && self.col_indices[row_start + current_col] == col { + // We found a non-zero element + let val = &self.values[row_start + current_col]; + write!(f, "{val:>max_width$}")?; + current_col += 1; + } else { + // This position is zero + write!(f, "{:>width$}", 0, width = max_width)?; + } + } + + writeln!(f, "]")?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { + use super::*; #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - fn test_sparse_matrix_multiplication() { - // Let's create this matrix: + fn test_write() { + // Create this matrix: // [2 0 1] // [0 3 0] // [4 0 5] - let row_offsets = vec![0, 2, 3, 5]; // Points to starts of rows - let col_indices = vec![0, 2, 1, 0, 2]; // Column indices for non-zero elements + let row_offsets = vec![0, 2, 3, 5]; + let col_indices = vec![0, 2, 1, 0, 2]; let values = vec![ F17::from(2), F17::from(1), // First row: 2 and 1 @@ -63,6 +251,25 @@ mod tests { let matrix = SparseMatrix::new(row_offsets, col_indices, values, 3); + let mut write_matrix = SparseMatrix::new_rows_cols(3, 3); + write_matrix.write(0, 0, F17::from(2)); + write_matrix.write(0, 2, F17::ONE); + write_matrix.write(1, 1, F17::from(3)); + write_matrix.write(2, 0, F17::from(4)); + write_matrix.write(2, 2, F17::from(5)); + + assert_eq!(matrix, write_matrix); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sparse_matrix_vector_multiplication() { + let row_offsets = vec![0, 2, 3, 5]; + let col_indices = vec![0, 2, 1, 0, 2]; + let values = vec![F17::from(2), F17::from(1), F17::from(3), F17::from(4), F17::from(5)]; + + let matrix = SparseMatrix::new(row_offsets, col_indices, values, 3); + // Create input vector [1, 2, 3] let input = vec![F17::from(1), F17::from(2), F17::from(3)]; @@ -70,10 +277,54 @@ mod tests { // [2*1 + 0*2 + 1*3] = [5] // [0*1 + 3*2 + 0*3] = [6] // [4*1 + 0*2 + 5*3] = [19 ≡ 2 mod 17] - let result = matrix.mul_vector(&input); + let result = &matrix * &input; assert_eq!(result[0], F17::from(5)); assert_eq!(result[1], F17::from(6)); assert_eq!(result[2], F17::from(2)); // 19 mod 17 = 2 } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_hadamard_multiplication() { + // Create two test matrices: + // Matrix 1: + // [2 0 1] + // [0 3 0] + // [4 0 5] + let mut test_matrix1 = SparseMatrix::new_rows_cols(3, 3); + test_matrix1.write(0, 0, F17::from(2)); + test_matrix1.write(0, 2, F17::from(1)); + test_matrix1.write(1, 1, F17::from(3)); + test_matrix1.write(2, 0, F17::from(4)); + test_matrix1.write(2, 2, F17::from(5)); + + // Matrix 2: + // [3 0 0] + // [0 2 1] + // [0 0 2] + let mut test_matrix2 = SparseMatrix::new_rows_cols(3, 3); + test_matrix2.write(0, 0, F17::from(3)); + test_matrix2.write(1, 1, F17::from(2)); + test_matrix2.write(1, 2, F17::from(1)); + test_matrix2.write(2, 2, F17::from(2)); + + // Perform Hadamard multiplication + let result = &test_matrix1 * &test_matrix2; + + // The result should be: + // [6 0 0] + // [0 6 0] + // [0 0 10] + assert_eq!( + result.values, + [ + F17::from(6), // 2*3 at (0,0) + F17::from(6), // 3*2 at (1,1) + F17::from(10), // 5*2 at (2,2) + ] + ); + assert_eq!(result.col_indices, [0, 1, 2]); + assert_eq!(result.row_offsets, [0, 1, 2, 3]); + } }