Skip to content

Commit 4a3344d

Browse files
gitoleglanza
authored andcommitted
[CIR][ABI][AArch64] Support struct passing with coercion through memory (#1111)
This PR adds a support for one more case of passing structs by value, with `memcpy` emitted. First of all, don't worry - despite the PR seems big, it's basically consist of helpers + refactoring. Also, there is a minor change in the `CIRBaseBuilder` - I made static the `getBestAllocaInsertPoint` method in order to call it from lowering - we discussed once - and I here we just need it (or copy-paste the code, which doesn't seem good). I will add several comments in order to simplify review.
1 parent 6b1ed8b commit 4a3344d

File tree

3 files changed

+93
-44
lines changed

3 files changed

+93
-44
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
560560
// Block handling helpers
561561
// ----------------------
562562
//
563-
OpBuilder::InsertPoint getBestAllocaInsertPoint(mlir::Block *block) {
563+
static OpBuilder::InsertPoint getBestAllocaInsertPoint(mlir::Block *block) {
564564
auto last =
565565
std::find_if(block->rbegin(), block->rend(), [](mlir::Operation &op) {
566566
return mlir::isa<cir::AllocaOp, cir::LabelOp>(&op);

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

Lines changed: 76 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/Support/LogicalResult.h"
2121
#include "clang/CIR/ABIArgInfo.h"
22+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
2223
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
2324
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2425
#include "clang/CIR/Dialect/IR/CIRTypes.h"
@@ -140,6 +141,76 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ,
140141
return val;
141142
}
142143

144+
// FIXME(cir): Create a custom rewriter class to abstract this away.
145+
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
146+
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
147+
Src);
148+
}
149+
150+
AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
151+
auto &rw = LF.getRewriter();
152+
auto *ctxt = rw.getContext();
153+
mlir::PatternRewriter::InsertionGuard guard(rw);
154+
155+
// find function's entry block and use it to find a best place for alloca
156+
auto *blk = rw.getBlock();
157+
auto *op = blk->getParentOp();
158+
FuncOp fun = mlir::dyn_cast<FuncOp>(op);
159+
if (!fun)
160+
fun = op->getParentOfType<FuncOp>();
161+
auto &entry = fun.getBody().front();
162+
163+
auto ip = CIRBaseBuilderTy::getBestAllocaInsertPoint(&entry);
164+
rw.restoreInsertionPoint(ip);
165+
166+
auto align = LF.LM.getDataLayout().getABITypeAlign(ty);
167+
auto alignAttr = rw.getI64IntegerAttr(align.value());
168+
auto ptrTy = PointerType::get(ctxt, ty);
169+
return rw.create<AllocaOp>(loc, ptrTy, ty, "tmp", alignAttr);
170+
}
171+
172+
bool isVoidPtr(mlir::Value v) {
173+
if (auto p = mlir::dyn_cast<PointerType>(v.getType()))
174+
return mlir::isa<VoidType>(p.getPointee());
175+
return false;
176+
}
177+
178+
MemCpyOp createMemCpy(LowerFunction &LF, mlir::Value dst, mlir::Value src,
179+
uint64_t len) {
180+
cir_cconv_assert(mlir::isa<PointerType>(src.getType()));
181+
cir_cconv_assert(mlir::isa<PointerType>(dst.getType()));
182+
183+
auto *ctxt = LF.getRewriter().getContext();
184+
auto &rw = LF.getRewriter();
185+
auto voidPtr = PointerType::get(ctxt, cir::VoidType::get(ctxt));
186+
187+
if (!isVoidPtr(src))
188+
src = createBitcast(src, voidPtr, LF);
189+
if (!isVoidPtr(dst))
190+
dst = createBitcast(dst, voidPtr, LF);
191+
192+
auto i64Ty = IntType::get(ctxt, 64, false);
193+
auto length = rw.create<ConstantOp>(src.getLoc(), IntAttr::get(i64Ty, len));
194+
return rw.create<MemCpyOp>(src.getLoc(), dst, src, length);
195+
}
196+
197+
cir::AllocaOp findAlloca(mlir::Operation *op) {
198+
if (!op)
199+
return {};
200+
201+
if (auto al = mlir::dyn_cast<cir::AllocaOp>(op)) {
202+
return al;
203+
} else if (auto ret = mlir::dyn_cast<cir::ReturnOp>(op)) {
204+
auto vals = ret.getInput();
205+
if (vals.size() == 1)
206+
return findAlloca(vals[0].getDefiningOp());
207+
} else if (auto load = mlir::dyn_cast<cir::LoadOp>(op)) {
208+
return findAlloca(load.getAddr().getDefiningOp());
209+
}
210+
211+
return {};
212+
}
213+
143214
/// Create a store to \param Dst from \param Src where the source and
144215
/// destination may have different types.
145216
///
@@ -187,16 +258,12 @@ void createCoercedStore(mlir::Value Src, mlir::Value Dst, bool DstIsVolatile,
187258
auto addr = bld.create<CastOp>(Dst.getLoc(), ptrTy, CastKind::bitcast, Dst);
188259
bld.create<StoreOp>(Dst.getLoc(), Src, addr);
189260
} else {
190-
cir_cconv_unreachable("NYI");
261+
auto tmp = createTmpAlloca(CGF, Src.getLoc(), SrcTy);
262+
CGF.getRewriter().create<StoreOp>(Src.getLoc(), Src, tmp);
263+
createMemCpy(CGF, Dst, tmp, DstSize.getFixedValue());
191264
}
192265
}
193266

194-
// FIXME(cir): Create a custom rewriter class to abstract this away.
195-
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
196-
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
197-
Src);
198-
}
199-
200267
/// Coerces a \param Src value to a value of type \param Ty.
201268
///
202269
/// This safely handles the case when the src type is smaller than the
@@ -261,23 +328,6 @@ mlir::Value emitAddressAtOffset(LowerFunction &LF, mlir::Value addr,
261328
return addr;
262329
}
263330

