Skip to content

Commit 416898b

Browse files
thomasjpfanogrisel
andauthored
ENH Adds Column name consistency (#18010)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent c592361 commit 416898b

16 files changed

+462
-7
lines changed

doc/whats_new/v1.0.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ Changelog
134134
- |API| `np.matrix` usage is deprecated in 1.0 and will raise a `TypeError` in
135135
1.2. :pr:`20165` by `Thomas Fan`_.
136136

137+
- |API| All estimators store `feature_names_in_` when fitted on pandas Dataframes.
138+
These feature names are compared to names seen in `non-fit` methods,
139+
`i.e.` `transform` and will raise a `FutureWarning` if they are not consistent.
140+
These `FutureWarning`s will become `ValueError`s in 1.2.
141+
:pr:`18010` by `Thomas Fan`_.
142+
137143
:mod:`sklearn.base`
138144
...................
139145

sklearn/base.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .utils.validation import _check_y
2525
from .utils.validation import _num_features
2626
from .utils._estimator_html_repr import estimator_html_repr
27+
from .utils.validation import _get_feature_names
2728

2829

2930
def clone(estimator, *, safe=True):
@@ -395,6 +396,92 @@ def _check_n_features(self, X, reset):
395396
f"is expecting {self.n_features_in_} features as input."
396397
)
397398

