diff --git a/README.rst b/README.rst index fa21005..4dc237a 100644 --- a/README.rst +++ b/README.rst @@ -415,6 +415,11 @@ Development with values other than the mode. (#144) +1.6.1 (2018-04-16) +****************** +* Preserve input data types when no transform is supplied (#138) + + 1.6.0 (2017-10-28) ****************** * Add column name to exception during fit/transform (#110). @@ -504,3 +509,4 @@ Other contributors: * Ritesh Agrawal (@ragrawal) * Vitaley Zaretskey (@vzaretsk) * Zac Stewart (@zacstewart) +* Timothy Sweetser (@hacktuarial) diff --git a/sklearn_pandas/__init__.py b/sklearn_pandas/__init__.py index 6d33c3e..fe3d189 100644 --- a/sklearn_pandas/__init__.py +++ b/sklearn_pandas/__init__.py @@ -1,4 +1,4 @@ -__version__ = '1.6.0' +__version__ = '1.6.1' from .dataframe_mapper import DataFrameMapper # NOQA from .cross_validation import cross_val_score, GridSearchCV, RandomizedSearchCV # NOQA diff --git a/sklearn_pandas/dataframe_mapper.py b/sklearn_pandas/dataframe_mapper.py index 596aa76..2c92d94 100644 --- a/sklearn_pandas/dataframe_mapper.py +++ b/sklearn_pandas/dataframe_mapper.py @@ -315,6 +315,15 @@ def transform(self, X): stacked = np.hstack(extracted) if self.df_out: + # output different data types, if appropriate + dtypes = [] + for ex in extracted: + if isinstance(ex, np.ndarray) or sparse.issparse(ex): + dtypes += [ex.dtype] * ex.shape[1] + elif isinstance(ex, pd.DataFrame): + dtypes += list(ex.dtypes) + else: + raise TypeError(type(ex)) # if no rows were dropped preserve the original index, # otherwise use a new integer one no_rows_dropped = len(X) == len(stacked) @@ -323,8 +332,10 @@ def transform(self, X): else: index = None - return pd.DataFrame(stacked, - columns=self.transformed_names_, - index=index) + df_out = pd.DataFrame(dict(zip( + self.transformed_names_, + [pd.Series(stacked[:, i], index=index, dtype=dtypes[i]) + for i in range(stacked.shape[1])]))) + return df_out[self.transformed_names_] # preserve order else: return stacked diff --git a/tests/test_dataframe_mapper.py b/tests/test_dataframe_mapper.py index 75da4fd..e55b678 100644 --- a/tests/test_dataframe_mapper.py +++ b/tests/test_dataframe_mapper.py @@ -829,3 +829,17 @@ def test_direct_cross_validation(iris_dataframe): scores = sklearn_cv_score(pipeline, data, labels) assert scores.mean() > 0.96 assert (scores.std() * 2) < 0.04 + + +def test_heterogeneous_output_types_input_df(complex_dataframe): + """ + Modify feat1, but pass feat2 and target (different types!) + through unmodified. This fails if input_df == False + """ + complex_dataframe['feat1'] = complex_dataframe['feat1'].astype(float) + M = DataFrameMapper([ + (['feat1'], StandardScaler()) + ], input_df=True, df_out=True, default=None) + expected_dtypes = complex_dataframe.dtypes + actual_dtypes = M.fit_transform(complex_dataframe).dtypes + assert (expected_dtypes == actual_dtypes).all()