Skip to content

[CIR] Clean up IntAttr #1725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::Value getConstAPSInt(mlir::Location loc, const llvm::APSInt &val) {
auto ty =
cir::IntType::get(getContext(), val.getBitWidth(), val.isSigned());
return create<cir::ConstantOp>(loc, getAttr<cir::IntAttr>(ty, val));
return create<cir::ConstantOp>(loc, cir::IntAttr::get(ty, val));
}

mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits) {
Expand All @@ -63,7 +63,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {

mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
const llvm::APInt &val) {
return create<cir::ConstantOp>(loc, getAttr<cir::IntAttr>(typ, val));
return create<cir::ConstantOp>(loc, cir::IntAttr::get(typ, val));
}

cir::ConstantOp getConstant(mlir::Location loc, mlir::TypedAttr attr) {
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"

#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"

//===----------------------------------------------------------------------===//
// CIR Dialect Attrs
Expand Down
59 changes: 49 additions & 10 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -406,32 +406,71 @@ def ConstRecordAttr : CIR_Attr<"ConstRecord", "const_record",
// IntegerAttr
//===----------------------------------------------------------------------===//

def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
def CIR_IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
let summary = "An Attribute containing a integer value";
let description = [{
An integer attribute is a literal attribute that represents an integral
value of the specified integer type.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "llvm::APInt":$value);

let parameters = (ins
AttributeSelfTypeParameter<"", "cir::IntTypeInterface">:$type,
"llvm::APInt":$value
);

let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"const llvm::APInt &":$value), [{
return $_get(type.getContext(), type, value);
return $_get(type.getContext(),
mlir::cast<cir::IntTypeInterface>(type), value);
}]>,
AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "int64_t":$value), [{
IntType intType = mlir::cast<IntType>(type);
auto intType = mlir::cast<cir::IntTypeInterface>(type);
mlir::APInt apValue(intType.getWidth(), value, intType.isSigned());
return $_get(intType.getContext(), intType, apValue);
}]>,
];

let extraClassDeclaration = [{
int64_t getSInt() const { return getValue().getSExtValue(); }
uint64_t getUInt() const { return getValue().getZExtValue(); }
bool isNullValue() const { return getValue() == 0; }
uint64_t getBitWidth() const { return mlir::cast<IntType>(getType()).getWidth(); }
int64_t getSInt() const;
uint64_t getUInt() const;
bool isNullValue() const;
bool isSigned() const;
bool isUnsigned() const;
uint64_t getBitWidth() const;
}];

let extraClassDefinition = [{
int64_t $cppClass::getSInt() const {
return getValue().getSExtValue();
}

uint64_t $cppClass::getUInt() const {
return getValue().getZExtValue();
}

bool $cppClass::isNullValue() const {
return getValue() == 0;
}

bool $cppClass::isSigned() const {
return mlir::cast<IntTypeInterface>(getType()).isSigned();
}

bool $cppClass::isUnsigned() const {
return mlir::cast<IntTypeInterface>(getType()).isUnsigned();
}

uint64_t $cppClass::getBitWidth() const {
return mlir::cast<IntTypeInterface>(getType()).getWidth();
}
}];

let assemblyFormat = [{
`<` custom<IntLiteral>($value, ref($type)) `>`
}];

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -892,7 +931,7 @@ def DynamicCastInfoAttr
GlobalViewAttr:$destRtti,
"mlir::FlatSymbolRefAttr":$runtimeFunc,
"mlir::FlatSymbolRefAttr":$badCastFunc,
IntAttr:$offsetHint);
CIR_IntAttr:$offsetHint);

let builders = [
AttrBuilderWithInferredContext<(ins "GlobalViewAttr":$srcRtti,
Expand Down
13 changes: 6 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ bool ConstantAggregateBuilder::addBits(llvm::APInt Bits, uint64_t OffsetInBits,

if (*FirstElemToUpdate == *LastElemToUpdate || isNull) {
// All existing bits are either zero or undef.
add(CGM.getBuilder().getAttr<cir::IntAttr>(charTy, BitsThisChar),
OffsetInChars, /*AllowOverwrite*/ true);
add(cir::IntAttr::get(charTy, BitsThisChar), OffsetInChars,
/*AllowOverwrite*/ true);
} else {
cir::IntAttr CI = dyn_cast<cir::IntAttr>(Elems[*FirstElemToUpdate]);
// In order to perform a partial update, we need the existing bitwise
Expand All @@ -286,8 +286,7 @@ bool ConstantAggregateBuilder::addBits(llvm::APInt Bits, uint64_t OffsetInBits,
assert((!(CI.getValue() & UpdateMask) || AllowOverwrite) &&
"unexpectedly overwriting bitfield");
BitsThisChar |= (CI.getValue() & ~UpdateMask);
Elems[*FirstElemToUpdate] =
CGM.getBuilder().getAttr<cir::IntAttr>(charTy, BitsThisChar);
Elems[*FirstElemToUpdate] = cir::IntAttr::get(charTy, BitsThisChar);
}
}

Expand Down Expand Up @@ -1906,7 +1905,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
if (mlir::isa<cir::BoolType>(ty))
return builder.getCIRBoolAttr(Value.getInt().getZExtValue());
assert(mlir::isa<cir::IntType>(ty) && "expected integral type");
return CGM.getBuilder().getAttr<cir::IntAttr>(ty, Value.getInt());
return cir::IntAttr::get(ty, Value.getInt());
}
case APValue::Float: {
const llvm::APFloat &Init = Value.getFloat();
Expand Down Expand Up @@ -2018,8 +2017,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
llvm::APSInt real = Value.getComplexIntReal();
llvm::APSInt imag = Value.getComplexIntImag();
return builder.getAttr<cir::ComplexAttr>(
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
complexType, cir::IntAttr::get(complexElemTy, real),
cir::IntAttr::get(complexElemTy, imag));
}

assert(isa<cir::FPTypeInterface>(complexElemTy) &&
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
// Leaves.
mlir::Value VisitIntegerLiteral(const IntegerLiteral *E) {
mlir::Type Ty = CGF.convertType(E->getType());
return Builder.create<cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()),
Builder.getAttr<cir::IntAttr>(Ty, E->getValue()));
return Builder.getConstAPInt(CGF.getLoc(E->getExprLoc()), Ty,
E->getValue());
}

mlir::Value VisitFixedPointLiteral(const FixedPointLiteral *E) {
Expand Down
105 changes: 57 additions & 48 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,40 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"

#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"

// ClangIR holds back AST references when available.
#include "clang/AST/Decl.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/ExprCXX.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SMLoc.h"

//===-----------------------------------------------------------------===//
// RecordMembers
//===-----------------------------------------------------------------===//

static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
static mlir::ParseResult parseRecordMembers(::mlir::AsmParser &parser,
mlir::ArrayAttr &members);

//===-----------------------------------------------------------------===//
// IntLiteral
//===-----------------------------------------------------------------===//

static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
cir::IntTypeInterface ty);

static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
llvm::APInt &value,
cir::IntTypeInterface ty);

