Skip to content

Commit 3d1d995

Browse files
committed
ENH define loss.is_multiclass
1 parent db7fb67 commit 3d1d995

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

sklearn/_loss/loss.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class BaseLoss(BaseLink, cLossFunction):
8787
approximated, it should be larger or equal to the exact one.
8888
constant_hessian : bool
8989
Indicates whether the hessian is one for this loss.
90+
is_multiclass : bool
91+
Indicates whether n_classes > 2 is allowed.
9092
"""
9193

9294
# Inherited methods from BaseLink:
@@ -112,6 +114,7 @@ def __init__(self, n_classes=1):
112114
self.approx_hessian = False
113115
self.constant_hessian = False
114116
self.n_classes = n_classes
117+
self.is_multiclass = False
115118
self.interval_y_true = Interval(-np.inf, np.inf, False, False)
116119
self.interval_y_pred = Interval(-np.inf, np.inf, False, False)
117120

@@ -811,6 +814,7 @@ class CategoricalCrossEntropy(
811814

812815
def __init__(self, sample_weight=None, n_classes=3):
813816
super().__init__(n_classes=n_classes)
817+
self.is_multiclass = True
814818
self.interval_y_true = Interval(0, np.inf, True, False)
815819
self.interval_y_pred = Interval(0, 1, False, False)
816820

sklearn/_loss/tests/test_loss.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@ def random_y_true_raw_prediction(
5353
):
5454
"""Random generate y_true and raw_prediction in valid range."""
5555
rng = np.random.RandomState(seed)
56-
if loss.n_classes <= 2:
56+
if loss.is_multiclass:
57+
raw_prediction = np.empty((n_samples, loss.n_classes))
58+
raw_prediction.flat[:] = rng.uniform(
59+
low=raw_bound[0],
60+
high=raw_bound[1],
61+
size=n_samples * loss.n_classes,
62+
)
63+
y_true = np.arange(n_samples).astype(float) % loss.n_classes
64+
else:
5765
raw_prediction = rng.uniform(
5866
low=raw_bound[0], high=raw_bound[0], size=n_samples
5967
)
@@ -73,14 +81,6 @@ def random_y_true_raw_prediction(
7381
and loss.interval_y_true.high_inclusive
7482
):
7583
y_true[1:: (n_samples // 3)] = 1
76-
else:
77-
raw_prediction = np.empty((n_samples, loss.n_classes))
78-
raw_prediction.flat[:] = rng.uniform(
79-
low=raw_bound[0],
80-
high=raw_bound[1],
81-
size=n_samples * loss.n_classes,
82-
)
83-
y_true = np.arange(n_samples).astype(float) % loss.n_classes
8484

8585
return y_true, raw_prediction
8686

@@ -105,11 +105,11 @@ def numerical_derivative(func, x, eps):
105105
def test_loss_boundary(loss):
106106
"""Test interval ranges of y_true and y_pred in losses."""
107107
# make sure low and high are always within the interval, used for linspace
108-
if loss.n_classes is None or loss.n_classes <= 2:
108+
if loss.is_multiclass:
109+
y_true = np.linspace(0, 9, num=10)
110+
else:
109111
low, high = _inclusive_low_high(loss.interval_y_true)
110112
y_true = np.linspace(low, high, num=10)
111-
else:
112-
y_true = np.linspace(0, 9, num=10)
113113

114114
# add boundaries if they are included
115115
if loss.interval_y_true.low_inclusive:
@@ -120,13 +120,13 @@ def test_loss_boundary(loss):
120120
assert loss.in_y_true_range(y_true)
121121

122122
low, high = _inclusive_low_high(loss.interval_y_pred)
123-
if loss.n_classes is None or loss.n_classes <= 2:
124-
y_pred = np.linspace(low, high, num=10)
125-
else:
123+
if loss.is_multiclass:
126124
y_pred = np.empty((10, 3))
127125
y_pred[:, 0] = np.linspace(low, high, num=10)
128126
y_pred[:, 1] = 0.5 * (1 - y_pred[:, 0])
129127
y_pred[:, 2] = 0.5 * (1 - y_pred[:, 0])
128+
else:
129+
y_pred = np.linspace(low, high, num=10)
130130

131131
assert loss.in_y_pred_range(y_pred)
132132

@@ -153,7 +153,7 @@ def test_loss_boundary(loss):
153153
]
154154
# y_pred and y_true do not always have the same domain (valid value range).
155155
# Hence, we define extra sets of parameters for each of them.
156-
Y_TRUE_PARAMS = [
156+
Y_TRUE_PARAMS = [ # type: ignore
157157
# (loss, [y success], [y fail])
158158
(HalfPoissonLoss(), [0], []),
159159
(HalfTweedieLoss(power=-3), [-100, -0.1, 0], []),
@@ -185,7 +185,8 @@ def test_loss_boundary_y_true(loss, y_true_success, y_true_fail):
185185

186186

187187
@pytest.mark.parametrize(
188-
"loss, y_pred_success, y_pred_fail", Y_COMMON_PARAMS + Y_PRED_PARAMS
188+
"loss, y_pred_success, y_pred_fail",
189+
Y_COMMON_PARAMS + Y_PRED_PARAMS # type: ignore
189190
)
190191
def test_loss_boundary_y_pred(loss, y_pred_success, y_pred_fail):
191192
"""Test boundaries of y_pred for loss functions."""
@@ -211,16 +212,16 @@ def test_loss_dtype(
211212
float64, and all output arrays are either all float32 or all float64.
212213
"""
213214
loss = loss()
214-
if loss.n_classes <= 2:
215-
# generate a y_true in valid range
216-
low, high = _inclusive_low_high(loss.interval_y_true, dtype=dtype_in)
217-
y_true = np.array([0.5 * (high - low)], dtype=dtype_in)
218-
raw_prediction = np.array([0.0], dtype=dtype_in)
219-
else:
215+
# generate a y_true and raw_prediction in valid range
216+
if loss.is_multiclass:
220217
y_true = np.array([0], dtype=dtype_in)
221218
raw_prediction = np.full(
222219
shape=(1, loss.n_classes), fill_value=0.0, dtype=dtype_in
223220
)
221+
else:
222+
low, high = _inclusive_low_high(loss.interval_y_true, dtype=dtype_in)
223+
y_true = np.array([0.5 * (high - low)], dtype=dtype_in)
224+
raw_prediction = np.array([0.0], dtype=dtype_in)
224225

