diff --git a/sklearn_pandas/dataframe_mapper.py b/sklearn_pandas/dataframe_mapper.py index 2f4c364..690499d 100644 --- a/sklearn_pandas/dataframe_mapper.py +++ b/sklearn_pandas/dataframe_mapper.py @@ -33,7 +33,7 @@ class DataFrameMapper(BaseEstimator, TransformerMixin): sklearn transformation. """ - def __init__(self, features, default=False, sparse=False): + def __init__(self, features, default=False, sparse=False, df_out=False): """ Params: @@ -50,6 +50,13 @@ def __init__(self, features, default=False, sparse=False): sparse will return sparse matrix if set True and any of the extracted features is sparse. Defaults to False. + + df_out return a pandas data frame, with each column named using + the pandas column that created it (if there's only one + input and output) or the input columns joined with '_' + if there's multiple inputs, and the name concatenated with + '_1', '_2' etc if there's multiple outputs. NB: does not + work if *default* or *sparse* are true """ if isinstance(features, list): features = [(columns, _build_transformer(transformers)) @@ -57,6 +64,9 @@ def __init__(self, features, default=False, sparse=False): self.features = features self.default = _build_transformer(default) self.sparse = sparse + self.df_out = df_out + if (df_out and (sparse or default)): + raise ValueError("Can not use df_out with sparse or default") @property def _selected_columns(self): @@ -94,6 +104,7 @@ def __setstate__(self, state): # compatibility shim for pickles created before ``default`` init # argument existed self.default = state.get('default', False) + self.df_out = state.get('df_out', False) def _get_col_subset(self, X, cols): """ @@ -145,6 +156,18 @@ def fit(self, X, y=None): self._get_col_subset(X, self._unselected_columns(X)), y) return self + + def get_names(self, c, t, x): + if type(c)==list: + c = '_'.join(c) + if hasattr(t, 'classes_') and (len(t.classes_)>2): + return [c + '_' + o for o in t.classes_] + elif len(x.shape)>1 and x.shape[1]>1: + return [c + '_' + str(o) for o in range(x.shape[1])] + else: + return [c] + + def transform(self, X): """ Transform the given data. Assumes that fit has already been called. @@ -152,6 +175,7 @@ def transform(self, X): X the data to transform """ extracted = [] + index = [] for columns, transformers in self.features: # columns could be a string or list of # strings; we don't care because pandas @@ -160,10 +184,13 @@ def transform(self, X): if transformers is not None: Xt = transformers.transform(Xt) extracted.append(_handle_feature(Xt)) + if self.df_out: + index = index + self.get_names(columns, transformers, Xt) # handle features not explicitly selected if self.default is not False: - Xt = self._get_col_subset(X, self._unselected_columns(X)) + unsel_cols = self._unselected_columns(X) + Xt = self._get_col_subset(X, unsel_cols) if self.default is not None: Xt = self.default.transform(Xt) extracted.append(_handle_feature(Xt)) @@ -185,4 +212,7 @@ def transform(self, X): else: stacked = np.hstack(extracted) - return stacked + if not self.df_out: + return stacked + + return pd.DataFrame(stacked, columns=index) diff --git a/tests/test_dataframe_mapper.py b/tests/test_dataframe_mapper.py index 08454fb..8190a14 100644 --- a/tests/test_dataframe_mapper.py +++ b/tests/test_dataframe_mapper.py @@ -17,9 +17,10 @@ from sklearn.pipeline import Pipeline from sklearn.svm import SVC from sklearn.feature_extraction.text import CountVectorizer -from sklearn.preprocessing import Imputer, StandardScaler, OneHotEncoder +from sklearn.preprocessing import Imputer, StandardScaler, OneHotEncoder, LabelBinarizer from sklearn.feature_selection import SelectKBest, chi2 from sklearn.base import BaseEstimator, TransformerMixin +import sklearn.decomposition import numpy as np from numpy.testing import assert_array_equal import pickle @@ -77,6 +78,81 @@ def complex_dataframe(): 'feat2': [1, 2, 3, 2, 3, 4]}) +def test_simple_df(simple_dataframe): + """ + Get a dataframe from a simple mapped dataframe + """ + df = simple_dataframe + mapper = DataFrameMapper([('a', None)], df_out=True) + transformed = mapper.fit_transform(df) + assert type(transformed) == pd.DataFrame + assert len(transformed["a"]) == len(simple_dataframe["a"]) + + +def test_complex_df(complex_dataframe): + """ + Get a dataframe from a complex mapped dataframe + """ + df = complex_dataframe + mapper = DataFrameMapper([('target', None), ('feat1', None), ('feat2', None)], df_out=True) + transformed = mapper.fit_transform(df) + assert len(transformed) == len(complex_dataframe) + for c in df.columns: + assert len(transformed[c]) == len(df[c]) + + +def test_binarizer_df(): + """ + Check level names from LabelBinarizer + """ + df = pd.DataFrame({'target': ['a', 'a', 'b', 'b', 'c', 'a']}) + mapper = DataFrameMapper([('target', LabelBinarizer())], df_out=True) + transformed = mapper.fit_transform(df) + cols = transformed.columns + assert len(cols) == 3 + assert cols[0] == 'target_a' + assert cols[1] == 'target_b' + assert cols[2] == 'target_c' + + +def test_binarizer2_df(): + """ + Check level names from LabelBinarizer with just one output column + """ + df = pd.DataFrame({'target': ['a', 'a', 'b', 'b', 'a']}) + mapper = DataFrameMapper([('target', LabelBinarizer())], df_out=True) + transformed = mapper.fit_transform(df) + cols = transformed.columns + assert len(cols) == 1 + assert cols[0] == 'target' + + +def test_onehot_df(): + """ + Check level ids from one-hot + """ + df = pd.DataFrame({'target': [0, 0, 1, 1, 2, 3, 0]}) + mapper = DataFrameMapper([(['target'], OneHotEncoder())], df_out=True) + transformed = mapper.fit_transform(df) + cols = transformed.columns + assert len(cols) == 4 + assert cols[0] == 'target_0' + assert cols[3] == 'target_3' + + +def test_pca(complex_dataframe): + """ + Check multi in and out with PCA + """ + df = complex_dataframe + mapper = DataFrameMapper([(['feat1', 'feat2'], sklearn.decomposition.PCA(2))], df_out=True) + transformed = mapper.fit_transform(df) + cols = transformed.columns + assert len(cols) == 2 + assert cols[0] == 'feat1_feat2_0' + assert cols[1] == 'feat1_feat2_1' + + def test_nonexistent_columns_explicit_fail(simple_dataframe): """ If a nonexistent column is selected, KeyError is raised.