Skip to content

Commit 5f3d4b6

Browse files
committed
autoPyTorch/api/
1 parent 8f8dee1 commit 5f3d4b6

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

test/test_api/test_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,6 @@ def test_tabular_input_support(openml_id, backend):
609609
estimator = TabularClassificationTask(
610610
backend=backend,
611611
resampling_strategy=HoldoutValTypes.holdout_validation,
612-
ensemble_size=0,
613612
)
614613

615614
estimator._do_dummy_prediction = unittest.mock.MagicMock()
@@ -624,6 +623,7 @@ def test_tabular_input_support(openml_id, backend):
624623
func_eval_time_limit_secs=50,
625624
enable_traditional_pipeline=False,
626625
load_models=False,
626+
ensemble_size=0,
627627
)
628628

629629

@@ -633,7 +633,6 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
633633
estimator = TabularClassificationTask(
634634
backend=backend,
635635
resampling_strategy=HoldoutValTypes.holdout_validation,
636-
ensemble_size=0,
637636
)
638637

639638
# Setup pre-requisites normally set by search()

test/test_api/test_base_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_set_pipeline_config():
118118
])
119119
def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, budget_type, expected):
120120
BaseTask.__abstractmethods__ = set()
121-
estimator = BaseTask(task_type='tabular_classification', ensemble_size=0)
121+
estimator = BaseTask(task_type='tabular_classification')
122122

123123
# Fixture pipeline config
124124
default_pipeline_config = {
@@ -141,7 +141,7 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud
141141
smac_mock.return_value = smac
142142
estimator._search(optimize_metric='accuracy', dataset=dataset, tae_func=pipeline_fit,
143143
min_budget=min_budget, max_budget=max_budget, budget_type=budget_type,
144-
enable_traditional_pipeline=False,
144+
ensemble_size=0, enable_traditional_pipeline=False,
145145
total_walltime_limit=20, func_eval_time_limit_secs=10,
146146
load_models=False)
147147
assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config
@@ -210,7 +210,6 @@ def test_init_ensemble_builder(backend):
210210
BaseTask.__abstractmethods__ = set()
211211
estimator = BaseTask(
212212
backend=backend,
213-
ensemble_size=0,
214213
)
215214

216215
# Setup pre-requisites normally set by search()

0 commit comments

Comments
 (0)