Skip to content

Add new ensemble methods crate: linfa-ensemble #392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,23 @@ Where does `linfa` stand right now? [Are we learning yet?](http://www.arewelearn

| Name | Purpose | Status | Category | Notes |
| :--- | :--- | :---| :--- | :---|
| [bayes](algorithms/linfa-bayes/) | Naive Bayes | Tested | Supervised learning | Contains Bernouilli, Gaussian and Multinomial Naive Bayes |
| [clustering](algorithms/linfa-clustering/) | Data clustering | Tested / Benchmarked | Unsupervised learning | Clustering of unlabeled data; contains K-Means, Gaussian-Mixture-Model, DBSCAN and OPTICS |
| [kernel](algorithms/linfa-kernel/) | Kernel methods for data transformation | Tested | Pre-processing | Maps feature vector into higher-dimensional space|
| [linear](algorithms/linfa-linear/) | Linear regression | Tested | Partial fit | Contains Ordinary Least Squares (OLS), Generalized Linear Models (GLM) |
| [ensemble](algorithms/linfa-ensemble/) | Ensemble methods | Tested | Supervised learning | Contains bagging |
| [elasticnet](algorithms/linfa-elasticnet/) | Elastic Net | Tested | Supervised learning | Linear regression with elastic net constraints |
| [logistic](algorithms/linfa-logistic/) | Logistic regression | Tested | Partial fit | Builds two-class logistic regression models
| [reduction](algorithms/linfa-reduction/) | Dimensionality reduction | Tested | Pre-processing | Diffusion mapping, Principal Component Analysis (PCA), Random projections |
| [trees](algorithms/linfa-trees/) | Decision trees | Tested / Benchmarked | Supervised learning | Linear decision trees
| [svm](algorithms/linfa-svm/) | Support Vector Machines | Tested | Supervised learning | Classification or regression analysis of labeled datasets |
| [ftrl](algorithms/linfa-ftrl/) | Follow The Regularized Leader - proximal | Tested / Benchmarked | Partial fit | Contains L1 and L2 regularization. Possible incremental update |
| [hierarchical](algorithms/linfa-hierarchical/) | Agglomerative hierarchical clustering | Tested | Unsupervised learning | Cluster and build hierarchy of clusters |
| [bayes](algorithms/linfa-bayes/) | Naive Bayes | Tested | Supervised learning | Contains Gaussian Naive Bayes |
| [ica](algorithms/linfa-ica/) | Independent component analysis | Tested | Unsupervised learning | Contains FastICA implementation |
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction| Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |
| [preprocessing](algorithms/linfa-preprocessing/) |Normalization & Vectorization| Tested / Benchmarked | Pre-processing | Contains data normalization/whitening and count vectorization/tf-idf |
| [kernel](algorithms/linfa-kernel/) | Kernel methods for data transformation | Tested | Pre-processing | Maps feature vector into higher-dimensional space |
| [linear](algorithms/linfa-linear/) | Linear regression | Tested | Partial fit | Contains Ordinary Least Squares (OLS), Generalized Linear Models (GLM) |
| [logistic](algorithms/linfa-logistic/) | Logistic regression | Tested | Partial fit | Builds two-class logistic regression models |
| [nn](algorithms/linfa-nn/) | Nearest Neighbours & Distances | Tested / Benchmarked | Pre-processing | Spatial index structures and distance functions |
| [ftrl](algorithms/linfa-ftrl/) | Follow The Regularized Leader - proximal | Tested / Benchmarked | Partial fit | Contains L1 and L2 regularization. Possible incremental update |
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
| [preprocessing](algorithms/linfa-preprocessing/) | Normalization & Vectorization| Tested / Benchmarked | Pre-processing | Contains data normalization/whitening and count vectorization/tf-idf |
| [reduction](algorithms/linfa-reduction/) | Dimensionality reduction | Tested | Pre-processing | Diffusion mapping, Principal Component Analysis (PCA), Random projections |
| [svm](algorithms/linfa-svm/) | Support Vector Machines | Tested | Supervised learning | Classification or regression analysis of labeled datasets |
| [trees](algorithms/linfa-trees/) | Decision trees | Tested / Benchmarked | Supervised learning | Linear decision trees |
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction | Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |

We believe that only a significant community effort can nurture, build, and sustain a machine learning ecosystem in Rust - there is no other way forward.