//===-----------------------------------------------------------------===//
// FloatLiteral
//===-----------------------------------------------------------------===//

static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty);
static mlir::ParseResult
Expand Down Expand Up @@ -204,65 +226,52 @@ static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
// IntAttr definitions
//===----------------------------------------------------------------------===//

Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
mlir::APInt APValue;
template <typename IntT>
static bool isTooLargeForType(const mlir::APInt &v, IntT expectedValue) {
if constexpr (std::is_signed_v<IntT>) {
return v.getSExtValue() != expectedValue;
} else {
return v.getZExtValue() != expectedValue;
}
}

if (!mlir::isa<IntType>(odsType))
return {};
auto type = mlir::cast<IntType>(odsType);
template <typename IntT>
static ParseResult parseIntLiteralImpl(mlir::AsmParser &p, llvm::APInt &value,
cir::IntTypeInterface ty) {
IntT ivalue;
const bool isSigned = ty.isSigned();
if (p.parseInteger(ivalue))
return p.emitError(p.getCurrentLocation(), "expected integer value");

// Consume the '<' symbol.
if (parser.parseLess())
return {};
value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
if (isTooLargeForType(value, ivalue))

// Fetch arbitrary precision integer value.
if (type.isSigned()) {
int64_t value;
if (parser.parseInteger(value))
parser.emitError(parser.getCurrentLocation(), "expected integer value");
APValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
/*implicitTrunc=*/true);
if (APValue.getSExtValue() != value)
parser.emitError(parser.getCurrentLocation(),
return p.emitError(p.getCurrentLocation(),
"integer value too large for the given type");
} else {
uint64_t value;
if (parser.parseInteger(value))
parser.emitError(parser.getCurrentLocation(), "expected integer value");
APValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
/*implicitTrunc=*/true);
if (APValue.getZExtValue() != value)
parser.emitError(parser.getCurrentLocation(),
"integer value too large for the given type");
}

// Consume the '>' symbol.
if (parser.parseGreater())
return {};
return success();
}

return IntAttr::get(type, APValue);
mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
cir::IntTypeInterface ty) {
if (ty.isSigned())
return parseIntLiteralImpl<int64_t>(parser, value, ty);
return parseIntLiteralImpl<uint64_t>(parser, value, ty);
}

void IntAttr::print(AsmPrinter &printer) const {
auto type = mlir::cast<IntType>(getType());
printer << '<';
if (type.isSigned())
printer << getSInt();
void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
cir::IntTypeInterface ty) {
if (ty.isSigned())
p << value.getSExtValue();
else
printer << getUInt();
printer << '>';
p << value.getZExtValue();
}

LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APInt value) {
if (!mlir::isa<IntType>(type))
return emitError() << "expected 'simple.int' type";

auto intType = mlir::cast<IntType>(type);
if (value.getBitWidth() != intType.getWidth())
cir::IntTypeInterface type, llvm::APInt value) {
if (value.getBitWidth() != type.getWidth())
return emitError() << "type and value bitwidth mismatch: "
<< intType.getWidth() << " != " << value.getBitWidth();

<< type.getWidth() << " != " << value.getBitWidth();
return success();
}

Expand Down Expand Up @@ -481,8 +490,8 @@ LogicalResult
GlobalAnnotationValuesAttr::verify(function_ref<InFlightDiagnostic()> emitError,
mlir::ArrayAttr annotations) {
if (annotations.empty())
return emitError()
<< "GlobalAnnotationValuesAttr should at least have one annotation";
return emitError() << "GlobalAnnotationValuesAttr should at least have "
"one annotation";

for (auto &entry : annotations) {
auto annoEntry = ::mlir::dyn_cast<mlir::ArrayAttr>(entry);
Expand Down
5 changes: 5 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,11 @@ module {
module {
// expected-error@below {{integer value too large for the given type}}
cir.global external @a = #cir.int<256> : !cir.int<u, 8>
}

// -----

module {
// expected-error@below {{integer value too large for the given type}}
cir.global external @b = #cir.int<-129> : !cir.int<s, 8>
}
Expand Down
Loading