Skip to content

Commit b94c587

Browse files
authored
Add a 'decision_function()' method to the 'LogisticRegression' class. (#728)
* Added 'decision_function()' method for LogisticRegression
1 parent 68c9bab commit b94c587

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

dask_ml/linear_model/glm.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ class LogisticRegression(_GLM):
208208
>>> X, y = make_classification()
209209
>>> lr = LogisticRegression()
210210
>>> lr.fit(X, y)
211+
>>> lr.decision_function(X)
211212
>>> lr.predict(X)
212213
>>> lr.predict_proba(X)
213214
>>> lr.score(X, y)"""
@@ -218,6 +219,21 @@ class LogisticRegression(_GLM):
218219
def family(self):
219220
return families.Logistic
220221

222+
def decision_function(self, X):
223+
"""Predict confidence scores for samples in X.
224+
225+
Parameters
226+
----------
227+
X : array-like, shape = [n_samples, n_features]
228+
229+
Returns
230+
-------
231+
T : array-like, shape = [n_samples, n_classes]
232+
The confidence score of the sample for each class in the model.
233+
"""
234+
X_ = self._check_array(X)
235+
return dot(X_, self._coef)
236+
221237
def predict(self, X):
222238
"""Predict class labels for samples in X.
223239
@@ -244,8 +260,7 @@ def predict_proba(self, X):
244260
T : array-like, shape = [n_samples, n_classes]
245261
The probability of the sample for each class in the model.
246262
"""
247-
X_ = self._check_array(X)
248-
return sigmoid(dot(X_, self._coef))
263+
return sigmoid(self.decision_function(X))
249264

250265
def score(self, X, y):
251266
"""The mean accuracy on the given data and labels

tests/linear_model/test_glm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_big(fit_intercept):
8989
X, y = make_classification(chunks=50)
9090
lr = LogisticRegression(fit_intercept=fit_intercept)
9191
lr.fit(X, y)
92+
lr.decision_function(X)
9293
lr.predict(X)
9394
lr.predict_proba(X)
9495
if fit_intercept:

0 commit comments

Comments
 (0)