264-
cir::AllocaOp findAlloca(mlir::Operation *op) {
265-
if (!op)
266-
return {};
267-
268-
if (auto al = mlir::dyn_cast<cir::AllocaOp>(op)) {
269-
return al;
270-
} else if (auto ret = mlir::dyn_cast<cir::ReturnOp>(op)) {
271-
auto vals = ret.getInput();
272-
if (vals.size() == 1)
273-
return findAlloca(vals[0].getDefiningOp());
274-
} else if (auto load = mlir::dyn_cast<cir::LoadOp>(op)) {
275-
return findAlloca(load.getAddr().getDefiningOp());
276-
}
277-
278-
return {};
279-
}
280-
281331
/// After the calling convention is lowered, an ABI-agnostic type might have to
282332
/// be loaded back to its ABI-aware couterpart so it may be returned. If they
283333
/// differ, we have to do a coerced load. A coerced load, which means to load a
@@ -329,25 +379,8 @@ mlir::Value castReturnValue(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
329379
// Otherwise do coercion through memory.
330380
if (auto addr = findAlloca(Src.getDefiningOp())) {
331381
auto &rewriter = LF.getRewriter();
332-
auto *ctxt = LF.LM.getMLIRContext();
333-
auto ptrTy = PointerType::get(ctxt, Ty);
334-
auto voidPtr = PointerType::get(ctxt, cir::VoidType::get(ctxt));
335-
336-
// insert alloca near the previuos one
337-
auto point = rewriter.saveInsertionPoint();
338-
rewriter.setInsertionPointAfter(addr);
339-
auto align = LF.LM.getDataLayout().getABITypeAlign(Ty);
340-
auto alignAttr = rewriter.getI64IntegerAttr(align.value());
341-
auto tmp =
342-
rewriter.create<AllocaOp>(Src.getLoc(), ptrTy, Ty, "tmp", alignAttr);
343-
rewriter.restoreInsertionPoint(point);
344-
345-
auto srcVoidPtr = createBitcast(addr, voidPtr, LF);
346-
auto dstVoidPtr = createBitcast(tmp, voidPtr, LF);
347-
auto i64Ty = IntType::get(ctxt, 64, false);
348-
auto len = rewriter.create<ConstantOp>(
349-
Src.getLoc(), IntAttr::get(i64Ty, SrcSize.getFixedValue()));
350-
rewriter.create<MemCpyOp>(Src.getLoc(), dstVoidPtr, srcVoidPtr, len);
382+
auto tmp = createTmpAlloca(LF, Src.getLoc(), Ty);
383+
createMemCpy(LF, tmp, addr, SrcSize.getFixedValue());
351384
return rewriter.create<LoadOp>(Src.getLoc(), tmp.getResult());
352385
}
353386

clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,19 @@ void pass_eq_128(EQ_128 s) {}
152152
// LLVM: store ptr %0, ptr %[[#V1]], align 8
153153
// LLVM: %[[#V2:]] = load ptr, ptr %[[#V1]], align 8
154154
void pass_gt_128(GT_128 s) {}
155+
156+
// CHECK: cir.func @passS(%arg0: !cir.array<!u64i x 2>
157+
// CHECK: %[[#V0:]] = cir.alloca !ty_S, !cir.ptr<!ty_S>, [""] {alignment = 4 : i64}
158+
// CHECK: %[[#V1:]] = cir.alloca !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>, ["tmp"] {alignment = 8 : i64}
159+
// CHECK: cir.store %arg0, %[[#V1]] : !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>
160+
// CHECK: %[[#V2:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.array<!u64i x 2>>), !cir.ptr<!void>
161+
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_S>), !cir.ptr<!void>
162+
// CHECK: %[[#V4:]] = cir.const #cir.int<12> : !u64i
163+
// CHECK: cir.libc.memcpy %[[#V4]] bytes from %[[#V2]] to %[[#V3]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>
164+
165+
// LLVM: void @passS([2 x i64] %[[#ARG:]])
166+
// LLVM: %[[#V1:]] = alloca %struct.S, i64 1, align 4
167+
// LLVM: %[[#V2:]] = alloca [2 x i64], i64 1, align 8
168+
// LLVM: store [2 x i64] %[[#ARG]], ptr %[[#V2]], align 8
169+
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V1]], ptr %[[#V2]], i64 12, i1 false)
170+
void passS(S s) {}

0 commit comments

Comments
 (0)