diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 18f1bba6ebfb8..e27c519304e2e 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -24,7 +24,7 @@ import numpy as np -from pandas._libs import lib, missing as libmissing, tslib +from pandas._libs import lib, tslib from pandas._libs.tslibs import ( NaT, OutOfBoundsDatetime, @@ -86,7 +86,12 @@ ABCSeries, ) from pandas.core.dtypes.inference import is_list_like -from pandas.core.dtypes.missing import is_valid_na_for_dtype, isna, notna +from pandas.core.dtypes.missing import ( + is_valid_na_for_dtype, + isna, + na_value_for_dtype, + notna, +) if TYPE_CHECKING: from pandas import Series @@ -529,16 +534,26 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan): dtype = np.dtype(object) return dtype, fill_value + kinds = ["i", "u", "f", "c", "m", "M"] + if is_valid_na_for_dtype(fill_value, dtype) and dtype.kind in kinds: + dtype = ensure_dtype_can_hold_na(dtype) + fv = na_value_for_dtype(dtype) + return dtype, fv + + elif isna(fill_value): + dtype = np.dtype(object) + if fill_value is None: + # but we retain e.g. pd.NA + fill_value = np.nan + return dtype, fill_value + # returns tuple of (dtype, fill_value) if issubclass(dtype.type, np.datetime64): if isinstance(fill_value, datetime) and fill_value.tzinfo is not None: # Trying to insert tzaware into tznaive, have to cast to object dtype = np.dtype(np.object_) - elif is_integer(fill_value) or (is_float(fill_value) and not isna(fill_value)): + elif is_integer(fill_value) or is_float(fill_value): dtype = np.dtype(np.object_) - elif is_valid_na_for_dtype(fill_value, dtype): - # e.g. pd.NA, which is not accepted by Timestamp constructor - fill_value = np.datetime64("NaT", "ns") else: try: fill_value = Timestamp(fill_value).to_datetime64() @@ -547,14 +562,11 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan): elif issubclass(dtype.type, np.timedelta64): if ( is_integer(fill_value) - or (is_float(fill_value) and not np.isnan(fill_value)) + or is_float(fill_value) or isinstance(fill_value, str) ): # TODO: What about str that can be a timedelta? dtype = np.dtype(np.object_) - elif is_valid_na_for_dtype(fill_value, dtype): - # e.g pd.NA, which is not accepted by the Timedelta constructor - fill_value = np.timedelta64("NaT", "ns") else: try: fv = Timedelta(fill_value) @@ -615,17 +627,6 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan): # e.g. mst is np.complex128 and dtype is np.complex64 dtype = mst - elif fill_value is None or fill_value is libmissing.NA: - # Note: we already excluded dt64/td64 dtypes above - if is_float_dtype(dtype) or is_complex_dtype(dtype): - fill_value = np.nan - elif is_integer_dtype(dtype): - dtype = np.dtype(np.float64) - fill_value = np.nan - else: - dtype = np.dtype(np.object_) - if fill_value is not libmissing.NA: - fill_value = np.nan else: dtype = np.dtype(np.object_)