225226
if sample_weight is not None:
226227
sample_weight = np.array([2.0], dtype=dtype_in)
@@ -251,7 +252,7 @@ def test_loss_dtype(
251252
gradient=out2,
252253
n_threads=n_threads,
253254
)
254-
if out1 is not None and loss.n_classes >= 3:
255+
if out1 is not None and loss.is_multiclass:
255256
out1 = np.empty_like(raw_prediction, dtype=dtype_out)
256257
loss.gradient_hessian(
257258
y_true=y_true,
@@ -350,7 +351,7 @@ def test_loss_same_as_C_functions(loss, sample_weight):
350351
def test_loss_gradients_are_the_same(loss, sample_weight):
351352
"""Test that loss and gradient are the same across different functions.
352353
353-
Also test that output arguments contain correct result.
354+
Also test that output arguments contain correct results.
354355
"""
355356
y_true, raw_prediction = random_y_true_raw_prediction(
356357
loss=loss,
@@ -410,7 +411,7 @@ def test_loss_gradients_are_the_same(loss, sample_weight):
410411
assert np.shares_memory(g3, out_g3)
411412

412413
if hasattr(loss, "gradient_proba"):
413-
assert loss.n_classes >= 3 # only for CategoricalCrossEntropy
414+
assert loss.is_multiclass # only for CategoricalCrossEntropy
414415
out_g4 = np.empty_like(raw_prediction)
415416
out_proba = np.empty_like(raw_prediction)
416417
g4, proba = loss.gradient_proba(

0 commit comments

Comments
 (0)