1
- from typing import Callable , Tuple
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable , Sequence , Tuple
2
4
3
5
from pathlib import Path
4
6
12
14
REGRESSION ,
13
15
)
14
16
from autosklearn .data .xy_data_manager import XYDataManager
15
- from autosklearn .metrics import Scorer , accuracy , precision , r2
17
+ from autosklearn .metrics import Scorer , accuracy , log_loss , precision , r2
16
18
from autosklearn .util .logging_ import PicklableClientLogger
17
19
18
20
import pytest
21
23
22
24
23
25
@parametrize (
24
- "dataset, metric , task" ,
26
+ "dataset, metrics , task" ,
25
27
[
26
- ("breast_cancer" , accuracy , BINARY_CLASSIFICATION ),
27
- ("wine" , accuracy , MULTICLASS_CLASSIFICATION ),
28
- ("diabetes" , r2 , REGRESSION ),
28
+ ("breast_cancer" , [accuracy ], BINARY_CLASSIFICATION ),
29
+ ("breast_cancer" , [accuracy , log_loss ], BINARY_CLASSIFICATION ),
30
+ ("wine" , [accuracy ], MULTICLASS_CLASSIFICATION ),
31
+ ("diabetes" , [r2 ], REGRESSION ),
29
32
],
30
33
)
31
34
def test_produces_correct_output (
32
35
dataset : str ,
33
36
task : int ,
34
- metric : Scorer ,
37
+ metrics : Sequence [ Scorer ] ,
35
38
mock_logger : PicklableClientLogger ,
36
39
make_automl : Callable [..., AutoML ],
37
40
make_sklearn_dataset : Callable [..., XYDataManager ],
@@ -45,8 +48,8 @@ def test_produces_correct_output(
45
48
task : int
46
49
The task type of the dataset
47
50
48
- metric: Scorer
49
- Metric to use, required as fit usually determines the metric to use
51
+ metrics: Sequence[ Scorer]
52
+ Metric(s) to use, required as fit usually determines the metric to use
50
53
51
54
Fixtures
52
55
--------
@@ -66,7 +69,7 @@ def test_produces_correct_output(
66
69
* It should produce predictions "predictions_ensemble_1337_1_0.0.npy"
67
70
"""
68
71
seed = 1337
69
- automl = make_automl (metrics = [ metric ] , seed = seed )
72
+ automl = make_automl (metrics = metrics , seed = seed )
70
73
automl ._logger = mock_logger
71
74
72
75
datamanager = make_sklearn_dataset (
0 commit comments