@@ -67,6 +67,23 @@ def transform(self, X):
67
67
return sparse .csr_matrix (X )
68
68
69
69
70
+ class CustomTransformer (BaseEstimator , TransformerMixin ):
71
+ """
72
+ Example of transformer in which the number of classes
73
+ is not equals to the number of output columns.
74
+ """
75
+ def fit (self , X , y = None ):
76
+ self .min = X .min ()
77
+ self .classes_ = np .unique (X )
78
+ return self
79
+
80
+ def transform (self , X ):
81
+ classes = np .unique (X )
82
+ if len (np .setdiff1d (classes , self .classes_ )) > 0 :
83
+ raise ValueError ('Unknown values found.' )
84
+ return X - self .min
85
+
86
+
70
87
@pytest .fixture
71
88
def simple_dataframe ():
72
89
return pd .DataFrame ({'a' : [1 , 2 , 3 ]})
@@ -118,6 +135,20 @@ def test_binarizer_df():
118
135
assert cols [2 ] == 'target_c'
119
136
120
137
138
+ def test_binarizer_int_df ():
139
+ """
140
+ Check level names from LabelBinarizer for a numeric array.
141
+ """
142
+ df = pd .DataFrame ({'target' : [5 , 5 , 6 , 6 , 7 , 5 ]})
143
+ mapper = DataFrameMapper ([('target' , LabelBinarizer ())], df_out = True )
144
+ transformed = mapper .fit_transform (df )
145
+ cols = transformed .columns
146
+ assert len (cols ) == 3
147
+ assert cols [0 ] == 'target_5'
148
+ assert cols [1 ] == 'target_6'
149
+ assert cols [2 ] == 'target_7'
150
+
151
+
121
152
def test_binarizer2_df ():
122
153
"""
123
154
Check level names from LabelBinarizer with just one output column
@@ -143,6 +174,20 @@ def test_onehot_df():
143
174
assert cols [3 ] == 'target_3'
144
175
145
176
177
+ def test_customtransform_df ():
178
+ """
179
+ Check level ids from a transformer in which
180
+ the number of classes is not equals to the number of output columns.
181
+ """
182
+ df = pd .DataFrame ({'target' : [6 , 5 , 7 , 5 , 4 , 8 , 8 ]})
183
+ mapper = DataFrameMapper ([(['target' ], CustomTransformer ())], df_out = True )
184
+ transformed = mapper .fit_transform (df )
185
+ cols = transformed .columns
186
+ assert len (mapper .features [0 ][1 ].classes_ ) == 5
187
+ assert len (cols ) == 1
188
+ assert cols [0 ] == 'target'
189
+
190
+
146
191
def test_pca (complex_dataframe ):
147
192
"""
148
193
Check multi in and out with PCA
0 commit comments