Skip to content
Merged
8 changes: 7 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,18 @@ class CIRBrCondOpLowering
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value i1Condition;

auto hasOneUse = false;

if (auto defOp = brOp.getCond().getDefiningOp())
hasOneUse = defOp->getResult(0).hasOneUse();

if (auto defOp = adaptor.getCond().getDefiningOp()) {
if (auto zext = dyn_cast<mlir::LLVM::ZExtOp>(defOp)) {
if (zext->use_empty() &&
zext->getOperand(0).getType() == rewriter.getI1Type()) {
i1Condition = zext->getOperand(0);
rewriter.eraseOp(zext);
if (hasOneUse)
rewriter.eraseOp(zext);
}
}
}
Expand Down
43 changes: 43 additions & 0 deletions clang/test/CIR/Lowering/brcond.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: cir-opt %s -cir-to-llvm | FileCheck %s -check-prefix=MLIR
// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s -check-prefix=LLVM

!s32i = !cir.int<s, 32>
#fn_attr = #cir<extra({inline = #cir.inline<no>, nothrow = #cir.nothrow, optnone = #cir.optnone})>
module { cir.func no_proto @test() -> !cir.bool extra(#fn_attr) {
%0 = cir.const #cir.int<0> : !s32i
%1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool
cir.br ^bb1
^bb1:
cir.brcond %1 ^bb2, ^bb3
^bb2:
cir.return %1 : !cir.bool
^bb3:
cir.br ^bb4
^bb4:
cir.return %1 : !cir.bool
}
}

// MLIR: {{.*}} = llvm.mlir.constant(0 : i32) : i32
// MLIR-NEXT: {{.*}} = llvm.mlir.constant(0 : i32) : i32
// MLIR-NEXT: {{.*}} = llvm.icmp "ne" {{.*}}, {{.*}} : i32
// MLIR-NEXT: {{.*}} = llvm.zext {{.*}} : i1 to i8
// MLIR-NEXT: llvm.br ^bb1
// MLIR-NEXT: ^bb1:
// MLIR-NEXT: llvm.cond_br {{.*}}, ^bb2, ^bb3
// MLIR-NEXT: ^bb2:
// MLIR-NEXT: llvm.return {{.*}} : i8
// MLIR-NEXT: ^bb3:
// MLIR-NEXT: llvm.br ^bb4
// MLIR-NEXT: ^bb4:
// MLIR-NEXT: llvm.return {{.*}} : i8

// LLVM: br label {{.*}}
// LLVM: 1:
// LLVM: br i1 false, label {{.*}}, label {{.*}}
// LLVM: 2:
// LLVM: ret i8 0
// LLVM: 3:
// LLVM: br label {{.*}}
// LLVM: 4:
// LLVM: ret i8 0