-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[flang] Cast fir.select[_rank] selector to i64. #153239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Properly cast the selector to i32 regardless of its integer type. We used to generate llvm.trunc always. Fixes llvm#153050.
@llvm/pr-subscribers-flang-fir-hlfir Author: Slava Zakharin (vzakhari) ChangesProperly cast the selector to i32 regardless of its integer type. Fixes #153050. Full diff: https://github.com/llvm/llvm-project/pull/153239.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae690cbe..90bbff4a8363e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3525,114 +3525,117 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
}
};
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is different from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter,
- mlir::Operation *branchOp, mlir::Block *block,
- mlir::TypeRange expectedTypes) {
- assert(converter && "expected non-null type converter");
- assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
- // There is nothing to do if the types already match.
- if (block->getArgumentTypes() == expectedTypes)
- return block;
-
- // Compute the new block argument types and convert the block.
- std::optional<mlir::TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion)
- return rewriter.notifyMatchFailure(branchOp,
- "could not compute block signature");
- if (expectedTypes != conversion->getConvertedTypes())
- return rewriter.notifyMatchFailure(
- branchOp,
- "mismatch between adaptor operand types and computed block signature");
- return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
- typename OP::Adaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter) {
- unsigned conds = select.getNumConditions();
- auto cases = select.getCases().getValue();
- mlir::Value selector = adaptor.getSelector();
- auto loc = select.getLoc();
- assert(conds > 0 && "select must have cases");
-
- llvm::SmallVector<mlir::Block *> destinations;
- llvm::SmallVector<mlir::ValueRange> destinationsOperands;
- mlir::Block *defaultDestination;
- mlir::ValueRange defaultOperands;
- llvm::SmallVector<int32_t> caseValues;
-
- for (unsigned t = 0; t != conds; ++t) {
- mlir::Block *dest = select.getSuccessor(t);
- auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
- const mlir::Attribute &attr = cases[t];
- if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
- destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
- auto convertedBlock =
- getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+ using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+ /// Helper function for converting select ops. This function converts the
+ /// signature of the given block. If the new block signature is different from
+ /// `expectedTypes`, returns "failure".
+ llvm::FailureOr<mlir::Block *>
+ getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+ mlir::Operation *branchOp, mlir::Block *block,
+ mlir::TypeRange expectedTypes) const {
+ const mlir::TypeConverter *converter = this->getTypeConverter();
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(branchOp,
+ "mismatch between adaptor operand "
+ "types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+ }
+
+protected:
+ llvm::LogicalResult
+ selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ unsigned conds = select.getNumConditions();
+ auto cases = select.getCases().getValue();
+ mlir::Value selector = adaptor.getSelector();
+ auto loc = select.getLoc();
+ assert(conds > 0 && "select must have cases");
+
+ llvm::SmallVector<mlir::Block *> destinations;
+ llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+ mlir::Block *defaultDestination;
+ mlir::ValueRange defaultOperands;
+ llvm::SmallVector<int32_t> caseValues;
+
+ for (unsigned t = 0; t != conds; ++t) {
+ mlir::Block *dest = select.getSuccessor(t);
+ auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+ const mlir::Attribute &attr = cases[t];
+ if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+ destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+ auto convertedBlock =
+ getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(destinationsOperands.back()));
+ if (mlir::failed(convertedBlock))
+ return mlir::failure();
+ destinations.push_back(*convertedBlock);
+ caseValues.push_back(intAttr.getInt());
+ continue;
+ }
+ assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+ assert((t + 1 == conds) && "unit must be last");
+ defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+ auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
- destinations.push_back(*convertedBlock);
- caseValues.push_back(intAttr.getInt());
- continue;
+ defaultDestination = *convertedBlock;
}
- assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
- assert((t + 1 == conds) && "unit must be last");
- defaultOperands = destOps ? *destOps : mlir::ValueRange{};
- auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(defaultOperands));
- if (mlir::failed(convertedBlock))
- return mlir::failure();
- defaultDestination = *convertedBlock;
- }
-
- // LLVM::SwitchOp takes a i32 type for the selector.
- if (select.getSelector().getType() != rewriter.getI32Type())
- selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
- selector);
-
- rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
- select, selector,
- /*defaultDestination=*/defaultDestination,
- /*defaultOperands=*/defaultOperands,
- /*caseValues=*/caseValues,
- /*caseDestinations=*/destinations,
- /*caseOperands=*/destinationsOperands,
- /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
- return mlir::success();
-}
+ // LLVM::SwitchOp takes a i32 type for the selector.
+ if (select.getSelector().getType() != rewriter.getI32Type())
+ selector =
+ this->integerCast(loc, rewriter, rewriter.getI32Type(), selector);
+
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+ select, selector,
+ /*defaultDestination=*/defaultDestination,
+ /*defaultOperands=*/defaultOperands,
+ /*caseValues=*/caseValues,
+ /*caseDestinations=*/destinations,
+ /*caseOperands=*/destinationsOperands,
+ /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+ return mlir::success();
+ }
+};
/// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
- rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
/// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+ : public SelectOpConversionBase<fir::SelectRankOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectRankOp>(
- lowerTy(), op, adaptor, rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 50a98466f0d4b..338d3ab8edbaf 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,6 +2853,8 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
return
}
+// -----
+
// CHECK-LABEL: @test_call_arg_attrs_indirect
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
// CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
@@ -2860,6 +2862,8 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
return %0 : i16
}
+// -----
+
// CHECK-LABEL: @test_byval
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
@@ -2867,9 +2871,52 @@ func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64)
return
}
+// -----
+
// CHECK-LABEL: @test_sret
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}
+
+// -----
+
+func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
+ fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
+ ^bb1:
+ fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
+ ^bb2:
+ fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
+ ^bb3:
+ fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
+ ^bb4:
+ return
+}
+// CHECK-LABEL: llvm.func @select_with_cast(
+// CHECK-SAME: %[[ARG0:.*]]: i8,
+// CHECK-SAME: %[[ARG1:.*]]: i16,
+// CHECK-SAME: %[[ARG2:.*]]: i64,
+// CHECK-SAME: %[[ARG3:.*]]: i64) {
+// CHECK: %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i32
+// CHECK: llvm.switch %[[VAL_0]] : i32, ^bb1 [
+// CHECK: 1: ^bb1
+// CHECK: ]
+// CHECK: ^bb1:
+// CHECK: %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i32
+// CHECK: llvm.switch %[[VAL_1]] : i32, ^bb2 [
+// CHECK: 1: ^bb2
+// CHECK: ]
+// CHECK: ^bb2:
+// CHECK: %[[VAL_2:.*]] = llvm.trunc %[[ARG2]] : i64 to i32
+// CHECK: llvm.switch %[[VAL_2]] : i32, ^bb3 [
+// CHECK: 1: ^bb3
+// CHECK: ]
+// CHECK: ^bb3:
+// CHECK: %[[VAL_3:.*]] = llvm.trunc %[[ARG3]] : i64 to i32
+// CHECK: llvm.switch %[[VAL_3]] : i32, ^bb4 [
+// CHECK: 1: ^bb4
+// CHECK: ]
+// CHECK: ^bb4:
+// CHECK: llvm.return
+// CHECK: }
|
@llvm/pr-subscribers-flang-codegen Author: Slava Zakharin (vzakhari) ChangesProperly cast the selector to i32 regardless of its integer type. Fixes #153050. Full diff: https://github.com/llvm/llvm-project/pull/153239.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae690cbe..90bbff4a8363e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3525,114 +3525,117 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
}
};
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is different from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter,
- mlir::Operation *branchOp, mlir::Block *block,
- mlir::TypeRange expectedTypes) {
- assert(converter && "expected non-null type converter");
- assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
- // There is nothing to do if the types already match.
- if (block->getArgumentTypes() == expectedTypes)
- return block;
-
- // Compute the new block argument types and convert the block.
- std::optional<mlir::TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion)
- return rewriter.notifyMatchFailure(branchOp,
- "could not compute block signature");
- if (expectedTypes != conversion->getConvertedTypes())
- return rewriter.notifyMatchFailure(
- branchOp,
- "mismatch between adaptor operand types and computed block signature");
- return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
- typename OP::Adaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter) {
- unsigned conds = select.getNumConditions();
- auto cases = select.getCases().getValue();
- mlir::Value selector = adaptor.getSelector();
- auto loc = select.getLoc();
- assert(conds > 0 && "select must have cases");
-
- llvm::SmallVector<mlir::Block *> destinations;
- llvm::SmallVector<mlir::ValueRange> destinationsOperands;
- mlir::Block *defaultDestination;
- mlir::ValueRange defaultOperands;
- llvm::SmallVector<int32_t> caseValues;
-
- for (unsigned t = 0; t != conds; ++t) {
- mlir::Block *dest = select.getSuccessor(t);
- auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
- const mlir::Attribute &attr = cases[t];
- if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
- destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
- auto convertedBlock =
- getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+ using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+ /// Helper function for converting select ops. This function converts the
+ /// signature of the given block. If the new block signature is different from
+ /// `expectedTypes`, returns "failure".
+ llvm::FailureOr<mlir::Block *>
+ getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+ mlir::Operation *branchOp, mlir::Block *block,
+ mlir::TypeRange expectedTypes) const {
+ const mlir::TypeConverter *converter = this->getTypeConverter();
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(branchOp,
+ "mismatch between adaptor operand "
+ "types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+ }
+
+protected:
+ llvm::LogicalResult
+ selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ unsigned conds = select.getNumConditions();
+ auto cases = select.getCases().getValue();
+ mlir::Value selector = adaptor.getSelector();
+ auto loc = select.getLoc();
+ assert(conds > 0 && "select must have cases");
+
+ llvm::SmallVector<mlir::Block *> destinations;
+ llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+ mlir::Block *defaultDestination;
+ mlir::ValueRange defaultOperands;
+ llvm::SmallVector<int32_t> caseValues;
+
+ for (unsigned t = 0; t != conds; ++t) {
+ mlir::Block *dest = select.getSuccessor(t);
+ auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+ const mlir::Attribute &attr = cases[t];
+ if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+ destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+ auto convertedBlock =
+ getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(destinationsOperands.back()));
+ if (mlir::failed(convertedBlock))
+ return mlir::failure();
+ destinations.push_back(*convertedBlock);
+ caseValues.push_back(intAttr.getInt());
+ continue;
+ }
+ assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+ assert((t + 1 == conds) && "unit must be last");
+ defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+ auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
- destinations.push_back(*convertedBlock);
- caseValues.push_back(intAttr.getInt());
- continue;
+ defaultDestination = *convertedBlock;
}
- assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
- assert((t + 1 == conds) && "unit must be last");
- defaultOperands = destOps ? *destOps : mlir::ValueRange{};
- auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(defaultOperands));
- if (mlir::failed(convertedBlock))
- return mlir::failure();
- defaultDestination = *convertedBlock;
- }
-
- // LLVM::SwitchOp takes a i32 type for the selector.
- if (select.getSelector().getType() != rewriter.getI32Type())
- selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
- selector);
-
- rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
- select, selector,
- /*defaultDestination=*/defaultDestination,
- /*defaultOperands=*/defaultOperands,
- /*caseValues=*/caseValues,
- /*caseDestinations=*/destinations,
- /*caseOperands=*/destinationsOperands,
- /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
- return mlir::success();
-}
+ // LLVM::SwitchOp takes a i32 type for the selector.
+ if (select.getSelector().getType() != rewriter.getI32Type())
+ selector =
+ this->integerCast(loc, rewriter, rewriter.getI32Type(), selector);
+
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+ select, selector,
+ /*defaultDestination=*/defaultDestination,
+ /*defaultOperands=*/defaultOperands,
+ /*caseValues=*/caseValues,
+ /*caseDestinations=*/destinations,
+ /*caseOperands=*/destinationsOperands,
+ /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+ return mlir::success();
+ }
+};
/// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
- rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
/// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+ : public SelectOpConversionBase<fir::SelectRankOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectRankOp>(
- lowerTy(), op, adaptor, rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 50a98466f0d4b..338d3ab8edbaf 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,6 +2853,8 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
return
}
+// -----
+
// CHECK-LABEL: @test_call_arg_attrs_indirect
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
// CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
@@ -2860,6 +2862,8 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
return %0 : i16
}
+// -----
+
// CHECK-LABEL: @test_byval
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
@@ -2867,9 +2871,52 @@ func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64)
return
}
+// -----
+
// CHECK-LABEL: @test_sret
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
return
}
+
+// -----
+
+func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
+ fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
+ ^bb1:
+ fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
+ ^bb2:
+ fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
+ ^bb3:
+ fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
+ ^bb4:
+ return
+}
+// CHECK-LABEL: llvm.func @select_with_cast(
+// CHECK-SAME: %[[ARG0:.*]]: i8,
+// CHECK-SAME: %[[ARG1:.*]]: i16,
+// CHECK-SAME: %[[ARG2:.*]]: i64,
+// CHECK-SAME: %[[ARG3:.*]]: i64) {
+// CHECK: %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i32
+// CHECK: llvm.switch %[[VAL_0]] : i32, ^bb1 [
+// CHECK: 1: ^bb1
+// CHECK: ]
+// CHECK: ^bb1:
+// CHECK: %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i32
+// CHECK: llvm.switch %[[VAL_1]] : i32, ^bb2 [
+// CHECK: 1: ^bb2
+// CHECK: ]
+// CHECK: ^bb2:
+// CHECK: %[[VAL_2:.*]] = llvm.trunc %[[ARG2]] : i64 to i32
+// CHECK: llvm.switch %[[VAL_2]] : i32, ^bb3 [
+// CHECK: 1: ^bb3
+// CHECK: ]
+// CHECK: ^bb3:
+// CHECK: %[[VAL_3:.*]] = llvm.trunc %[[ARG3]] : i64 to i32
+// CHECK: llvm.switch %[[VAL_3]] : i32, ^bb4 [
+// CHECK: 1: ^bb4
+// CHECK: ]
+// CHECK: ^bb4:
+// CHECK: llvm.return
+// CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Why convert at all? Wouldn't |
Good point! I should have looked at |
The case values' type is computed using the type of the
And the type must match the types of the |
Alternatively, I guess, we can create the |
You know, if it becomes a headache, let's cast. But maybe just cast all to Even casting all to (Although, someone will now create a reproducer that explicitly uses |
Yep, it looks like creating differently typed |
Properly cast the selector to
i64
regardless of its integer type.We used to generate llvm.trunc always.
We have to use
i64
as long as the case values may exceed INT_MAX.Fixes #153050.