399+
def _check_feature_names(self, X, *, reset):
400+
"""Set or check the `feature_names_in_` attribute.
401+
402+
.. versionadded:: 1.0
403+
404+
Parameters
405+
----------
406+
X : {ndarray, dataframe} of shape (n_samples, n_features)
407+
The input samples.
408+
409+
reset : bool
410+
Whether to reset the `feature_names_in_` attribute.
411+
If False, the input will be checked for consistency with
412+
feature names of data provided when reset was last True.
413+
.. note::
414+
It is recommended to call `reset=True` in `fit` and in the first
415+
call to `partial_fit`. All other methods that validate `X`
416+
should set `reset=False`.
417+
"""
418+
419+
if reset:
420+
feature_names_in = _get_feature_names(X)
421+
if feature_names_in is not None:
422+
self.feature_names_in_ = feature_names_in
423+
return
424+
425+
fitted_feature_names = getattr(self, "feature_names_in_", None)
426+
X_feature_names = _get_feature_names(X)
427+
428+
if fitted_feature_names is None and X_feature_names is None:
429+
# no feature names seen in fit and in X
430+
return
431+
432+
if X_feature_names is not None and fitted_feature_names is None:
433+
warnings.warn(
434+
f"X has feature names, but {self.__class__.__name__} was fitted without"
435+
" feature names"
436+
)
437+
return
438+
439+
if X_feature_names is None and fitted_feature_names is not None:
440+
warnings.warn(
441+
"X does not have valid feature names, but"
442+
f" {self.__class__.__name__} was fitted with feature names"
443+
)
444+
return
445+
446+
# validate the feature names against the `feature_names_in_` attribute
447+
if len(fitted_feature_names) != len(X_feature_names) or np.any(
448+
fitted_feature_names != X_feature_names
449+
):
450+
message = (
451+
"The feature names should match those that were "
452+
"passed during fit. Starting version 1.2, an error will be raised.\n"
453+
)
454+
fitted_feature_names_set = set(fitted_feature_names)
455+
X_feature_names_set = set(X_feature_names)
456+
457+
unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set)
458+
missing_names = sorted(fitted_feature_names_set - X_feature_names_set)
459+
460+
def add_names(names):
461+
output = ""
462+
max_n_names = 5
463+
for i, name in enumerate(names):
464+
if i >= max_n_names:
465+
output += "- ...\n"
466+
break
467+
output += f"- {name}\n"
468+
return output
469+
470+
if unexpected_names:
471+
message += "Feature names unseen at fit time:\n"
472+
message += add_names(unexpected_names)
473+
474+
if missing_names:
475+
message += "Feature names seen at fit time, yet now missing:\n"
476+
message += add_names(missing_names)
477+
478+
if not missing_names and not missing_names:
479+
message += (
480+
"Feature names must be in the same order as they were in fit.\n"
481+
)
482+
483+
warnings.warn(message, FutureWarning)
484+
398485
def _validate_data(
399486
self,
400487
X="no_validation",
@@ -452,6 +539,8 @@ def _validate_data(
452539
The validated input. A tuple is returned if both `X` and `y` are
453540
validated.
454541
"""
542+
self._check_feature_names(X, reset=reset)
543+
455544
if y is None and self._get_tags()["requires_y"]:
456545
raise ValueError(
457546
f"This {self.__class__.__name__} estimator "

sklearn/calibration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def fit(self, X, y, sample_weight=None):
368368
first_clf = self.calibrated_classifiers_[0].base_estimator
369369
if hasattr(first_clf, "n_features_in_"):
370370
self.n_features_in_ = first_clf.n_features_in_
371+
if hasattr(first_clf, "feature_names_in_"):
372+
self.feature_names_in_ = first_clf.feature_names_in_
371373
return self
372374

373375
def predict_proba(self, X):

sklearn/feature_selection/_from_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def fit(self, X, y=None, **fit_params):
257257
raise NotFittedError("Since 'prefit=True', call transform directly")
258258
self.estimator_ = clone(self.estimator)
259259
self.estimator_.fit(X, y, **fit_params)
260+
if hasattr(self.estimator_, "feature_names_in_"):
261+
self.feature_names_in_ = self.estimator_.feature_names_in_
260262
return self
261263

262264
@property

sklearn/kernel_approximation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from .base import BaseEstimator
2323
from .base import TransformerMixin
24-
from .utils import check_random_state, as_float_array
24+
from .utils import check_random_state
2525
from .utils.extmath import safe_sparse_dot
2626
from .utils.validation import check_is_fitted
2727
from .metrics.pairwise import pairwise_kernels, KERNEL_PARAMS
@@ -469,9 +469,9 @@ def transform(self, X):
469469
Returns the instance itself.
470470
"""
471471
check_is_fitted(self)
472-
473-
X = as_float_array(X, copy=True)
474-
X = self._validate_data(X, copy=False, reset=False)
472+
X = self._validate_data(
473+
X, copy=True, dtype=[np.float64, np.float32], reset=False
474+
)
475475
if (X <= -self.skewedness).any():
476476
raise ValueError("X may not contain entries smaller than -skewedness.")
477477

sklearn/linear_model/_ransac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ def predict(self, X):
556556
Returns predicted values.
557557
"""
558558
check_is_fitted(self)
559+
self._check_feature_names(X, reset=False)
559560

560561
return self.estimator_.predict(X)
561562

@@ -578,6 +579,7 @@ def score(self, X, y):
578579
Score of the prediction.
579580
"""
580581
check_is_fitted(self)
582+
self._check_feature_names(X, reset=False)
581583

582584
return self.estimator_.score(X, y)
583585

sklearn/linear_model/_ridge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,8 @@ def fit(self, X, y, sample_weight=None):
19831983
self.coef_ = estimator.coef_
19841984
self.intercept_ = estimator.intercept_
19851985
self.n_features_in_ = estimator.n_features_in_
1986+
if hasattr(estimator, "feature_names_in_"):
1987+
self.feature_names_in_ = estimator.feature_names_in_
19861988

19871989
return self
19881990

sklearn/manifold/_isomap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def _fit_transform(self, X):
172172
)
173173
self.nbrs_.fit(X)
174174
self.n_features_in_ = self.nbrs_.n_features_in_
175+
if hasattr(self.nbrs_, "feature_names_in_"):
176+
self.feature_names_in_ = self.nbrs_.feature_names_in_
175177

176178
self.kernel_pca_ = KernelPCA(
177179
n_components=self.n_components,

sklearn/manifold/_locally_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def transform(self, X):
768768
"""
769769
check_is_fitted(self)
770770

771-
X = check_array(X)
771+
X = self._validate_data(X, reset=False)
772772
ind = self.nbrs_.kneighbors(
773773
X, n_neighbors=self.n_neighbors, return_distance=False
774774
)

sklearn/neural_network/_rbm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from ..base import BaseEstimator
1717
from ..base import TransformerMixin
18-
from ..utils import check_array
1918
from ..utils import check_random_state
2019
from ..utils import gen_even_slices
2120
from ..utils.extmath import safe_sparse_dot
@@ -333,7 +332,7 @@ def score_samples(self, X):
333332
"""
334333
check_is_fitted(self)
335334

336-
v = check_array(X, accept_sparse="csr")
335+
v = self._validate_data(X, accept_sparse="csr", reset=False)
337336
rng = check_random_state(self.random_state)
338337

339338
# Randomly corrupt one feature in each sample in v.

sklearn/tests/test_base.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Author: Gael Varoquaux
22
# License: BSD 3 clause
33

4+
import re
45
import numpy as np
56
import scipy.sparse as sp
67
import pytest
@@ -615,3 +616,73 @@ def test_n_features_in_no_validation():
615616

616617
# does not raise
617618
est._check_n_features("invalid X", reset=False)
619+
620+
621+
def test_feature_names_in():
622+
"""Check that feature_name_in are recorded by `_validate_data`"""
623+
pd = pytest.importorskip("pandas")
624+
iris = datasets.load_iris()
625+
X_np = iris.data
626+
df = pd.DataFrame(X_np, columns=iris.feature_names)
627+
628+
class NoOpTransformer(TransformerMixin, BaseEstimator):
629+
def fit(self, X, y=None):
630+
self._validate_data(X)
631+
return self
632+
633+
def transform(self, X):
634+
self._validate_data(X, reset=False)
635+
return X
636+
637+
# fit on dataframe saves the feature names
638+
trans = NoOpTransformer().fit(df)
639+
assert_array_equal(trans.feature_names_in_, df.columns)
640+
641+
msg = "The feature names should match those that were passed"
642+
df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1])
643+
with pytest.warns(FutureWarning, match=msg):
644+
trans.transform(df_bad)
645+
646+
# warns when fitted on dataframe and transforming a ndarray
647+
msg = (
648+
"X does not have valid feature names, but NoOpTransformer was "
649+
"fitted with feature names"
650+
)
651+
with pytest.warns(UserWarning, match=msg):
652+
trans.transform(X_np)
653+
654+
# warns when fitted on a ndarray and transforming dataframe
655+
msg = "X has feature names, but NoOpTransformer was fitted without feature names"
656+
trans = NoOpTransformer().fit(X_np)
657+
with pytest.warns(UserWarning, match=msg):
658+
trans.transform(df)
659+
660+
# fit on dataframe with all integer feature names works without warning
661+
df_int_names = pd.DataFrame(X_np)
662+
trans = NoOpTransformer()
663+
with pytest.warns(None) as record:
664+
trans.fit(df_int_names)
665+
assert not record
666+
667+
# fit on dataframe with no feature names or all integer feature names
668+
# -> do not warn on trainsform
669+
Xs = [X_np, df_int_names]
670+
for X in Xs:
671+
with pytest.warns(None) as record:
672+
trans.transform(X)
673+
assert not record
674+
675+
# TODO: Convert to a error in 1.2
676+
# fit on dataframe with feature names that are mixed warns:
677+
df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2])
678+
trans = NoOpTransformer()
679+
msg = re.escape(
680+
"Feature names only support names that are all strings. "
681+
"Got feature names with dtypes: ['int', 'str']"
682+
)
683+
with pytest.warns(FutureWarning, match=msg) as record:
684+
trans.fit(df_mixed)
685+
686+
# transform on feature names that are mixed also warns:
687+
with pytest.warns(FutureWarning, match=msg) as record:
688+
trans.transform(df_mixed)

