Skip to content

Commit 60fd054

Browse files
authored
Merge pull request #74 from arnau126/master
Fix potential bugs in 'get_names'.
2 parents 4a63261 + 0529b64 commit 60fd054

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

sklearn_pandas/dataframe_mapper.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,24 @@ def fit(self, X, y=None):
158158

159159

160160
def get_names(self, c, t, x):
161-
if type(c)==list:
161+
"""
162+
Return verbose names for the transformed columns.
163+
164+
c name (or list of names) of the original column(s)
165+
t transformer
166+
x transformed columns (numpy.ndarray)
167+
"""
168+
if isinstance(c, list):
162169
c = '_'.join(c)
163-
if hasattr(t, 'classes_') and (len(t.classes_)>2):
164-
return [c + '_' + o for o in t.classes_]
165-
elif len(x.shape)>1 and x.shape[1]>1:
166-
return [c + '_' + str(o) for o in range(x.shape[1])]
170+
num_cols = x.shape[1] if len(x.shape) > 1 else 1
171+
if num_cols > 1:
172+
# If there are as many columns as classes,
173+
# infer column names from classes names.
174+
if hasattr(t, 'classes_') and (len(t.classes_) == num_cols):
175+
return [c + '_' + str(o) for o in t.classes_]
176+
# otherwise, return name concatenated with '_1', '_2', etc.
177+
else:
178+
return [c + '_' + str(o) for o in range(num_cols)]
167179
else:
168180
return [c]
169181

tests/test_dataframe_mapper.py

+45
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,23 @@ def transform(self, X):
6767
return sparse.csr_matrix(X)
6868

6969

70+
class CustomTransformer(BaseEstimator, TransformerMixin):
71+
"""
72+
Example of transformer in which the number of classes
73+
is not equals to the number of output columns.
74+
"""
75+
def fit(self, X, y=None):
76+
self.min = X.min()
77+
self.classes_ = np.unique(X)
78+
return self
79+
80+
def transform(self, X):
81+
classes = np.unique(X)
82+
if len(np.setdiff1d(classes, self.classes_)) > 0:
83+
raise ValueError('Unknown values found.')
84+
return X - self.min
85+
86+
7087
@pytest.fixture
7188
def simple_dataframe():
7289
return pd.DataFrame({'a': [1, 2, 3]})
@@ -118,6 +135,20 @@ def test_binarizer_df():
118135
assert cols[2] == 'target_c'
119136

120137

138+
def test_binarizer_int_df():
139+
"""
140+
Check level names from LabelBinarizer for a numeric array.
141+
"""
142+
df = pd.DataFrame({'target': [5, 5, 6, 6, 7, 5]})
143+
mapper = DataFrameMapper([('target', LabelBinarizer())], df_out=True)
144+
transformed = mapper.fit_transform(df)
145+
cols = transformed.columns
146+
assert len(cols) == 3
147+
assert cols[0] == 'target_5'
148+
assert cols[1] == 'target_6'
149+
assert cols[2] == 'target_7'
150+
151+
121152
def test_binarizer2_df():
122153
"""
123154
Check level names from LabelBinarizer with just one output column
@@ -143,6 +174,20 @@ def test_onehot_df():
143174
assert cols[3] == 'target_3'
144175

145176

177+
def test_customtransform_df():
178+
"""
179+
Check level ids from a transformer in which
180+
the number of classes is not equals to the number of output columns.
181+
"""
182+
df = pd.DataFrame({'target': [6, 5, 7, 5, 4, 8, 8]})
183+
mapper = DataFrameMapper([(['target'], CustomTransformer())], df_out=True)
184+
transformed = mapper.fit_transform(df)
185+
cols = transformed.columns
186+
assert len(mapper.features[0][1].classes_) == 5
187+
assert len(cols) == 1
188+
assert cols[0] == 'target'
189+
190+
146191
def test_pca(complex_dataframe):
147192
"""
148193
Check multi in and out with PCA

0 commit comments

Comments
 (0)