Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 99 additions & 90 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3525,114 +3525,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
}
};

/// Helper function for converting select ops. This function converts the
/// signature of the given block. If the new block signature is different from
/// `expectedTypes`, returns "failure".
static llvm::FailureOr<mlir::Block *>
getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::Operation *branchOp, mlir::Block *block,
mlir::TypeRange expectedTypes) {
assert(converter && "expected non-null type converter");
assert(!block->isEntryBlock() && "entry blocks have no predecessors");

// There is nothing to do if the types already match.
if (block->getArgumentTypes() == expectedTypes)
return block;

// Compute the new block argument types and convert the block.
std::optional<mlir::TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion)
return rewriter.notifyMatchFailure(branchOp,
"could not compute block signature");
if (expectedTypes != conversion->getConvertedTypes())
return rewriter.notifyMatchFailure(
branchOp,
"mismatch between adaptor operand types and computed block signature");
return rewriter.applySignatureConversion(block, *conversion, converter);
}

/// Base class for SelectOpConversion and SelectRankOpConversion.
template <typename OP>
static llvm::LogicalResult
selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.getSelector();
auto loc = select.getLoc();
assert(conds > 0 && "select must have cases");

llvm::SmallVector<mlir::Block *> destinations;
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
mlir::Block *defaultDestination;
mlir::ValueRange defaultOperands;
llvm::SmallVector<int32_t> caseValues;

for (unsigned t = 0; t != conds; ++t) {
mlir::Block *dest = select.getSuccessor(t);
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
const mlir::Attribute &attr = cases[t];
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
auto convertedBlock =
getConvertedBlock(rewriter, converter, select, dest,
mlir::TypeRange(destinationsOperands.back()));
struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
using fir::FIROpConversion<OP>::FIROpConversion;

private:
/// Helper function for converting select ops. This function converts the
/// signature of the given block. If the new block signature is different from
/// `expectedTypes`, returns "failure".
llvm::FailureOr<mlir::Block *>
getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
mlir::Operation *branchOp, mlir::Block *block,
mlir::TypeRange expectedTypes) const {
const mlir::TypeConverter *converter = this->getTypeConverter();
assert(converter && "expected non-null type converter");
assert(!block->isEntryBlock() && "entry blocks have no predecessors");

// There is nothing to do if the types already match.
if (block->getArgumentTypes() == expectedTypes)
return block;

// Compute the new block argument types and convert the block.
std::optional<mlir::TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion)
return rewriter.notifyMatchFailure(branchOp,
"could not compute block signature");
if (expectedTypes != conversion->getConvertedTypes())
return rewriter.notifyMatchFailure(branchOp,
"mismatch between adaptor operand "
"types and computed block signature");
return rewriter.applySignatureConversion(block, *conversion, converter);
}

protected:
llvm::LogicalResult
selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.getSelector();
auto loc = select.getLoc();
assert(conds > 0 && "select must have cases");

llvm::SmallVector<mlir::Block *> destinations;
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
mlir::Block *defaultDestination;
mlir::ValueRange defaultOperands;
// LLVM::SwitchOp selector type and the case values types
// must have the same bit width, so cast the selector to i64,
// and use i64 for the case values. It is hard to imagine
// a computed GO TO with the number of labels in the label-list
// bigger than INT_MAX, but let's use i64 to be on the safe side.
// Moreover, fir.select operation is more relaxed than
// a Fortran computed GO TO, so it may specify such a case value
// even if there is just a single label/case.
llvm::SmallVector<int64_t> caseValues;

for (unsigned t = 0; t != conds; ++t) {
mlir::Block *dest = select.getSuccessor(t);
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
const mlir::Attribute &attr = cases[t];
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
auto convertedBlock =
getConvertedBlock(rewriter, select, dest,
mlir::TypeRange(destinationsOperands.back()));
if (mlir::failed(convertedBlock))
return mlir::failure();
destinations.push_back(*convertedBlock);
caseValues.push_back(intAttr.getInt());
continue;
}
assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
assert((t + 1 == conds) && "unit must be last");
defaultOperands = destOps ? *destOps : mlir::ValueRange{};
auto convertedBlock = getConvertedBlock(rewriter, select, dest,
mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
destinations.push_back(*convertedBlock);
caseValues.push_back(intAttr.getInt());
continue;
defaultDestination = *convertedBlock;
}
assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
assert((t + 1 == conds) && "unit must be last");
defaultOperands = destOps ? *destOps : mlir::ValueRange{};
auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
defaultDestination = *convertedBlock;
}

// LLVM::SwitchOp takes a i32 type for the selector.
if (select.getSelector().getType() != rewriter.getI32Type())
selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
selector);

rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
select, selector,
/*defaultDestination=*/defaultDestination,
/*defaultOperands=*/defaultOperands,
/*caseValues=*/caseValues,
/*caseDestinations=*/destinations,
/*caseOperands=*/destinationsOperands,
/*branchWeights=*/llvm::ArrayRef<std::int32_t>());
return mlir::success();
}