Expand Down
40 changes: 40 additions & 0 deletions algorithms/linfa-ensemble/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[package]
name = "linfa-ensemble"
version = "0.7.0"
edition = "2018"
authors = [
"James Knight <[email protected]>",
"James Kay <[email protected]>",
]
description = "A general method for creating ensemble classifiers"
license = "MIT/Apache-2.0"

repository = "https://github.com/rust-ml/linfa"
readme = "README.md"

keywords = ["machine-learning", "linfa", "ensemble"]
categories = ["algorithms", "mathematics", "science"]

[features]
default = []
serde = ["serde_crate", "ndarray/serde"]

[dependencies.serde_crate]
package = "serde"
optional = true
version = "1.0"
default-features = false
features = ["std", "derive"]

[dependencies]
ndarray = { version = "0.15", features = ["rayon", "approx"] }
ndarray-rand = "0.14"
rand = "0.8.5"

linfa = { version = "0.7.1", path = "../.." }
linfa-trees = { version = "0.7.1", path = "../linfa-trees" }

[dev-dependencies]
linfa-datasets = { version = "0.7.1", path = "../../datasets/", features = [
"iris",
] }
21 changes: 21 additions & 0 deletions algorithms/linfa-ensemble/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Ensemble Learning

`linfa-ensemble` provides pure Rust implementations of Ensemble Learning algorithms for the Linfa toolkit.

## The Big Picture

`linfa-ensemble` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.

## Current state

`linfa-ensemble` currently provides an implementation of bootstrap aggregation (bagging) for other classifiers provided in linfa.

## Examples

You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use:

```bash
$ cargo run --example randomforest_iris --release
```


35 changes: 35 additions & 0 deletions algorithms/linfa-ensemble/examples/bagging_iris.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
use linfa_ensemble::EnsembleLearnerParams;
use linfa_trees::DecisionTree;
use ndarray_rand::rand::SeedableRng;
use rand::rngs::SmallRng;

fn main() {
// Number of models in the ensemble
let ensemble_size = 100;
// Proportion of training data given to each model
let bootstrap_proportion = 0.7;

// Load dataset
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);

// Train ensemble learner model
let model = EnsembleLearnerParams::new(DecisionTree::params())
.ensemble_size(ensemble_size)
.bootstrap_proportion(bootstrap_proportion)
.fit(&train)
.unwrap();

// Return highest ranking predictions
let final_predictions_ensemble = model.predict(&test);
println!("Final Predictions: \n{:?}", final_predictions_ensemble);

let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();

println!("{:?}", cm);
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {},\n Bootstrap Proportion: {}",
100.0 * cm.accuracy(), ensemble_size, bootstrap_proportion);
}
103 changes: 103 additions & 0 deletions algorithms/linfa-ensemble/src/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use crate::EnsembleLearnerValidParams;
use linfa::{
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
error::Error,
traits::*,
DatasetBase,
};
use ndarray::{Array2, Axis, Zip};
use rand::Rng;
use std::{cmp::Eq, collections::HashMap, hash::Hash};

pub struct EnsembleLearner<M> {
pub models: Vec<M>,
}

impl<M> EnsembleLearner<M> {
// Generates prediction iterator returning predictions from each model
pub fn generate_predictions<'b, R: Records, T>(
&'b self,
x: &'b R,
) -> impl Iterator<Item = T> + 'b
where
M: Predict<&'b R, T>,
{
self.models.iter().map(move |m| m.predict(x))
}
}

impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
where
M: PredictInplace<Array2<F>, T>,
<T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
{
fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
let y_array = y.as_targets();
assert_eq!(
x.nrows(),
y_array.len_of(Axis(0)),
"The number of data points must match the number of outputs."
);

let predictions = self.generate_predictions(x);

// prediction map has same shape as y_array, but the elements are maps
let mut prediction_maps = y_array.map(|_| HashMap::new());

for prediction in predictions {
let p_arr = prediction.as_targets();
assert_eq!(p_arr.shape(), y_array.shape());
// Insert each prediction value into the corresponding map
Zip::from(&mut prediction_maps)
.and(&p_arr)
.for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
}

// For each prediction, pick the result with the highest number of votes
let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
let mut y_array = y.as_targets_mut();
for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
*y = **pred
}
}