sklearn/tests/test_common.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_get_check_estimator_ids,
4848
check_class_weight_balanced_linear_classifier,
4949
parametrize_with_checks,
50+
check_dataframe_column_names_consistency,
5051
check_n_features_in_after_fitting,
5152
)
5253

@@ -313,3 +314,41 @@ def test_search_cv(estimator, check, request):
313314
def test_check_n_features_in_after_fitting(estimator):
314315
_set_checking_parameters(estimator)
315316
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
317+
318+
319+
# TODO: When more modules get added, we can remove it from this list to make
320+
# sure it gets tested. After we finish each module we can move the checks
321+
# into check_estimator.
322+
# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that
323+
# delegates validation to a base estimator, the check is testing that the base estimator
324+
# is checking for column name consistency.
325+
326+
COLUMN_NAME_MODULES_TO_IGNORE = {
327+
"compose",
328+
"ensemble",
329+
"feature_extraction",
330+
"kernel_approximation",
331+
"model_selection",
332+
"multiclass",
333+
"multioutput",
334+
"pipeline",
335+
"semi_supervised",
336+
}
337+
338+
339+
column_name_estimators = [
340+
est
341+
for est in _tested_estimators()
342+
if est.__module__.split(".")[1] not in COLUMN_NAME_MODULES_TO_IGNORE
343+
]
344+
345+
346+
@pytest.mark.parametrize(
347+
"estimator", column_name_estimators, ids=_get_check_estimator_ids
348+
)
349+
def test_pandas_column_name_consistency(estimator):
350+
_set_checking_parameters(estimator)
351+
with ignore_warnings(category=(FutureWarning)):
352+
check_dataframe_column_names_consistency(
353+
estimator.__class__.__name__, estimator
354+
)

0 commit comments

Comments
 (0)