selector =
this->integerCast(loc, rewriter, rewriter.getI64Type(), selector);

rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
select, selector,
/*defaultDestination=*/defaultDestination,
/*defaultOperands=*/defaultOperands,
/*caseValues=*/rewriter.getI64VectorAttr(caseValues),
/*caseDestinations=*/destinations,
/*caseOperands=*/destinationsOperands,
/*branchWeights=*/llvm::ArrayRef<std::int32_t>());
return mlir::success();
}
};
/// conversion of fir::SelectOp to an if-then-else ladder
struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
using FIROpConversion::FIROpConversion;
struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
using SelectOpConversionBase::SelectOpConversionBase;

llvm::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
rewriter, getTypeConverter());
return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};

/// conversion of fir::SelectRankOp to an if-then-else ladder
struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
using FIROpConversion::FIROpConversion;
struct SelectRankOpConversion
: public SelectOpConversionBase<fir::SelectRankOp> {
using SelectOpConversionBase::SelectOpConversionBase;

llvm::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
return selectMatchAndRewrite<fir::SelectRankOp>(
lowerTy(), op, adaptor, rewriter, getTypeConverter());
return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};

Expand Down
57 changes: 54 additions & 3 deletions flang/test/Fir/convert-to-llvm.fir
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,7 @@ func.func @select(%arg : index, %arg2 : i32) -> i32 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
// CHECK: llvm.switch %[[SELECTOR]] : i32, ^bb5 [
// CHECK: llvm.switch %[[SELECTVALUE]] : i64, ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
Expand Down Expand Up @@ -384,7 +383,8 @@ func.func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.switch %[[SELECTVALUE]] : i32, ^bb5 [
// CHECK: %[[SELECTOR:.*]] = llvm.sext %[[SELECTVALUE]] : i{{.*}} to i64
// CHECK: llvm.switch %[[SELECTOR]] : i64, ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
Expand Down Expand Up @@ -2853,23 +2853,74 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
return
}

// -----

// CHECK-LABEL: @test_call_arg_attrs_indirect
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
// CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
return %0 : i16
}

// -----

// CHECK-LABEL: @test_byval
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}

// -----

// CHECK-LABEL: @test_sret
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}

// -----

func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
^bb1:
fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
^bb2:
fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
^bb3:
fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
^bb4:
fir.select %arg3 : i64 [ 4294967296, ^bb5, unit, ^bb5 ]
^bb5:
return
}
// CHECK-LABEL: llvm.func @select_with_cast(
// CHECK-SAME: %[[ARG0:.*]]: i8,
// CHECK-SAME: %[[ARG1:.*]]: i16,
// CHECK-SAME: %[[ARG2:.*]]: i64,
// CHECK-SAME: %[[ARG3:.*]]: i64) {
// CHECK: %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i64
// CHECK: llvm.switch %[[VAL_0]] : i64, ^bb1 [
// CHECK: 1: ^bb1
// CHECK: ]
// CHECK: ^bb1:
// CHECK: %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i64
// CHECK: llvm.switch %[[VAL_1]] : i64, ^bb2 [
// CHECK: 1: ^bb2
// CHECK: ]
// CHECK: ^bb2:
// CHECK: llvm.switch %[[ARG2]] : i64, ^bb3 [
// CHECK: 1: ^bb3
// CHECK: ]
// CHECK: ^bb3:
// CHECK: llvm.switch %[[ARG3]] : i64, ^bb4 [
// CHECK: 1: ^bb4
// CHECK: ]
// CHECK: ^bb4:
// CHECK: llvm.switch %[[ARG2]] : i64, ^bb5 [
// CHECK: 4294967296: ^bb5
// CHECK: ]
// CHECK: ^bb5:
// CHECK: llvm.return
// CHECK: }
10 changes: 5 additions & 5 deletions flang/test/Fir/select.fir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
func.func @f(%a : i32) -> i32 {
%1 = arith.constant 1 : i32
%2 = arith.constant 42 : i32
// CHECK: switch i32 %{{.*}}, label %{{.*}} [
// CHECK: i32 1, label %{{.*}}
// CHECK: switch i64 %{{.*}}, label %{{.*}} [
// CHECK: i64 1, label %{{.*}}
// CHECK: ]
fir.select %a : i32 [1, ^bb2(%1:i32), unit, ^bb3(%2:i32)]
^bb2(%3 : i32) :
Expand All @@ -24,9 +24,9 @@ func.func @g(%a : i32) -> i32 {
%1 = arith.constant 1 : i32
%2 = arith.constant 42 : i32

// CHECK: switch i32 %{{.*}}, label %{{.*}} [
// CHECK: i32 1, label %{{.*}}
// CHECK: i32 -1, label %{{.*}}
// CHECK: switch i64 %{{.*}}, label %{{.*}} [
// CHECK: i64 1, label %{{.*}}
// CHECK: i64 -1, label %{{.*}}
// CHECK: ]
fir.select_rank %a : i32 [1, ^bb2(%1:i32), -1, ^bb4, unit, ^bb3(%2:i32)]
^bb2(%3 : i32) :
Expand Down