Skip to content

Commit ea7b475

Browse files
Proper generation of data and target
1 parent c4c78e7 commit ea7b475

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,5 @@ optional = true
5050
[dev-dependencies]
5151
paste = "0.1"
5252
ndarray-stats = {git = "https://github.com/rust-ndarray/ndarray-stats", branch = "master"}
53+
ndarray-rand = "0.9"
54+
rand = "0.6"

examples/linear_regression.rs

+28-20
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Data};
33
use ndarray_linalg::{Solve, random};
44
use ndarray_stats::DeviationExt;
5-
5+
use ndarray_rand::RandomExt;
6+
use rand::distributions::StandardNormal;
67

78
/// The simple linear regression model is
89
/// y = bX + e where e ~ N(0, sigma^2 * I)
@@ -18,7 +19,7 @@ use ndarray_stats::DeviationExt;
1819
/// where (X^T X)^{-1} X^T is known as the pseudoinverse / Moore-Penrose
1920
/// inverse.
2021
struct LinearRegression {
21-
beta: Option<Array1<f32>>,
22+
pub beta: Option<Array1<f64>>,
2223
fit_intercept: bool,
2324
}
2425

@@ -30,15 +31,15 @@ impl LinearRegression {
3031
}
3132
}
3233

33-
fn fit(&mut self, mut X: Array2<f32>, y: Array1<f32>) {
34+
fn fit(&mut self, mut X: Array2<f64>, y: Array1<f64>) {
3435
let (n_samples, _) = X.dim();
3536

3637
// Check that our inputs have compatible shapes
3738
assert_eq!(y.dim(), n_samples);
3839

3940
// If we are fitting the intercept, we need an additional column
4041
if self.fit_intercept {
41-
let dummy_column: Array<f32, _> = Array::ones((n_samples, 1));
42+
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
4243
X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap();
4344
};
4445

@@ -47,15 +48,15 @@ impl LinearRegression {
4748
self.beta = Some(linear_operator.solve_into(rhs).unwrap());
4849
}
4950

50-
fn predict<A>(&self, X: &ArrayBase<A, Ix2>) -> Array1<f32>
51+
fn predict<A>(&self, X: &ArrayBase<A, Ix2>) -> Array1<f64>
5152
where
52-
A: Data<Elem=f32>,
53+
A: Data<Elem=f64>,
5354
{
5455
let (n_samples, _) = X.dim();
5556

5657
// If we are fitting the intercept, we need an additional column
5758
let X = if self.fit_intercept {
58-
let dummy_column: Array<f32, _> = Array::ones((n_samples, 1));
59+
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
5960
stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap()
6061
} else {
6162
X.to_owned()
@@ -70,24 +71,31 @@ impl LinearRegression {
7071
}
7172
}
7273

73-
fn get_data(n_train_samples: usize, n_test_samples: usize, n_features: usize) -> (
74-
Array2<f32>, Array2<f32>, Array1<f32>, Array1<f32>
74+
fn get_data(n_samples: usize, n_features: usize) -> (
75+
Array2<f64>, Array1<f64>
7576
) {
76-
let X_train: Array2<f32> = random((n_train_samples, n_features));
77-
let y_train: Array1<f32> = random(n_train_samples);
78-
let X_test: Array2<f32> = random((n_test_samples, n_features));
79-
let y_test: Array1<f32> = random(n_test_samples);
80-
(X_train, X_test, y_train, y_test)
77+
let shape = (n_samples, n_features);
78+
let noise: Array1<f64> = Array::random(n_samples, StandardNormal);
79+
80+
let beta: Array1<f64> = random(n_features) * 100.;
81+
println!("Beta used to generate target variable: {:.3}", beta);
82+
83+
let X: Array2<f64> = random(shape);
84+
let y: Array1<f64> = X.dot(&beta) + noise;
85+
(X, y)
8186
}
8287

8388
pub fn main() {
8489
let n_train_samples = 5000;
8590
let n_test_samples = 1000;
86-
let n_features = 15;
87-
let (X_train, X_test, y_train, y_test) = get_data(n_train_samples, n_test_samples, n_features);
88-
let mut linear_regressor = LinearRegression::new(true);
89-
linear_regressor.fit(X_train, y_train);
91+
let n_features = 3;
92+
let (X, y) = get_data(n_train_samples + n_test_samples, n_features);
93+
let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples);
94+
let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples);
95+
let mut linear_regressor = LinearRegression::new(false);
96+
linear_regressor.fit(X_train.to_owned(), y_train.to_owned());
9097
let test_predictions = linear_regressor.predict(&X_test);
91-
let mean_squared_error = test_predictions.sq_l2_dist(&y_test).unwrap();
92-
println!("The fitted regressor has a root mean squared error of {:}", mean_squared_error);
98+
let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap();
99+
println!("Beta estimated from the training data: {:.3}", linear_regressor.beta.unwrap());
100+
println!("The fitted regressor has a root mean squared error of {:.3}", mean_squared_error);
93101
}

0 commit comments

Comments
 (0)