diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index a32fb3c801114..2c861fc171f6c 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -273,29 +273,36 @@ def CIR_PointerType : CIR_Type<"Pointer", "ptr", def CIR_FuncType : CIR_Type<"Func", "func"> { let summary = "CIR function type"; let description = [{ - The `!cir.func` is a function type. It consists of a single return type, a - list of parameter types and can optionally be variadic. + The `!cir.func` is a function type. It consists of an optional return type, + a list of parameter types and can optionally be variadic. Example: ```mlir + !cir.func<()> !cir.func + !cir.func<(!s8i, !s8i)> !cir.func !cir.func ``` }]; let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, - "mlir::Type":$returnType, "bool":$varArg); + "mlir::Type":$optionalReturnType, "bool":$varArg); + // Use a custom parser to handle the optional return and argument types + // without an optional anchor. let assemblyFormat = [{ - `<` $returnType ` ` `(` custom($inputs, $varArg) `>` + `<` custom($optionalReturnType, $inputs, $varArg) `>` }]; let builders = [ + // Construct with an actual return type or implicit !cir.void. TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$inputs, "mlir::Type":$returnType, CArg<"bool", "false">:$isVarArg), [{ - return $_get(returnType.getContext(), inputs, returnType, isVarArg); + return $_get(returnType.getContext(), inputs, + mlir::isa(returnType) ? nullptr : returnType, + isVarArg); }]> ]; @@ -309,11 +316,15 @@ def CIR_FuncType : CIR_Type<"Func", "func"> { /// Returns the number of arguments to the function. unsigned getNumInputs() const { return getInputs().size(); } + /// Returns the result type of the function as an actual return type or + /// explicit !cir.void. + mlir::Type getReturnType() const; + /// Returns the result type of the function as an ArrayRef, enabling better /// integration with generic MLIR utilities. llvm::ArrayRef getReturnTypes() const; - /// Returns whether the function is returns void. + /// Returns whether the function returns void. bool isVoid() const; /// Returns a clone of this function type with the given argument diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index 551b43ef121b3..72ace7520e087 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) { mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) { assert(qft.isCanonical()); const FunctionType *ft = cast(qft.getTypePtr()); - // First, check whether we can build the full fucntion type. If the function + // First, check whether we can build the full function type. If the function // type depends on an incomplete type (e.g. a struct or enum), we cannot lower // the function type. if (!isFuncTypeConvertible(ft)) { diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 10ad7fb4e6542..3c3b03ccf1b9a 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -399,6 +399,11 @@ LogicalResult cir::FuncOp::verifyType() { if (!isa(type)) return emitOpError("requires '" + getFunctionTypeAttrName().str() + "' attribute of function type"); + if (auto rt = type.getReturnTypes(); + !rt.empty() && mlir::isa(rt.front())) + return emitOpError("The return type for a function returning void should " + "be empty instead of an explicit !cir.void"); + return success(); } diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 48be11ba4e243..dbc36c1b513b4 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -16,15 +16,19 @@ #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "llvm/ADT/TypeSwitch.h" +#include + //===----------------------------------------------------------------------===// // CIR Custom Parser/Printer Signatures //===----------------------------------------------------------------------===// -static mlir::ParseResult -parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector ¶ms, - bool &isVarArg); -static void printFuncTypeArgs(mlir::AsmPrinter &p, - mlir::ArrayRef params, bool isVarArg); +static mlir::ParseResult parseFuncType(mlir::AsmParser &p, + mlir::Type &optionalReturnTypes, + llvm::SmallVector ¶ms, + bool &isVarArg); + +static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes, + mlir::ArrayRef params, bool isVarArg); //===----------------------------------------------------------------------===// // Get autogenerated stuff @@ -331,9 +335,38 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const { return get(llvm::to_vector(inputs), results[0], isVarArg()); } -mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p, - llvm::SmallVector ¶ms, - bool &isVarArg) { +// A special parser is needed for function returning void to handle the missing +// type. +static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p, + mlir::Type &optionalReturnType) { + if (succeeded(p.parseOptionalLParen())) { + // If we have already a '(', the function has no return type + optionalReturnType = {}; + return mlir::success(); + } + mlir::Type type; + if (p.parseType(type)) + return mlir::failure(); + if (isa(type)) { + // An explicit !cir.void means also no return type. + optionalReturnType = {}; + } else { + // Otherwise use the actual type. + optionalReturnType = type; + } + return p.parseLParen(); +} + +// A special pretty-printer for function returning or not a result. +static void printFuncTypeReturn(mlir::AsmPrinter &p, + mlir::Type optionalReturnType) { + if (optionalReturnType) + p << optionalReturnType << ' '; +} + +static mlir::ParseResult +parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector ¶ms, + bool &isVarArg) { isVarArg = false; // `(` `)` if (succeeded(p.parseOptionalRParen())) @@ -363,8 +396,10 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p, return p.parseRParen(); } -void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef params, - bool isVarArg) { +static void printFuncTypeArgs(mlir::AsmPrinter &p, + mlir::ArrayRef params, + bool isVarArg) { + p << '('; llvm::interleaveComma(params, p, [&p](mlir::Type type) { p.printType(type); }); if (isVarArg) { @@ -375,11 +410,48 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef params, p << ')'; } +// Use a custom parser to handle the optional return and argument types without +// an optional anchor. +static mlir::ParseResult parseFuncType(mlir::AsmParser &p, + mlir::Type &optionalReturnTypes, + llvm::SmallVector ¶ms, + bool &isVarArg) { + if (failed(parseFuncTypeReturn(p, optionalReturnTypes))) + return failure(); + return parseFuncTypeArgs(p, params, isVarArg); +} + +static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes, + mlir::ArrayRef params, bool isVarArg) { + printFuncTypeReturn(p, optionalReturnTypes); + printFuncTypeArgs(p, params, isVarArg); +} + +// Return the actual return type or an explicit !cir.void if the function does +// not return anything +mlir::Type FuncType::getReturnType() const { + if (isVoid()) + return cir::VoidType::get(getContext()); + return static_cast(getImpl())->optionalReturnType; +} + +/// Returns the result type of the function as an ArrayRef, enabling better +/// integration with generic MLIR utilities. llvm::ArrayRef FuncType::getReturnTypes() const { - return static_cast(getImpl())->returnType; + if (isVoid()) + return {}; + return static_cast(getImpl())->optionalReturnType; } -bool FuncType::isVoid() const { return mlir::isa(getReturnType()); } +// Whether the function returns void +bool FuncType::isVoid() const { + auto rt = + static_cast(getImpl())->optionalReturnType; + assert((!rt || !mlir::isa(rt)) && + "The return type for a function returning void should be empty " + "instead of a real !cir.void"); + return !rt; +} //===----------------------------------------------------------------------===// // PointerType Definitions diff --git a/clang/test/CIR/func-simple.cpp b/clang/test/CIR/func-simple.cpp index 10c49bc506c87..f5d8b2225c7e4 100644 --- a/clang/test/CIR/func-simple.cpp +++ b/clang/test/CIR/func-simple.cpp @@ -2,12 +2,12 @@ // RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s void empty() { } -// CHECK: cir.func @empty() -> !cir.void { +// CHECK: cir.func @empty() { // CHECK: cir.return // CHECK: } void voidret() { return; } -// CHECK: cir.func @voidret() -> !cir.void { +// CHECK: cir.func @voidret() { // CHECK: cir.return // CHECK: } diff --git a/clang/test/CIR/global-var-simple.cpp b/clang/test/CIR/global-var-simple.cpp index 237070a5b7564..cb3e915a6fa5a 100644 --- a/clang/test/CIR/global-var-simple.cpp +++ b/clang/test/CIR/global-var-simple.cpp @@ -89,10 +89,10 @@ char **cpp; // CHECK: cir.global @cpp : !cir.ptr>> void (*fp)(); -// CHECK: cir.global @fp : !cir.ptr> +// CHECK: cir.global @fp : !cir.ptr> int (*fpii)(int) = 0; // CHECK: cir.global @fpii = #cir.ptr : !cir.ptr (!cir.int)>> void (*fpvar)(int, ...); -// CHECK: cir.global @fpvar : !cir.ptr, ...)>> +// CHECK: cir.global @fpvar : !cir.ptr, ...)>>