-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes #113064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
053d0b7
d5da538
52b49ac
6d3204c
f8a21fc
73df6fb
93156b1
90868b8
d216d43
6543732
75c8264
e5b10a3
b65d7d6
c9b2100
a1ae520
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
#include "IRModule.h" | ||
|
||
#include "PybindUtils.h" | ||
#include <pybind11/numpy.h> | ||
|
||
#include "llvm/ADT/ScopeExit.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute | |
throw py::error_already_set(); | ||
} | ||
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); | ||
SmallVector<int64_t> shape; | ||
if (explicitShape) { | ||
shape.append(explicitShape->begin(), explicitShape->end()); | ||
} else { | ||
shape.append(view.shape, view.shape + view.ndim); | ||
} | ||
|
||
MlirAttribute encodingAttr = mlirAttributeGetNull(); | ||
MlirContext context = contextWrapper->get(); | ||
|
||
// Detect format codes that are suitable for bulk loading. This includes | ||
// all byte aligned integer and floating point types up to 8 bytes. | ||
// Notably, this excludes, bool (which needs to be bit-packed) and | ||
// other exotics which do not have a direct representation in the buffer | ||
// protocol (i.e. complex, etc). | ||
std::optional<MlirType> bulkLoadElementType; | ||
if (explicitType) { | ||
bulkLoadElementType = *explicitType; | ||
} else { | ||
std::string_view format(view.format); | ||
if (format == "f") { | ||
// f32 | ||
assert(view.itemsize == 4 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF32TypeGet(context); | ||
} else if (format == "d") { | ||
// f64 | ||
assert(view.itemsize == 8 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF64TypeGet(context); | ||
} else if (format == "e") { | ||
// f16 | ||
assert(view.itemsize == 2 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF16TypeGet(context); | ||
} else if (isSignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeSignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeSignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeSignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeSignedGet(context, 16); | ||
} | ||
} else if (isUnsignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// unsigned i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeUnsignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// unsigned i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeUnsignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeUnsignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeUnsignedGet(context, 16); | ||
} | ||
} | ||
if (!bulkLoadElementType) { | ||
throw std::invalid_argument( | ||
std::string("unimplemented array format conversion from format: ") + | ||
std::string(format)); | ||
} | ||
} | ||
|
||
MlirType shapedType; | ||
if (mlirTypeIsAShaped(*bulkLoadElementType)) { | ||
if (explicitShape) { | ||
throw std::invalid_argument("Shape can only be specified explicitly " | ||
"when the type is not a shaped type."); | ||
} | ||
shapedType = *bulkLoadElementType; | ||
} else { | ||
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), | ||
*bulkLoadElementType, encodingAttr); | ||
} | ||
size_t rawBufferSize = view.len; | ||
MlirAttribute attr = | ||
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); | ||
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, | ||
explicitShape, context); | ||
if (mlirAttributeIsNull(attr)) { | ||
throw std::invalid_argument( | ||
"DenseElementsAttr could not be constructed from the given buffer. " | ||
|
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute | |
// unsigned i16 | ||
return bufferInfo<uint16_t>(shapedType); | ||
} | ||
} else if (mlirTypeIsAInteger(elementType) && | ||
mlirIntegerTypeGetWidth(elementType) == 1) { | ||
// i1 / bool | ||
// We can not send the buffer directly back to Python, because the i1 | ||
// values are bitpacked within MLIR. We call numpy's unpackbits function | ||
// to convert the bytes. | ||
return getBooleanBufferFromBitpackedAttribute(); | ||
} | ||
|
||
// TODO: Currently crashes the program. | ||
|
@@ -1016,14 +931,177 @@ class PyDenseElementsAttribute | |
code == 'q'; | ||
} | ||
|
||
static MlirType | ||
getShapedType(std::optional<MlirType> bulkLoadElementType, | ||
std::optional<std::vector<int64_t>> explicitShape, | ||
Py_buffer &view) { | ||
SmallVector<int64_t> shape; | ||
if (explicitShape) { | ||
shape.append(explicitShape->begin(), explicitShape->end()); | ||
} else { | ||
shape.append(view.shape, view.shape + view.ndim); | ||
} | ||
|
||
if (mlirTypeIsAShaped(*bulkLoadElementType)) { | ||
if (explicitShape) { | ||
throw std::invalid_argument("Shape can only be specified explicitly " | ||
"when the type is not a shaped type."); | ||
} | ||
return *bulkLoadElementType; | ||
} else { | ||
MlirAttribute encodingAttr = mlirAttributeGetNull(); | ||
return mlirRankedTensorTypeGet(shape.size(), shape.data(), | ||
*bulkLoadElementType, encodingAttr); | ||
} | ||
} | ||
|
||
static MlirAttribute getAttributeFromBuffer( | ||
Py_buffer &view, bool signless, std::optional<PyType> explicitType, | ||
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { | ||
// Detect format codes that are suitable for bulk loading. This includes | ||
// all byte aligned integer and floating point types up to 8 bytes. | ||
// Notably, this excludes exotics types which do not have a direct | ||
// representation in the buffer protocol (i.e. complex, etc). | ||
std::optional<MlirType> bulkLoadElementType; | ||
if (explicitType) { | ||
bulkLoadElementType = *explicitType; | ||
} else { | ||
std::string_view format(view.format); | ||
if (format == "f") { | ||
// f32 | ||
assert(view.itemsize == 4 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF32TypeGet(context); | ||
} else if (format == "d") { | ||
// f64 | ||
assert(view.itemsize == 8 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF64TypeGet(context); | ||
} else if (format == "e") { | ||
// f16 | ||
assert(view.itemsize == 2 && "mismatched array itemsize"); | ||
bulkLoadElementType = mlirF16TypeGet(context); | ||
} else if (format == "?") { | ||
kasper0406 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// i1 | ||
// The i1 type needs to be bit-packed, so we will handle it seperately | ||
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, | ||
context); | ||
} else if (isSignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeSignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeSignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeSignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeSignedGet(context, 16); | ||
} | ||
} else if (isUnsignedIntegerFormat(format)) { | ||
if (view.itemsize == 4) { | ||
// unsigned i32 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 32) | ||
: mlirIntegerTypeUnsignedGet(context, 32); | ||
} else if (view.itemsize == 8) { | ||
// unsigned i64 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 64) | ||
: mlirIntegerTypeUnsignedGet(context, 64); | ||
} else if (view.itemsize == 1) { | ||
// i8 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 8) | ||
: mlirIntegerTypeUnsignedGet(context, 8); | ||
} else if (view.itemsize == 2) { | ||
// i16 | ||
bulkLoadElementType = signless | ||
? mlirIntegerTypeGet(context, 16) | ||
: mlirIntegerTypeUnsignedGet(context, 16); | ||
} | ||
} | ||
if (!bulkLoadElementType) { | ||
throw std::invalid_argument( | ||
std::string("unimplemented array format conversion from format: ") + | ||
std::string(format)); | ||
} | ||
} | ||
|
||
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); | ||
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); | ||
} | ||
|
||
// There is a complication for boolean numpy arrays, as numpy represents them | ||
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans | ||
// per byte. This function does the bit-packing respecting endianess. | ||
static MlirAttribute getBitpackedAttributeFromBooleanBuffer( | ||
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, | ||
MlirContext &context) { | ||
// First read the content of the python buffer as u8's, to correct for | ||
// endianess | ||
|
||
MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8), | ||
explicitShape, view); | ||
MlirAttribute intermediateAttr = | ||
mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf); | ||
|
||
uint8_t *unpackedData = static_cast<uint8_t *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(intermediateAttr))); | ||
py::array_t<uint8_t> unpackedArray(view.len, unpackedData); | ||
|
||
py::module numpy = py::module::import("numpy"); | ||
py::object packbits_func = numpy.attr("packbits"); | ||
py::object packed_booleans = | ||
packbits_func(unpackedArray, "bitorder"_a = "little"); | ||
Comment on lines
+1060
to
+1061
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ................ you can do this ..................... wow color me shocked I never noticed/knew you could do "kwargs" like this on cpp side. |
||
py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request(); | ||
|
||
MlirType bitpackedType = | ||
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); | ||
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, | ||
pythonBuffer.ptr); | ||
} | ||
|
||
// This does the opposite transformation of | ||
// `getBitpackedAttributeFromBooleanBuffer` | ||
py::buffer_info getBooleanBufferFromBitpackedAttribute() { | ||
int64_t numBooleans = mlirElementsAttrGetNumElements(*this); | ||
int64_t numBitpackedBytes = (numBooleans + 7) / 8; | ||
kasper0406 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
uint8_t *bitpackedData = static_cast<uint8_t *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); | ||
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); | ||
|
||
py::module numpy = py::module::import("numpy"); | ||
py::object unpackbits_func = numpy.attr("unpackbits"); | ||
py::object unpacked_booleans = | ||
unpackbits_func(packedArray, "bitorder"_a = "little"); | ||
py::buffer_info pythonBuffer = | ||
unpacked_booleans.cast<py::buffer>().request(); | ||
|
||
MlirType shapedType = mlirAttributeGetType(*this); | ||
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?"); | ||
} | ||
|
||
template <typename Type> | ||
py::buffer_info bufferInfo(MlirType shapedType, | ||
const char *explicitFormat = nullptr) { | ||
intptr_t rank = mlirShapedTypeGetRank(shapedType); | ||
// Prepare the data for the buffer_info. | ||
// Buffer is configured for read-only access below. | ||
// Buffer is configured for read-only access inside the `bufferInfo` call. | ||
Type *data = static_cast<Type *>( | ||
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); | ||
return bufferInfo<Type>(shapedType, data, explicitFormat); | ||
} | ||
|
||
template <typename Type> | ||
py::buffer_info bufferInfo(MlirType shapedType, Type *data, | ||
const char *explicitFormat = nullptr) { | ||
intptr_t rank = mlirShapedTypeGetRank(shapedType); | ||
// Prepare the shape for the buffer_info. | ||
SmallVector<intptr_t, 4> shape; | ||
for (intptr_t i = 0; i < rank; ++i) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to move the if-statements to the
getAttributeFromBuffer
method, as now the i1 case will not follow the usual flow, but instead callgetBitpackedAttributeFromBooleanBuffer
to construct the MlirAttribute.