Skip to content

Commit 26eb428

Browse files
authored
[MLIR][LLVM] Add vararg support in LLVM::CallOp and InvokeOp (#67274)
In order to support indirect vararg calls, we need to have information about the callee type - this patch adds a `callee_type` attribute that holds that. The attribute is required for vararg calls, else, it is optional and the callee type is inferred by the operands and results of the operation if not present. The syntax for non-vararg calls remains the same, whereas for vararg calls, it is changed to this: ``` llvm.call %p(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> () llvm.call @s(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> () ```
1 parent 64ffe64 commit 26eb428

File tree

16 files changed

+345
-105
lines changed

16 files changed

+345
-105
lines changed

mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class PrintOpLowering : public ConversionPattern {
6161
LogicalResult
6262
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6363
ConversionPatternRewriter &rewriter) const override {
64+
auto context = rewriter.getContext();
6465
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
6566
auto memRefShape = memRefType.getShape();
6667
auto loc = op->getLoc();
@@ -92,8 +93,8 @@ class PrintOpLowering : public ConversionPattern {
9293

9394
// Insert a newline after each of the inner dimensions of the shape.
9495
if (i != e - 1)
95-
rewriter.create<func::CallOp>(loc, printfRef,
96-
rewriter.getIntegerType(32), newLineCst);
96+
rewriter.create<LLVM::CallOp>(loc, getPrintfType(context), printfRef,
97+
newLineCst);
9798
rewriter.create<scf::YieldOp>(loc);
9899
rewriter.setInsertionPointToStart(loop.getBody());
99100
}
@@ -102,8 +103,8 @@ class PrintOpLowering : public ConversionPattern {
102103
auto printOp = cast<toy::PrintOp>(op);
103104
auto elementLoad =
104105
rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs);
105-
rewriter.create<func::CallOp>(
106-
loc, printfRef, rewriter.getIntegerType(32),
106+
rewriter.create<LLVM::CallOp>(
107+
loc, getPrintfType(context), printfRef,
107108
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
108109

109110
// Notify the rewriter that this operation has been removed.
@@ -112,6 +113,16 @@ class PrintOpLowering : public ConversionPattern {
112113
}
113114

114115
private:
116+
/// Create a function declaration for printf, the signature is:
117+
/// * `i32 (i8*, ...)`
118+
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
119+
auto llvmI32Ty = IntegerType::get(context, 32);
120+
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121+
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
122+
/*isVarArg=*/true);
123+
return llvmFnType;
124+
}
125+
115126
/// Return a symbol reference to the printf function, inserting it into the
116127
/// module if necessary.
117128
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
@@ -120,17 +131,11 @@ class PrintOpLowering : public ConversionPattern {
120131
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
121132
return SymbolRefAttr::get(context, "printf");
122133

123-
// Create a function declaration for printf, the signature is:
124-
// * `i32 (i8*, ...)`
125-
auto llvmI32Ty = IntegerType::get(context, 32);
126-
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
127-
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
128-
/*isVarArg=*/true);
129-
130134
// Insert the printf function into the body of the parent module.
131135
PatternRewriter::InsertionGuard insertGuard(rewriter);
132136
rewriter.setInsertionPointToStart(module.getBody());
133-
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
137+
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf",
138+
getPrintfType(context));
134139
return SymbolRefAttr::get(context, "printf");
135140
}
136141

mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class PrintOpLowering : public ConversionPattern {
6161
LogicalResult
6262
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6363
ConversionPatternRewriter &rewriter) const override {
64+
auto context = rewriter.getContext();
6465
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
6566
auto memRefShape = memRefType.getShape();
6667
auto loc = op->getLoc();
@@ -92,8 +93,8 @@ class PrintOpLowering : public ConversionPattern {
9293

9394
// Insert a newline after each of the inner dimensions of the shape.
9495
if (i != e - 1)
95-
rewriter.create<func::CallOp>(loc, printfRef,
96-
rewriter.getIntegerType(32), newLineCst);
96+
rewriter.create<LLVM::CallOp>(loc, getPrintfType(context), printfRef,
97+
newLineCst);
9798
rewriter.create<scf::YieldOp>(loc);
9899
rewriter.setInsertionPointToStart(loop.getBody());
99100
}
@@ -102,8 +103,8 @@ class PrintOpLowering : public ConversionPattern {
102103
auto printOp = cast<toy::PrintOp>(op);
103104
auto elementLoad =
104105
rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs);
105-
rewriter.create<func::CallOp>(
106-
loc, printfRef, rewriter.getIntegerType(32),
106+
rewriter.create<LLVM::CallOp>(
107+
loc, getPrintfType(context), printfRef,
107108
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
108109

109110
// Notify the rewriter that this operation has been removed.
@@ -112,6 +113,16 @@ class PrintOpLowering : public ConversionPattern {
112113
}
113114

114115
private:
116+
/// Create a function declaration for printf, the signature is:
117+
/// * `i32 (i8*, ...)`
118+
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
119+
auto llvmI32Ty = IntegerType::get(context, 32);
120+
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121+
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
122+
/*isVarArg=*/true);
123+
return llvmFnType;
124+
}
125+
115126
/// Return a symbol reference to the printf function, inserting it into the
116127
/// module if necessary.
117128
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
@@ -120,17 +131,11 @@ class PrintOpLowering : public ConversionPattern {
120131
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
121132
return SymbolRefAttr::get(context, "printf");
122133

123-
// Create a function declaration for printf, the signature is:
124-
// * `i32 (i8*, ...)`
125-
auto llvmI32Ty = IntegerType::get(context, 32);
126-
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
127-
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
128-
/*isVarArg=*/true);
129-
130134
// Insert the printf function into the body of the parent module.
131135
PatternRewriter::InsertionGuard insertGuard(rewriter);
132136
rewriter.setInsertionPointToStart(module.getBody());
133-
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
137+
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf",
138+
getPrintfType(context));
134139
return SymbolRefAttr::get(context, "printf");
135140
}
136141

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,9 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
542542
DeclareOpInterfaceMethods<CallOpInterface>,
543543
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
544544
Terminator]> {
545-
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
545+
let arguments = (ins
546+
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
547+
OptionalAttr<FlatSymbolRefAttr>:$callee,
546548
Variadic<LLVM_Type>:$callee_operands,
547549
Variadic<LLVM_Type>:$normalDestOperands,
548550
Variadic<LLVM_Type>:$unwindDestOperands,
@@ -552,21 +554,21 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
552554
AnySuccessor:$unwindDest);
553555

554556
let builders = [
557+
OpBuilder<(ins "LLVMFuncOp":$func,
558+
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
559+
"Block*":$unwind, "ValueRange":$unwindOps)>,
555560
OpBuilder<(ins "TypeRange":$tys, "FlatSymbolRefAttr":$callee,
556561
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
557-
"Block*":$unwind, "ValueRange":$unwindOps),
558-
[{
559-
$_state.addAttribute("callee", callee);
560-
build($_builder, $_state, tys, ops, normal, normalOps, unwind, unwindOps);
561-
}]>,
562-
OpBuilder<(ins "TypeRange":$tys, "ValueRange":$ops, "Block*":$normal,
563-
"ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps),
564-
[{
565-
build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
566-
unwindOps, nullptr, normal, unwind);
567-
}]>];
562+
"Block*":$unwind, "ValueRange":$unwindOps)>,
563+
OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
564+
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
565+
"Block*":$unwind, "ValueRange":$unwindOps)>];
568566
let hasCustomAssemblyFormat = 1;
569567
let hasVerifier = 1;
568+
let extraClassDeclaration = [{
569+
/// Returns the callee function type.
570+
LLVMFunctionType getCalleeFunctionType();
571+
}];
570572
}
571573

572574
def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
@@ -581,9 +583,6 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
581583
// CallOp
582584
//===----------------------------------------------------------------------===//
583585

584-
// FIXME: Add a type attribute that carries the LLVM function type to support
585-
// indirect calls to variadic functions. The type attribute is necessary to
586-
// distinguish normal and variadic arguments.
587586
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
588587
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
589588
DeclareOpInterfaceMethods<CallOpInterface>,
@@ -599,10 +598,11 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
599598
The `call` instruction supports both direct and indirect calls. Direct calls
600599
start with a function name (`@`-prefixed) and indirect calls start with an
601600
SSA value (`%`-prefixed). The direct callee, if present, is stored as a
602-
function attribute `callee`. The trailing type list contains the optional
603-
indirect callee type and the MLIR function type, which differs from the
604-
LLVM function type that uses a explicit void type to model functions that do
605-
not return a value.
601+
function attribute `callee`. If the callee is a variadic function, then the
602+
`callee_type` attribute must carry the function type. The trailing type list
603+
contains the optional indirect callee type and the MLIR function type, which
604+
differs from the LLVM function type that uses a explicit void type to model
605+
functions that do not return a value.
606606

607607
Examples:
608608

@@ -615,10 +615,17 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
615615

616616
// Indirect call with an argument and without a result.
617617
llvm.call %1(%0) : !llvm.ptr, (f32) -> ()
618+
619+
// Direct variadic call.
620+
llvm.call @printf(%0, %1) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, i32) -> i32
621+
622+
// Indirect variadic call
623+
llvm.call %1(%0) vararg(!llvm.func<void (...)>) : !llvm.ptr, (i32) -> ()
618624
```
619625
}];
620626

621-
dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
627+
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
628+
OptionalAttr<FlatSymbolRefAttr>:$callee,
622629
Variadic<LLVM_Type>:$callee_operands,
623630
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
624631
"{}">:$fastmathFlags,
@@ -628,14 +635,25 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
628635
let results = (outs Optional<LLVM_Type>:$result);
629636
let builders = [
630637
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
638+
OpBuilder<(ins "LLVMFunctionType":$calleeType, "ValueRange":$args)>,
631639
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
632640
CArg<"ValueRange", "{}">:$args)>,
633641
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
634642
CArg<"ValueRange", "{}">:$args)>,
635643
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
644+
CArg<"ValueRange", "{}">:$args)>,
645+
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
646+
CArg<"ValueRange", "{}">:$args)>,
647+
OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
648+
CArg<"ValueRange", "{}">:$args)>,
649+
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
636650
CArg<"ValueRange", "{}">:$args)>
637651
];
638652
let hasCustomAssemblyFormat = 1;
653+
let extraClassDeclaration = [{
654+
/// Returns the callee function type.
655+
LLVMFunctionType getCalleeFunctionType();
656+
}];
639657
}
640658

641659
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)