@@ -418,13 +418,8 @@ def search(
418
418
y_test = y_test ,
419
419
resampling_strategy = self .resampling_strategy ,
420
420
resampling_strategy_args = self .resampling_strategy_args ,
421
- << << << < HEAD
422
421
dataset_name = dataset_name ,
423
422
dataset_compression = self ._dataset_compression )
424
- == == == =
425
- dataset_name = dataset_name
426
- )
427
- >> >> >> > [FIX ] Enable preprocessing in reg_cocktails (#369)
428
423
429
424
return self ._search (
430
425
dataset = self .dataset ,
@@ -465,23 +460,23 @@ def predict(
465
460
raise ValueError ("predict() is only supported after calling search. Kindly call first "
466
461
"the estimator search() method." )
467
462
468
- X_test = self .InputValidator .feature_validator .transform (X_test )
463
+ X_test = self .input_validator .feature_validator .transform (X_test )
469
464
predicted_probabilities = super ().predict (X_test , batch_size = batch_size ,
470
465
n_jobs = n_jobs )
471
466
472
- if self .InputValidator .target_validator .is_single_column_target ():
467
+ if self .input_validator .target_validator .is_single_column_target ():
473
468
predicted_indexes = np .argmax (predicted_probabilities , axis = 1 )
474
469
else :
475
470
predicted_indexes = (predicted_probabilities > 0.5 ).astype (int )
476
471
477
472
# Allow to predict in the original domain -- that is, the user is not interested
478
473
# in our encoded values
479
- return self .InputValidator .target_validator .inverse_transform (predicted_indexes )
474
+ return self .input_validator .target_validator .inverse_transform (predicted_indexes )
480
475
481
476
def predict_proba (self ,
482
477
X_test : Union [np .ndarray , pd .DataFrame , List ],
483
478
batch_size : Optional [int ] = None , n_jobs : int = 1 ) -> np .ndarray :
484
- if self .InputValidator is None or not self .InputValidator ._is_fitted :
479
+ if self .input_validator is None or not self .input_validator ._is_fitted :
485
480
raise ValueError ("predict() is only supported after calling search. Kindly call first "
486
481
"the estimator search() method." )
487
482
X_test = self .input_validator .feature_validator .transform (X_test )
0 commit comments