Skip to content

[Feature] Add Random Forest Classifier to linfa-trees #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
maxprogrammer007 opened this issue May 17, 2025 · 4 comments
Open

[Feature] Add Random Forest Classifier to linfa-trees #389

maxprogrammer007 opened this issue May 17, 2025 · 4 comments

Comments

@maxprogrammer007
Copy link

📝 Description

I would like to contribute a new module to the linfa-trees crate that implements the Random Forest algorithm for classification tasks. This will expand linfa-trees from single decision trees into ensemble learning, aligning closely with scikit-learn's functionality in Python.


🚀 Motivation

Random Forests are a powerful ensemble learning method used widely in classification tasks. They provide:

  • Robustness to overfitting

  • Better generalization than single trees

  • Feature importance estimates

Currently, linfa-trees provides support for single decision trees. By adding Random Forests, we unlock ensemble learning for the Rust ML ecosystem.


📐 Proposed Design

🔹 New Module

A new file will be added:

bash
CopyEdit
linfa-trees/src/decision_trees/random_forest.rs

This will include:

  • RandomForestClassifier<F: Float>

  • RandomForestParams<F> (unchecked)

  • RandomForestValidParams<F> (checked)

🔹 Trait Implementations

I will implement the following traits according to linfa conventions:

  • ParamGuard for parameter validation

  • Fit to train the forest using bootstrapped data and random feature subsetting

  • PredictInplace and Predict to perform inference via majority voting

🔹 Example

An example will be added in:

bash
CopyEdit
linfa-trees/examples/iris_random_forest.rs

Using the Iris dataset from linfa-datasets.

🔹 Benchmark (Optional)

If approved, I can also add a benchmark using Criterion:

bash
CopyEdit
linfa-trees/benches/random_forest.rs

📁 File Integration Plan

  • src/lib.rs: Re-export random_forest::*

  • src/decision_trees/mod.rs: pub mod random_forest;

  • README.md: Update with a section on Random Forests and example usage

  • examples/iris_random_forest.rs: Demonstrates training and evaluation


📦 API Preview

rust
CopyEdit
let model = RandomForest::params() .n_trees(100) .feature_subsample(0.8) .max_depth(Some(10)) .fit(&dataset)?;

let predictions = model.predict(&dataset);
let acc = predictions.confusion_matrix(&dataset)?.accuracy();


✅ Conformity with CONTRIBUTING.md

  • Uses Float trait for f32/f64 compatibility

  • Follows the ParamsValidParams validation pattern

  • Implements Fit, Predict, and PredictInplace using Dataset

  • Optional serde support via feature flag

  • Will include unit tests and optionally benchmarks



🙋‍♂️ Request

Please let me know if you're open to this contribution. I’d be happy to align with maintainers on:

  • Feature scope (classifier first, regressor later?)

  • Benchmarking standards

  • Integration strategy (e.g., reuse of DecisionTree)

Looking forward to your guidance!

@relf
Copy link
Member

relf commented May 19, 2025

Thanks for your thorough description. This looks good to me, please proceed with a PR!

@relf
Copy link
Member

relf commented May 19, 2025

Sorry, just noticed previous art in #229. It would be great to take a look at it before jumping on a whole new implementation.

@maxprogrammer007
Copy link
Author

@relf Sure i will look into #229 and afterwards i will prepare my PR

@maxprogrammer007
Copy link
Author

@relf

I have done a PR, please see .. all checks have been passed and i have successfully tested the module.

PR link - #390

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants