Skip to content

[CIR] Clean up IntAttr #146661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {

mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
const llvm::APInt &val) {
return create<cir::ConstantOp>(loc, getAttr<cir::IntAttr>(typ, val));
return create<cir::ConstantOp>(loc, cir::IntAttr::get(typ, val));
}

cir::ConstantOp getConstant(mlir::Location loc, mlir::TypedAttr attr) {
Expand Down
53 changes: 42 additions & 11 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,36 +117,67 @@ def UndefAttr : CIR_TypedAttr<"Undef", "undef"> {
// IntegerAttr
//===----------------------------------------------------------------------===//

def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
def CIR_IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
let summary = "An attribute containing an integer value";
let description = [{
An integer attribute is a literal attribute that represents an integral
value of the specified integer type.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type,
APIntParameter<"">:$value);

let parameters = (ins
AttributeSelfTypeParameter<"", "cir::IntTypeInterface">:$type,
APIntParameter<"">:$value
);

let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"const llvm::APInt &":$value), [{
return $_get(type.getContext(), type, value);
auto intType = mlir::cast<cir::IntTypeInterface>(type);
return $_get(type.getContext(), intType, value);
}]>,
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"int64_t":$value), [{
IntType intType = mlir::cast<IntType>(type);
auto intType = mlir::cast<cir::IntTypeInterface>(type);
mlir::APInt apValue(intType.getWidth(), value, intType.isSigned());
return $_get(intType.getContext(), intType, apValue);
}]>,
];

let extraClassDeclaration = [{
int64_t getSInt() const { return getValue().getSExtValue(); }
uint64_t getUInt() const { return getValue().getZExtValue(); }
bool isNullValue() const { return getValue() == 0; }
uint64_t getBitWidth() const {
return mlir::cast<IntType>(getType()).getWidth();
int64_t getSInt() const;
uint64_t getUInt() const;
bool isNullValue() const;
bool isSigned() const;
bool isUnsigned() const;
uint64_t getBitWidth() const;
}];

let extraClassDefinition = [{
int64_t $cppClass::getSInt() const {
return getValue().getSExtValue();
}
uint64_t $cppClass::getUInt() const {
return getValue().getZExtValue();
}
bool $cppClass::isNullValue() const {
return getValue() == 0;
}
bool $cppClass::isSigned() const {
return mlir::cast<IntTypeInterface>(getType()).isSigned();
}
bool $cppClass::isUnsigned() const {
return mlir::cast<IntTypeInterface>(getType()).isUnsigned();
}
uint64_t $cppClass::getBitWidth() const {
return mlir::cast<IntTypeInterface>(getType()).getWidth();
}
}];

let assemblyFormat = [{
`<` custom<IntLiteral>($value, ref($type)) `>`
}];

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
if (mlir::isa<cir::BoolType>(ty))
return builder.getCIRBoolAttr(value.getInt().getZExtValue());
assert(mlir::isa<cir::IntType>(ty) && "expected integral type");
return cgm.getBuilder().getAttr<cir::IntAttr>(ty, value.getInt());
return cir::IntAttr::get(ty, value.getInt());
}
case APValue::Float: {
const llvm::APFloat &init = value.getFloat();
Expand Down Expand Up @@ -789,8 +789,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
llvm::APSInt real = value.getComplexIntReal();
llvm::APSInt imag = value.getComplexIntImag();
return builder.getAttr<cir::ConstComplexAttr>(
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
complexType, cir::IntAttr::get(complexElemTy, real),
cir::IntAttr::get(complexElemTy, imag));
}

assert(isa<cir::FPTypeInterface>(complexElemTy) &&
Expand Down
15 changes: 7 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value VisitIntegerLiteral(const IntegerLiteral *e) {
mlir::Type type = cgf.convertType(e->getType());
return builder.create<cir::ConstantOp>(
cgf.getLoc(e->getExprLoc()),
builder.getAttr<cir::IntAttr>(type, e->getValue()));
cgf.getLoc(e->getExprLoc()), cir::IntAttr::get(type, e->getValue()));
}

mlir::Value VisitFloatingLiteral(const FloatingLiteral *e) {
Expand Down Expand Up @@ -1970,21 +1969,21 @@ mlir::Value ScalarExprEmitter::VisitUnaryExprOrTypeTraitExpr(
"sizeof operator for VariableArrayType",
e->getStmtClassName());
return builder.getConstant(
loc, builder.getAttr<cir::IntAttr>(
cgf.cgm.UInt64Ty, llvm::APSInt(llvm::APInt(64, 1), true)));
loc, cir::IntAttr::get(cgf.cgm.UInt64Ty,
llvm::APSInt(llvm::APInt(64, 1), true)));
}
} else if (e->getKind() == UETT_OpenMPRequiredSimdAlign) {
cgf.getCIRGenModule().errorNYI(
e->getSourceRange(), "sizeof operator for OpenMpRequiredSimdAlign",
e->getStmtClassName());
return builder.getConstant(
loc, builder.getAttr<cir::IntAttr>(
cgf.cgm.UInt64Ty, llvm::APSInt(llvm::APInt(64, 1), true)));
loc, cir::IntAttr::get(cgf.cgm.UInt64Ty,
llvm::APSInt(llvm::APInt(64, 1), true)));
}

return builder.getConstant(
loc, builder.getAttr<cir::IntAttr>(
cgf.cgm.UInt64Ty, e->EvaluateKnownConstInt(cgf.getContext())));
loc, cir::IntAttr::get(cgf.cgm.UInt64Ty,
e->EvaluateKnownConstInt(cgf.getContext())));
}

/// Return true if the specified expression is cheap enough and side-effect-free
Expand Down
100 changes: 48 additions & 52 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"

//===-----------------------------------------------------------------===//
// IntLiteral
//===-----------------------------------------------------------------===//

static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
cir::IntTypeInterface ty);
static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
llvm::APInt &value,
cir::IntTypeInterface ty);
//===-----------------------------------------------------------------===//
// FloatLiteral
//===-----------------------------------------------------------------===//

static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty);
static mlir::ParseResult
Expand Down Expand Up @@ -82,69 +95,52 @@ static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
// IntAttr definitions
//===----------------------------------------------------------------------===//

Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
mlir::APInt apValue;

if (!mlir::isa<IntType>(odsType))
return {};
auto type = mlir::cast<IntType>(odsType);

// Consume the '<' symbol.
if (parser.parseLess())
return {};

// Fetch arbitrary precision integer value.
if (type.isSigned()) {
int64_t value = 0;
if (parser.parseInteger(value)) {
parser.emitError(parser.getCurrentLocation(), "expected integer value");
} else {
apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
/*implicitTrunc=*/true);
if (apValue.getSExtValue() != value)
parser.emitError(parser.getCurrentLocation(),
"integer value too large for the given type");
}
template <typename IntT>
static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
if constexpr (std::is_signed_v<IntT>) {
return value.getSExtValue() != expectedValue;
} else {
uint64_t value = 0;
if (parser.parseInteger(value)) {
parser.emitError(parser.getCurrentLocation(), "expected integer value");
} else {
apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
/*implicitTrunc=*/true);
if (apValue.getZExtValue() != value)
parser.emitError(parser.getCurrentLocation(),
"integer value too large for the given type");
}
return value.getZExtValue() != expectedValue;
}
}

// Consume the '>' symbol.
if (parser.parseGreater())
return {};
template <typename IntT>
static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
llvm::APInt &value,
cir::IntTypeInterface ty) {
IntT ivalue;
const bool isSigned = ty.isSigned();
if (p.parseInteger(ivalue))
return p.emitError(p.getCurrentLocation(), "expected integer value");

value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
if (isTooLargeForType(value, ivalue))
return p.emitError(p.getCurrentLocation(),
"integer value too large for the given type");

return IntAttr::get(type, apValue);
return success();
}

mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
cir::IntTypeInterface ty) {
if (ty.isSigned())
return parseIntLiteralImpl<int64_t>(parser, value, ty);
return parseIntLiteralImpl<uint64_t>(parser, value, ty);
}

void IntAttr::print(AsmPrinter &printer) const {
auto type = mlir::cast<IntType>(getType());
printer << '<';
if (type.isSigned())
printer << getSInt();
void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
cir::IntTypeInterface ty) {
if (ty.isSigned())
p << value.getSExtValue();
else
printer << getUInt();
printer << '>';
p << value.getZExtValue();
}

LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APInt value) {
if (!mlir::isa<IntType>(type))
return emitError() << "expected 'simple.int' type";

auto intType = mlir::cast<IntType>(type);
if (value.getBitWidth() != intType.getWidth())
cir::IntTypeInterface type, llvm::APInt value) {
if (value.getBitWidth() != type.getWidth())
return emitError() << "type and value bitwidth mismatch: "
<< intType.getWidth() << " != " << value.getBitWidth();

<< type.getWidth() << " != " << value.getBitWidth();
return success();
}

Expand Down
Loading