Skip to content

Commit c39c7c5

Browse files
committed
[refactor] Separate errors for unsupported columns as a func
1 parent 048656e commit c39c7c5

File tree

2 files changed

+31
-47
lines changed

2 files changed

+31
-47
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,28 @@ def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
7474
return preprocessors
7575

7676

77+
def _error_due_to_unsupported_column(X: pd.DataFrame, column: str) -> None:
78+
# Move away from np.issubdtype as it causes
79+
# TypeError: data type not understood in certain pandas types
80+
def _generate_error_message_prefix(type_name: str, proc_type: Optional[str] = None) -> str:
81+
msg1 = f"column `{column}` has an invalid type `{type_name}`. "
82+
msg2 = "Cast it to a numerical type, category type or bool type by astype method. "
83+
msg3 = f"The following link might help you to know {proc_type} processing: "
84+
return msg1 + msg2 + ("" if proc_type is None else msg3)
85+
86+
dtype = X[column].dtype
87+
if dtype.name == 'object':
88+
err_msg = _generate_error_message_prefix(type_name="object", proc_type="string")
89+
url = "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html"
90+
raise TypeError(f"{err_msg}{url}")
91+
elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(dtype):
92+
err_msg = _generate_error_message_prefix(type_name="time and/or date datatype", proc_type="datetime")
93+
raise TypeError(f"{err_msg}https://stats.stackexchange.com/questions/311494/")
94+
else:
95+
err_msg = _generate_error_message_prefix(type_name=dtype.name)
96+
raise TypeError(err_msg)
97+
98+
7799
class TabularFeatureValidator(BaseFeatureValidator):
78100
"""
79101
A subclass of `BaseFeatureValidator` made for tabular data.
@@ -428,51 +450,15 @@ def _get_columns_to_encode(
428450
feat_type = []
429451

430452
# Make sure each column is a valid type
431-
for i, column in enumerate(X.columns):
432-
if X[column].dtype.name in ['category', 'bool']:
433-
453+
for dtype, column in zip(X.dtypes, X.columns):
454+
if dtype.name in ['category', 'bool']:
434455
transformed_columns.append(column)
435456
feat_type.append('categorical')
436-
# Move away from np.issubdtype as it causes
437-
# TypeError: data type not understood in certain pandas types
438-
elif not is_numeric_dtype(X[column]):
439-
if X[column].dtype.name == 'object':
440-
raise ValueError(
441-
"Input Column {} has invalid type object. "
442-
"Cast it to a valid dtype before using it in AutoPyTorch. "
443-
"Valid types are numerical, categorical or boolean. "
444-
"You can cast it to a valid dtype using "
445-
"pandas.Series.astype ."
446-
"If working with string objects, the following "
447-
"tutorial illustrates how to work with text data: "
448-
"https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format(
449-
# noqa: E501
450-
column,
451-
)
452-
)
453-
elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(
454-
X[column].dtype
455-
):
456-
raise ValueError(
457-
"AutoPyTorch does not support time and/or date datatype as given "
458-
"in column {}. Please convert the time information to a numerical value "
459-
"first. One example on how to do this can be found on "
460-
"https://stats.stackexchange.com/questions/311494/".format(
461-
column,
462-
)
463-
)
464-
else:
465-
raise ValueError(
466-
"Input Column {} has unsupported dtype {}. "
467-
"Supported column types are categorical/bool/numerical dtypes. "
468-
"Make sure your data is formatted in a correct way, "
469-
"before feeding it to AutoPyTorch.".format(
470-
column,
471-
X[column].dtype.name,
472-
)
473-
)
474-
else:
457+
elif is_numeric_dtype(dtype):
475458
feat_type.append('numerical')
459+
else:
460+
_error_due_to_unsupported_column(X, column)
461+
476462
return transformed_columns, feat_type
477463

478464
def list_to_dataframe(

test/test_data/test_feature_validator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,11 @@ def test_features_unsupported_calls_are_raised():
328328
expected
329329
"""
330330
validator = TabularFeatureValidator()
331-
with pytest.raises(ValueError, match=r"AutoPyTorch does not support time"):
332-
validator.fit(
333-
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
334-
)
331+
with pytest.raises(TypeError, match=r"invalid type `time and/or date datatype`."):
332+
validator.fit(pd.DataFrame({'datetime': [pd.Timestamp('20180310')]}))
335333
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
336334
validator.fit({'input1': 1, 'input2': 2})
337-
with pytest.raises(ValueError, match=r"has unsupported dtype string"):
335+
with pytest.raises(TypeError, match=r"invalid type `string`."):
338336
validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string'))
339337
with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"):
340338
validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]),

0 commit comments

Comments
 (0)