fn default_target(&self, x: &Array2<F>) -> T {
self.models[0].default_target(x)
}
}

impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
for EnsembleLearnerValidParams<P, R>
where
D: Clone,
T: FromTargetArrayOwned,
T::Elem: Copy + Eq + Hash,
T::Owned: AsTargets,
{
type Object = EnsembleLearner<P::Object>;

fn fit(
&self,
dataset: &DatasetBase<Array2<D>, T>,
) -> core::result::Result<Self::Object, Error> {
let mut models = Vec::new();
let mut rng = self.rng.clone();

let dataset_size =
((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;

let iter = dataset.bootstrap_samples(dataset_size, &mut rng);

for train in iter {
let model = self.model_params.fit(&train).unwrap();
models.push(model);

if models.len() == self.ensemble_size {
break;
}
}

Ok(EnsembleLearner { models })
}
}
73 changes: 73 additions & 0 deletions algorithms/linfa-ensemble/src/hyperparams.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use linfa::{
error::{Error, Result},
ParamGuard,
};
use rand::rngs::ThreadRng;
use rand::Rng;

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerValidParams<P, R> {
/// The number of models in the ensemble
pub ensemble_size: usize,
/// The proportion of the total number of training samples that should be given to each model for training
pub bootstrap_proportion: f64,
/// The model parameters for the base model
pub model_params: P,
pub rng: R,
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);

impl<P> EnsembleLearnerParams<P, ThreadRng> {
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
Self::new_fixed_rng(model_params, rand::thread_rng())
}
}

impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
Self(EnsembleLearnerValidParams {
ensemble_size: 1,
bootstrap_proportion: 1.0,
model_params,
rng,
})
}

pub fn ensemble_size(mut self, size: usize) -> Self {
self.0.ensemble_size = size;
self
}

pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
self.0.bootstrap_proportion = proportion;
self
}
}

impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
type Checked = EnsembleLearnerValidParams<P, R>;
type Error = Error;

fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
Err(Error::Parameters(format!(
"Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
self.0.bootstrap_proportion
)))
} else if self.0.ensemble_size < 1 {
Err(Error::Parameters(format!(
"Ensemble size should be less than one, but was {}",
self.0.ensemble_size
)))
} else {
Ok(&self.0)
}
}

fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}
77 changes: 77 additions & 0 deletions algorithms/linfa-ensemble/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//! # Ensemble Learning Algorithms
//!
//! Ensemble methods combine the predictions of several base estimators built with a given
//! learning algorithm in order to improve generalizability / robustness over a single estimator.
//!
//! ## Bootstrap Aggregation (aka Bagging)
//!
//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of
//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
//!
//! ## Reference
//!
//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
//!
//! ## Example
//!
//! This example shows how to train a bagging model using 100 decision trees,
//! each trained on 70% of the training data (bootstrap sampling).
//!
//! ```no_run
//! use linfa::prelude::{Fit, Predict};
//! use linfa_ensemble::EnsembleLearnerParams;
//! use linfa_trees::DecisionTree;
//! use ndarray_rand::rand::SeedableRng;
//! use rand::rngs::SmallRng;
//!
//! // Load Iris dataset
//! let mut rng = SmallRng::seed_from_u64(42);
//! let (train, test) = linfa_datasets::iris()
//! .shuffle(&mut rng)
//! .split_with_ratio(0.8);
//!
//! // Train the model on the iris dataset
//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
//! .ensemble_size(100)
//! .bootstrap_proportion(0.7)
//! .fit(&train)
//! .unwrap();
//!
//! // Make predictions on the test set
//! let predictions = bagging_model.predict(&test);
//! ```
//!
mod algorithm;
mod hyperparams;

pub use algorithm::*;
pub use hyperparams::*;

#[cfg(test)]
mod tests {
use super::*;
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
use linfa_trees::DecisionTree;
use ndarray_rand::rand::SeedableRng;
use rand::rngs::SmallRng;

#[test]
fn test_ensemble_learner_accuracy_on_iris_dataset() {
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);

let model = EnsembleLearnerParams::new(DecisionTree::params())
.ensemble_size(100)
.bootstrap_proportion(0.7)
.fit(&train)
.unwrap();

let predictions = model.predict(&test);

let cm = predictions.confusion_matrix(&test).unwrap();
let acc = cm.accuracy();
assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
}
}
Loading
Loading