Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion python/pyarrow/src/arrow/python/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,68 @@ Status ImportPresentIntervalTypes(OwnedRefNoGIL* interval_types_tuple) {

} // namespace

namespace {
// Helper for integer type inference
std::shared_ptr<DataType> InferIntegerType(PyObject* min_obj, PyObject* max_obj) {
if (min_obj == nullptr) {
return int64();
}

int min_sign = _PyLong_Sign(min_obj);
if (min_sign == -1) {
// Negative values, so must be signed
int64_t min_val = PyLong_AsLongLong(min_obj);
if (PyErr_Occurred()) {
PyErr_Clear();
return int64();
}
int64_t max_val = PyLong_AsLongLong(max_obj);
if (PyErr_Occurred()) {
PyErr_Clear();
return int64();
}
if (min_val >= std::numeric_limits<int8_t>::min() &&
max_val <= std::numeric_limits<int8_t>::max()) {
return int8();
} else if (min_val >= std::numeric_limits<int16_t>::min() &&
max_val <= std::numeric_limits<int16_t>::max()) {
return int16();
} else if (min_val >= std::numeric_limits<int32_t>::min() &&
max_val <= std::numeric_limits<int32_t>::max()) {
return int32();
} else {
return int64();
}
} else {
// Non-negative values, could be signed or unsigned
uint64_t max_val = PyLong_AsUnsignedLongLong(max_obj);
if (PyErr_Occurred()) {
PyErr_Clear();
// Might be a large signed integer, that doesn't fit in uint64
return int64();
}

if (max_val <= static_cast<uint64_t>(std::numeric_limits<int8_t>::max())) {
return int8();
} else if (max_val <= std::numeric_limits<uint8_t>::max()) {
return uint8();
} else if (max_val <= static_cast<uint64_t>(std::numeric_limits<int16_t>::max())) {
return int16();
} else if (max_val <= std::numeric_limits<uint16_t>::max()) {
return uint16();
} else if (max_val <= static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
return int32();
} else if (max_val <= std::numeric_limits<uint32_t>::max()) {
return uint32();
} else if (max_val <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return int64();
} else {
return uint64();
}
}
}
} // namespace

#define _NUMPY_UNIFY_NOOP(DTYPE) \
case NPY_##DTYPE: \
return OK;
Expand Down Expand Up @@ -331,6 +393,8 @@ class TypeInferrer {
none_count_(0),
bool_count_(0),
int_count_(0),
min_int_(nullptr),
max_int_(nullptr),
date_count_(0),
time_count_(0),
timestamp_micro_count_(0),
Expand Down Expand Up @@ -368,6 +432,23 @@ class TypeInferrer {
*keep_going = make_unions_;
} else if (internal::IsPyInteger(obj)) {
++int_count_;
if (min_int_ == nullptr) {
min_int_ = obj;
max_int_ = obj;
Py_INCREF(min_int_);
Py_INCREF(max_int_);
} else {
if (PyObject_RichCompareBool(obj, min_int_, Py_LT)) {
Py_DECREF(min_int_);
min_int_ = obj;
Py_INCREF(min_int_);
}
if (PyObject_RichCompareBool(obj, max_int_, Py_GT)) {
Py_DECREF(max_int_);
max_int_ = obj;
Py_INCREF(max_int_);
}
}
} else if (PyDateTime_Check(obj)) {
// infer timezone from the first encountered datetime object
if (!timestamp_micro_count_) {
Expand Down Expand Up @@ -524,7 +605,11 @@ class TypeInferrer {
// Prioritize floats before integers
*out = float64();
} else if (int_count_) {
*out = int64();
*out = InferIntegerType(min_int_, max_int_);
Py_XDECREF(min_int_);
Py_XDECREF(max_int_);
min_int_ = nullptr;
max_int_ = nullptr;
} else if (date_count_) {
*out = date32();
} else if (time_count_) {
Expand Down Expand Up @@ -684,6 +769,8 @@ class TypeInferrer {
int64_t none_count_;
int64_t bool_count_;
int64_t int_count_;
PyObject* min_int_;
PyObject* max_int_;
int64_t date_count_;
int64_t time_count_;
int64_t timestamp_micro_count_;
Expand Down