2
2
use ndarray:: { Array1 , ArrayBase , Array2 , stack, Axis , Array , Ix2 , Data } ;
3
3
use ndarray_linalg:: { Solve , random} ;
4
4
use ndarray_stats:: DeviationExt ;
5
-
5
+ use ndarray_rand:: RandomExt ;
6
+ use rand:: distributions:: StandardNormal ;
6
7
7
8
/// The simple linear regression model is
8
9
/// y = bX + e where e ~ N(0, sigma^2 * I)
@@ -18,7 +19,7 @@ use ndarray_stats::DeviationExt;
18
19
/// where (X^T X)^{-1} X^T is known as the pseudoinverse / Moore-Penrose
19
20
/// inverse.
20
21
struct LinearRegression {
21
- beta : Option < Array1 < f32 > > ,
22
+ pub beta : Option < Array1 < f64 > > ,
22
23
fit_intercept : bool ,
23
24
}
24
25
@@ -30,15 +31,15 @@ impl LinearRegression {
30
31
}
31
32
}
32
33
33
- fn fit ( & mut self , mut X : Array2 < f32 > , y : Array1 < f32 > ) {
34
+ fn fit ( & mut self , mut X : Array2 < f64 > , y : Array1 < f64 > ) {
34
35
let ( n_samples, _) = X . dim ( ) ;
35
36
36
37
// Check that our inputs have compatible shapes
37
38
assert_eq ! ( y. dim( ) , n_samples) ;
38
39
39
40
// If we are fitting the intercept, we need an additional column
40
41
if self . fit_intercept {
41
- let dummy_column: Array < f32 , _ > = Array :: ones ( ( n_samples, 1 ) ) ;
42
+ let dummy_column: Array < f64 , _ > = Array :: ones ( ( n_samples, 1 ) ) ;
42
43
X = stack ( Axis ( 1 ) , & [ dummy_column. view ( ) , X . view ( ) ] ) . unwrap ( ) ;
43
44
} ;
44
45
@@ -47,15 +48,15 @@ impl LinearRegression {
47
48
self . beta = Some ( linear_operator. solve_into ( rhs) . unwrap ( ) ) ;
48
49
}
49
50
50
- fn predict < A > ( & self , X : & ArrayBase < A , Ix2 > ) -> Array1 < f32 >
51
+ fn predict < A > ( & self , X : & ArrayBase < A , Ix2 > ) -> Array1 < f64 >
51
52
where
52
- A : Data < Elem =f32 > ,
53
+ A : Data < Elem =f64 > ,
53
54
{
54
55
let ( n_samples, _) = X . dim ( ) ;
55
56
56
57
// If we are fitting the intercept, we need an additional column
57
58
let X = if self . fit_intercept {
58
- let dummy_column: Array < f32 , _ > = Array :: ones ( ( n_samples, 1 ) ) ;
59
+ let dummy_column: Array < f64 , _ > = Array :: ones ( ( n_samples, 1 ) ) ;
59
60
stack ( Axis ( 1 ) , & [ dummy_column. view ( ) , X . view ( ) ] ) . unwrap ( )
60
61
} else {
61
62
X . to_owned ( )
@@ -70,24 +71,31 @@ impl LinearRegression {
70
71
}
71
72
}
72
73
73
- fn get_data ( n_train_samples : usize , n_test_samples : usize , n_features : usize ) -> (
74
- Array2 < f32 > , Array2 < f32 > , Array1 < f32 > , Array1 < f32 >
74
+ fn get_data ( n_samples : usize , n_features : usize ) -> (
75
+ Array2 < f64 > , Array1 < f64 >
75
76
) {
76
- let X_train : Array2 < f32 > = random ( ( n_train_samples, n_features) ) ;
77
- let y_train: Array1 < f32 > = random ( n_train_samples) ;
78
- let X_test : Array2 < f32 > = random ( ( n_test_samples, n_features) ) ;
79
- let y_test: Array1 < f32 > = random ( n_test_samples) ;
80
- ( X_train , X_test , y_train, y_test)
77
+ let shape = ( n_samples, n_features) ;
78
+ let noise: Array1 < f64 > = Array :: random ( n_samples, StandardNormal ) ;
79
+
80
+ let beta: Array1 < f64 > = random ( n_features) * 100. ;
81
+ println ! ( "Beta used to generate target variable: {:.3}" , beta) ;
82
+
83
+ let X : Array2 < f64 > = random ( shape) ;
84
+ let y: Array1 < f64 > = X . dot ( & beta) + noise;
85
+ ( X , y)
81
86
}
82
87
83
88
pub fn main ( ) {
84
89
let n_train_samples = 5000 ;
85
90
let n_test_samples = 1000 ;
86
- let n_features = 15 ;
87
- let ( X_train , X_test , y_train, y_test) = get_data ( n_train_samples, n_test_samples, n_features) ;
88
- let mut linear_regressor = LinearRegression :: new ( true ) ;
89
- linear_regressor. fit ( X_train , y_train) ;
91
+ let n_features = 3 ;
92
+ let ( X , y) = get_data ( n_train_samples + n_test_samples, n_features) ;
93
+ let ( X_train , X_test ) = X . view ( ) . split_at ( Axis ( 0 ) , n_train_samples) ;
94
+ let ( y_train, y_test) = y. view ( ) . split_at ( Axis ( 0 ) , n_train_samples) ;
95
+ let mut linear_regressor = LinearRegression :: new ( false ) ;
96
+ linear_regressor. fit ( X_train . to_owned ( ) , y_train. to_owned ( ) ) ;
90
97
let test_predictions = linear_regressor. predict ( & X_test ) ;
91
- let mean_squared_error = test_predictions. sq_l2_dist ( & y_test) . unwrap ( ) ;
92
- println ! ( "The fitted regressor has a root mean squared error of {:}" , mean_squared_error) ;
98
+ let mean_squared_error = test_predictions. mean_sq_err ( & y_test. to_owned ( ) ) . unwrap ( ) ;
99
+ println ! ( "Beta estimated from the training data: {:.3}" , linear_regressor. beta. unwrap( ) ) ;
100
+ println ! ( "The fitted regressor has a root mean squared error of {:.3}" , mean_squared_error) ;
93
101
}
0 commit comments