diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index ead81a76c0538..c8883c0d8270a 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -13,6 +13,7 @@ #include "IRModule.h" #include "PybindUtils.h" +#include #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -757,103 +758,10 @@ class PyDenseElementsAttribute throw py::error_already_set(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); - SmallVector shape; - if (explicitShape) { - shape.append(explicitShape->begin(), explicitShape->end()); - } else { - shape.append(view.shape, view.shape + view.ndim); - } - MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirContext context = contextWrapper->get(); - - // Detect format codes that are suitable for bulk loading. This includes - // all byte aligned integer and floating point types up to 8 bytes. - // Notably, this excludes, bool (which needs to be bit-packed) and - // other exotics which do not have a direct representation in the buffer - // protocol (i.e. complex, etc). - std::optional bulkLoadElementType; - if (explicitType) { - bulkLoadElementType = *explicitType; - } else { - std::string_view format(view.format); - if (format == "f") { - // f32 - assert(view.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (format == "d") { - // f64 - assert(view.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (format == "e") { - // f16 - assert(view.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (isSignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (view.itemsize == 8) { - // i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (view.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } - } - if (!bulkLoadElementType) { - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - std::string(format)); - } - } - - MlirType shapedType; - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); - } - shapedType = *bulkLoadElementType; - } else { - shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); - } - size_t rawBufferSize = view.len; - MlirAttribute attr = - mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); + MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, + explicitShape, context); if (mlirAttributeIsNull(attr)) { throw std::invalid_argument( "DenseElementsAttr could not be constructed from the given buffer. " @@ -963,6 +871,13 @@ class PyDenseElementsAttribute // unsigned i16 return bufferInfo(shapedType); } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 1) { + // i1 / bool + // We can not send the buffer directly back to Python, because the i1 + // values are bitpacked within MLIR. We call numpy's unpackbits function + // to convert the bytes. + return getBooleanBufferFromBitpackedAttribute(); } // TODO: Currently crashes the program. @@ -1016,14 +931,183 @@ class PyDenseElementsAttribute code == 'q'; } + static MlirType + getShapedType(std::optional bulkLoadElementType, + std::optional> explicitShape, + Py_buffer &view) { + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(view.shape, view.shape + view.ndim); + } + + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + return *bulkLoadElementType; + } else { + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + } + + static MlirAttribute getAttributeFromBuffer( + Py_buffer &view, bool signless, std::optional explicitType, + std::optional> explicitShape, MlirContext &context) { + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes exotics types which do not have a direct + // representation in the buffer protocol (i.e. complex, etc). + std::optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (format == "?") { + // i1 + // The i1 type needs to be bit-packed, so we will handle it seperately + return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, + context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); + } + } + if (!bulkLoadElementType) { + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + std::string(format)); + } + } + + MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); + } + + // There is a complication for boolean numpy arrays, as numpy represents them + // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans + // per byte. + static MlirAttribute getBitpackedAttributeFromBooleanBuffer( + Py_buffer &view, std::optional> explicitShape, + MlirContext &context) { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a bit-packed MLIR attribute is " + "unsupported on big-endian systems"); + } + + py::array_t unpackedArray(view.len, + static_cast(view.buf)); + + py::module numpy = py::module::import("numpy"); + py::object packbits_func = numpy.attr("packbits"); + py::object packed_booleans = + packbits_func(unpackedArray, "bitorder"_a = "little"); + py::buffer_info pythonBuffer = packed_booleans.cast().request(); + + MlirType bitpackedType = + getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, + pythonBuffer.ptr); + } + + // This does the opposite transformation of + // `getBitpackedAttributeFromBooleanBuffer` + py::buffer_info getBooleanBufferFromBitpackedAttribute() { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a numpy array from a MLIR attribute " + "is unsupported on big-endian systems"); + } + + int64_t numBooleans = mlirElementsAttrGetNumElements(*this); + int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); + uint8_t *bitpackedData = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + py::array_t packedArray(numBitpackedBytes, bitpackedData); + + py::module numpy = py::module::import("numpy"); + py::object unpackbits_func = numpy.attr("unpackbits"); + py::object unpacked_booleans = + unpackbits_func(packedArray, "bitorder"_a = "little"); + py::buffer_info pythonBuffer = + unpacked_booleans.cast().request(); + + MlirType shapedType = mlirAttributeGetType(*this); + return bufferInfo(shapedType, (bool *)pythonBuffer.ptr, "?"); + } + template py::buffer_info bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. - // Buffer is configured for read-only access below. + // Buffer is configured for read-only access inside the `bufferInfo` call. Type *data = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); + return bufferInfo(shapedType, data, explicitFormat); + } + + template + py::buffer_info bufferInfo(MlirType shapedType, Type *data, + const char *explicitFormat = nullptr) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the shape for the buffer_info. SmallVector shape; for (intptr_t i = 0; i < rank; ++i) diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py index 2bc403aace834..256a69a939658 100644 --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -326,6 +326,78 @@ def testGetDenseElementsF64(): print(np.array(attr)) +### 1 bit/boolean integer arrays +# CHECK-LABEL: TEST: testGetDenseElementsI1Signless +@run +def testGetDenseElementsI1Signless(): + with Context(): + array = np.array([True], dtype=np.bool_) + attr = DenseElementsAttr.get(array) + # CHECK: dense : tensor<1xi1> + print(attr) + # CHECK{LITERAL}: [ True] + print(np.array(attr)) + + array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_) + attr = DenseElementsAttr.get(array) + # CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1> + print(attr) + # CHECK{LITERAL}: [[ True False True] + # CHECK{LITERAL}: [ True True False]] + print(np.array(attr)) + + array = np.array( + [[True, True, False, False], [True, False, True, False]], dtype=np.bool_ + ) + attr = DenseElementsAttr.get(array) + # CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1> + print(attr) + # CHECK{LITERAL}: [[ True True False False] + # CHECK{LITERAL}: [ True False True False]] + print(np.array(attr)) + + array = np.array( + [ + [True, True, False, False], + [True, False, True, False], + [False, False, False, False], + [True, True, True, True], + [True, False, False, True], + ], + dtype=np.bool_, + ) + attr = DenseElementsAttr.get(array) + # CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1> + print(attr) + # CHECK{LITERAL}: [[ True True False False] + # CHECK{LITERAL}: [ True False True False] + # CHECK{LITERAL}: [False False False False] + # CHECK{LITERAL}: [ True True True True] + # CHECK{LITERAL}: [ True False False True]] + print(np.array(attr)) + + array = np.array( + [ + [True, True, False, False, True, True, False, False, False], + [False, False, False, True, False, True, True, False, True], + ], + dtype=np.bool_, + ) + attr = DenseElementsAttr.get(array) + # CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1> + print(attr) + # CHECK{LITERAL}: [[ True True False False True True False False False] + # CHECK{LITERAL}: [False False False True False True True False True]] + print(np.array(attr)) + + array = np.array([], dtype=np.bool_) + attr = DenseElementsAttr.get(array) + # CHECK: dense<> : tensor<0xi1> + print(attr) + # CHECK{LITERAL}: [] + print(np.array(attr)) + + ### 16 bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI16Signless @run