diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td index 0909c9abad951..c1c5bfc76e055 100644 --- a/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td @@ -45,21 +45,11 @@ def BitVectorAttr : AttrDef:$value); let hasCustomAssemblyFormat = true; let genVerifyDecl = true; - // We need to manually define the storage class because the generated one is - // buggy (because the APInt asserts matching bitwidth in the `==` operator and - // the generated storage uses that directly. - // Alternatively: add a type parameter to redundantly store the bitwidth of - // of the attribute type, it it's in the order before the 'value' it will be - // checked before the APInt equality (this is the reason it works for the - // builtin integer attribute), but would be more fragile (and we'd store - // duplicate data). - let genStorageClass = false; - let builders = [ AttrBuilder<(ins "llvm::StringRef":$value)>, AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>, diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 6826d1a437775..50dcb8de1f7e7 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -700,7 +700,7 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer", false // A bool, i.e. i1, value. ``` }]; - let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value); + let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value); let builders = [ AttrBuilderWithInferredContext<(ins "Type":$type, "const APInt &":$value), [{ diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h index c3d730e42ef70..8c1d399b39e0b 100644 --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -68,7 +68,11 @@ class AttrOrTypeParameter { /// If specified, get the custom allocator code for this parameter. std::optional getAllocator() const; - /// If specified, get the custom comparator code for this parameter. + /// Return true if user defined comparator is specified. + bool hasCustomComparator() const; + + /// Get the custom comparator code for this parameter or fallback to the + /// default. StringRef getComparator() const; /// Get the C++ type of this parameter. diff --git a/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp b/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp index c28f3558a02d2..3f40d6a42eafd 100644 --- a/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp +++ b/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp @@ -21,42 +21,6 @@ using namespace mlir::smt; // BitVectorAttr //===----------------------------------------------------------------------===// -namespace mlir { -namespace smt { -namespace detail { -struct BitVectorAttrStorage : public mlir::AttributeStorage { - using KeyTy = APInt; - BitVectorAttrStorage(APInt value) : value(std::move(value)) {} - - KeyTy getAsKey() const { return value; } - - // NOTE: the implementation of this operator is the reason we need to define - // the storage manually. The auto-generated version would just do the direct - // equality check of the APInt, but that asserts the bitwidth of both to be - // the same, leading to a crash. This implementation, therefore, checks for - // matching bit-width beforehand. - bool operator==(const KeyTy &key) const { - return (value.getBitWidth() == key.getBitWidth() && value == key); - } - - static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_value(key); - } - - static BitVectorAttrStorage * - construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) { - return new (allocator.allocate()) - BitVectorAttrStorage(std::move(key)); - } - - APInt value; -}; -} // namespace detail -} // namespace smt -} // namespace mlir - -APInt BitVectorAttr::getValue() const { return getImpl()->value; } - LogicalResult BitVectorAttr::verify( function_ref emitError, APInt value) { // NOLINT(performance-unnecessary-value-param) diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index 9e8f789d71b5e..ccb0a6c6c261e 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -278,6 +278,10 @@ std::optional AttrOrTypeParameter::getAllocator() const { return getDefValue("allocator"); } +bool AttrOrTypeParameter::hasCustomComparator() const { + return getDefValue("comparator").has_value(); +} + StringRef AttrOrTypeParameter::getComparator() const { return getDefValue("comparator").value_or("$_lhs == $_rhs"); } diff --git a/mlir/test/mlir-tblgen/apint-param-error.td b/mlir/test/mlir-tblgen/apint-param-error.td new file mode 100644 index 0000000000000..602180790bbff --- /dev/null +++ b/mlir/test/mlir-tblgen/apint-param-error.td @@ -0,0 +1,17 @@ +// RUN: not mlir-tblgen -gen-attrdef-decls -I %S/../../include %s 2>&1 | FileCheck %s + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Test_Dialect: Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; +} + +def RawAPIntAttr : AttrDef { + let mnemonic = "raw_ap_int"; + let parameters = (ins "APInt":$value); + let hasCustomAssemblyFormat = 1; +} + +// CHECK: apint-param-error.td:11:5: error: Using a raw APInt parameter diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index cf0d827942949..cef859bc24abb 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -678,8 +678,18 @@ void DefGen::emitStorageClass() { emitConstruct(); // Emit the storage class members as public, at the very end of the struct. storageCls->finalize(); - for (auto ¶m : params) + for (auto ¶m : params) { + if (param.getCppType().contains("APInt") && !param.hasCustomComparator()) { + PrintFatalError( + def.getLoc(), + "Using a raw APInt parameter without a custom comparator is " + "not supported because an assert in the equality operator is " + "triggered when the two APInts have different bit widths. This can " + "lead to unexpected crashes. Use an `APIntParameter` or " + "provide a custom comparator."); + } storageCls->declare(param.getCppType(), param.getName()); + } } //===----------------------------------------------------------------------===//