Skip to content

[flang] Cast fir.select[_rank] selector to i64. #153239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 12, 2025
Merged

Conversation

vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Aug 12, 2025

Properly cast the selector to i64 regardless of its integer type.
We used to generate llvm.trunc always.

We have to use i64 as long as the case values may exceed INT_MAX.

Fixes #153050.

Properly cast the selector to i32 regardless of its integer type.
We used to generate llvm.trunc always.

Fixes llvm#153050.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Aug 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

Properly cast the selector to i32 regardless of its integer type.
We used to generate llvm.trunc always.

Fixes #153050.


Full diff: https://github.com/llvm/llvm-project/pull/153239.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+93-90)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+47)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae690cbe..90bbff4a8363e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3525,114 +3525,117 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
   }
 };
 
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is different from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
-                  const mlir::TypeConverter *converter,
-                  mlir::Operation *branchOp, mlir::Block *block,
-                  mlir::TypeRange expectedTypes) {
-  assert(converter && "expected non-null type converter");
-  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
-  // There is nothing to do if the types already match.
-  if (block->getArgumentTypes() == expectedTypes)
-    return block;
-
-  // Compute the new block argument types and convert the block.
-  std::optional<mlir::TypeConverter::SignatureConversion> conversion =
-      converter->convertBlockSignature(block);
-  if (!conversion)
-    return rewriter.notifyMatchFailure(branchOp,
-                                       "could not compute block signature");
-  if (expectedTypes != conversion->getConvertedTypes())
-    return rewriter.notifyMatchFailure(
-        branchOp,
-        "mismatch between adaptor operand types and computed block signature");
-  return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
 template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
-                      typename OP::Adaptor adaptor,
-                      mlir::ConversionPatternRewriter &rewriter,
-                      const mlir::TypeConverter *converter) {
-  unsigned conds = select.getNumConditions();
-  auto cases = select.getCases().getValue();
-  mlir::Value selector = adaptor.getSelector();
-  auto loc = select.getLoc();
-  assert(conds > 0 && "select must have cases");
-
-  llvm::SmallVector<mlir::Block *> destinations;
-  llvm::SmallVector<mlir::ValueRange> destinationsOperands;
-  mlir::Block *defaultDestination;
-  mlir::ValueRange defaultOperands;
-  llvm::SmallVector<int32_t> caseValues;
-
-  for (unsigned t = 0; t != conds; ++t) {
-    mlir::Block *dest = select.getSuccessor(t);
-    auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
-    const mlir::Attribute &attr = cases[t];
-    if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
-      destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
-      auto convertedBlock =
-          getConvertedBlock(rewriter, converter, select, dest,
-                            mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+  using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+  /// Helper function for converting select ops. This function converts the
+  /// signature of the given block. If the new block signature is different from
+  /// `expectedTypes`, returns "failure".
+  llvm::FailureOr<mlir::Block *>
+  getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+                    mlir::Operation *branchOp, mlir::Block *block,
+                    mlir::TypeRange expectedTypes) const {
+    const mlir::TypeConverter *converter = this->getTypeConverter();
+    assert(converter && "expected non-null type converter");
+    assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+    // There is nothing to do if the types already match.
+    if (block->getArgumentTypes() == expectedTypes)
+      return block;
+
+    // Compute the new block argument types and convert the block.
+    std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+        converter->convertBlockSignature(block);
+    if (!conversion)
+      return rewriter.notifyMatchFailure(branchOp,
+                                         "could not compute block signature");
+    if (expectedTypes != conversion->getConvertedTypes())
+      return rewriter.notifyMatchFailure(branchOp,
+                                         "mismatch between adaptor operand "
+                                         "types and computed block signature");
+    return rewriter.applySignatureConversion(block, *conversion, converter);
+  }
+
+protected:
+  llvm::LogicalResult
+  selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+                        mlir::ConversionPatternRewriter &rewriter) const {
+    unsigned conds = select.getNumConditions();
+    auto cases = select.getCases().getValue();
+    mlir::Value selector = adaptor.getSelector();
+    auto loc = select.getLoc();
+    assert(conds > 0 && "select must have cases");
+
+    llvm::SmallVector<mlir::Block *> destinations;
+    llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+    mlir::Block *defaultDestination;
+    mlir::ValueRange defaultOperands;
+    llvm::SmallVector<int32_t> caseValues;
+
+    for (unsigned t = 0; t != conds; ++t) {
+      mlir::Block *dest = select.getSuccessor(t);
+      auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+      const mlir::Attribute &attr = cases[t];
+      if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+        destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+        auto convertedBlock =
+            getConvertedBlock(rewriter, select, dest,
+                              mlir::TypeRange(destinationsOperands.back()));
+        if (mlir::failed(convertedBlock))
+          return mlir::failure();
+        destinations.push_back(*convertedBlock);
+        caseValues.push_back(intAttr.getInt());
+        continue;
+      }
+      assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+      assert((t + 1 == conds) && "unit must be last");
+      defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+      auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+                                              mlir::TypeRange(defaultOperands));
       if (mlir::failed(convertedBlock))
         return mlir::failure();
-      destinations.push_back(*convertedBlock);
-      caseValues.push_back(intAttr.getInt());
-      continue;
+      defaultDestination = *convertedBlock;
     }
-    assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
-    assert((t + 1 == conds) && "unit must be last");
-    defaultOperands = destOps ? *destOps : mlir::ValueRange{};
-    auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
-                                            mlir::TypeRange(defaultOperands));
-    if (mlir::failed(convertedBlock))
-      return mlir::failure();
-    defaultDestination = *convertedBlock;
-  }
-
-  // LLVM::SwitchOp takes a i32 type for the selector.
-  if (select.getSelector().getType() != rewriter.getI32Type())
-    selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
-                                           selector);
-
-  rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
-      select, selector,
-      /*defaultDestination=*/defaultDestination,
-      /*defaultOperands=*/defaultOperands,
-      /*caseValues=*/caseValues,
-      /*caseDestinations=*/destinations,
-      /*caseOperands=*/destinationsOperands,
-      /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
-  return mlir::success();
-}
 
+    // LLVM::SwitchOp takes a i32 type for the selector.
+    if (select.getSelector().getType() != rewriter.getI32Type())
+      selector =
+          this->integerCast(loc, rewriter, rewriter.getI32Type(), selector);
+
+    rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+        select, selector,
+        /*defaultDestination=*/defaultDestination,
+        /*defaultOperands=*/defaultOperands,
+        /*caseValues=*/caseValues,
+        /*caseDestinations=*/destinations,
+        /*caseOperands=*/destinationsOperands,
+        /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+    return mlir::success();
+  }
+};
 /// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
-  using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+  using SelectOpConversionBase::SelectOpConversionBase;
 
   llvm::LogicalResult
   matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
-                                                rewriter, getTypeConverter());
+    return this->selectMatchAndRewrite(op, adaptor, rewriter);
   }
 };
 
 /// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
-  using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+    : public SelectOpConversionBase<fir::SelectRankOp> {
+  using SelectOpConversionBase::SelectOpConversionBase;
 
   llvm::LogicalResult
   matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    return selectMatchAndRewrite<fir::SelectRankOp>(
-        lowerTy(), op, adaptor, rewriter, getTypeConverter());
+    return this->selectMatchAndRewrite(op, adaptor, rewriter);
   }
 };
 
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 50a98466f0d4b..338d3ab8edbaf 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,6 +2853,8 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @test_call_arg_attrs_indirect
 func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
@@ -2860,6 +2862,8 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   return %0 : i16
 }
 
+// -----
+
 // CHECK-LABEL: @test_byval
 func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
   //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
@@ -2867,9 +2871,52 @@ func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64)
   return
 }
 
+// -----
+
 // CHECK-LABEL: @test_sret
 func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
   //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
   fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
   return
 }
+
+// -----
+
+func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
+  fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
+  ^bb1:
+  fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
+  ^bb2:
+  fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
+  ^bb3:
+  fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
+  ^bb4:
+  return
+}
+// CHECK-LABEL:   llvm.func @select_with_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: i8,
+// CHECK-SAME:      %[[ARG1:.*]]: i16,
+// CHECK-SAME:      %[[ARG2:.*]]: i64,
+// CHECK-SAME:      %[[ARG3:.*]]: i64) {
+// CHECK:           %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i32
+// CHECK:           llvm.switch %[[VAL_0]] : i32, ^bb1 [
+// CHECK:             1: ^bb1
+// CHECK:           ]
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i32
+// CHECK:           llvm.switch %[[VAL_1]] : i32, ^bb2 [
+// CHECK:             1: ^bb2
+// CHECK:           ]
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_2:.*]] = llvm.trunc %[[ARG2]] : i64 to i32
+// CHECK:           llvm.switch %[[VAL_2]] : i32, ^bb3 [
+// CHECK:             1: ^bb3
+// CHECK:           ]
+// CHECK:         ^bb3:
+// CHECK:           %[[VAL_3:.*]] = llvm.trunc %[[ARG3]] : i64 to i32
+// CHECK:           llvm.switch %[[VAL_3]] : i32, ^bb4 [
+// CHECK:             1: ^bb4
+// CHECK:           ]
+// CHECK:         ^bb4:
+// CHECK:           llvm.return
+// CHECK:         }

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-flang-codegen

Author: Slava Zakharin (vzakhari)

Changes

Properly cast the selector to i32 regardless of its integer type.
We used to generate llvm.trunc always.

Fixes #153050.


Full diff: https://github.com/llvm/llvm-project/pull/153239.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+93-90)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+47)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae690cbe..90bbff4a8363e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3525,114 +3525,117 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
   }
 };
 
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is different from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
-                  const mlir::TypeConverter *converter,
-                  mlir::Operation *branchOp, mlir::Block *block,
-                  mlir::TypeRange expectedTypes) {
-  assert(converter && "expected non-null type converter");
-  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
-  // There is nothing to do if the types already match.
-  if (block->getArgumentTypes() == expectedTypes)
-    return block;
-
-  // Compute the new block argument types and convert the block.
-  std::optional<mlir::TypeConverter::SignatureConversion> conversion =
-      converter->convertBlockSignature(block);
-  if (!conversion)
-    return rewriter.notifyMatchFailure(branchOp,
-                                       "could not compute block signature");
-  if (expectedTypes != conversion->getConvertedTypes())
-    return rewriter.notifyMatchFailure(
-        branchOp,
-        "mismatch between adaptor operand types and computed block signature");
-  return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
 template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
-                      typename OP::Adaptor adaptor,
-                      mlir::ConversionPatternRewriter &rewriter,
-                      const mlir::TypeConverter *converter) {
-  unsigned conds = select.getNumConditions();
-  auto cases = select.getCases().getValue();
-  mlir::Value selector = adaptor.getSelector();
-  auto loc = select.getLoc();
-  assert(conds > 0 && "select must have cases");
-
-  llvm::SmallVector<mlir::Block *> destinations;
-  llvm::SmallVector<mlir::ValueRange> destinationsOperands;
-  mlir::Block *defaultDestination;
-  mlir::ValueRange defaultOperands;
-  llvm::SmallVector<int32_t> caseValues;
-
-  for (unsigned t = 0; t != conds; ++t) {
-    mlir::Block *dest = select.getSuccessor(t);
-    auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
-    const mlir::Attribute &attr = cases[t];
-    if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
-      destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
-      auto convertedBlock =
-          getConvertedBlock(rewriter, converter, select, dest,
-                            mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+  using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+  /// Helper function for converting select ops. This function converts the
+  /// signature of the given block. If the new block signature is different from
+  /// `expectedTypes`, returns "failure".
+  llvm::FailureOr<mlir::Block *>
+  getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+                    mlir::Operation *branchOp, mlir::Block *block,
+                    mlir::TypeRange expectedTypes) const {
+    const mlir::TypeConverter *converter = this->getTypeConverter();
+    assert(converter && "expected non-null type converter");
+    assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+    // There is nothing to do if the types already match.
+    if (block->getArgumentTypes() == expectedTypes)
+      return block;
+
+    // Compute the new block argument types and convert the block.
+    std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+        converter->convertBlockSignature(block);
+    if (!conversion)
+      return rewriter.notifyMatchFailure(branchOp,
+                                         "could not compute block signature");
+    if (expectedTypes != conversion->getConvertedTypes())
+      return rewriter.notifyMatchFailure(branchOp,
+                                         "mismatch between adaptor operand "
+                                         "types and computed block signature");
+    return rewriter.applySignatureConversion(block, *conversion, converter);
+  }
+
+protected:
+  llvm::LogicalResult
+  selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+                        mlir::ConversionPatternRewriter &rewriter) const {
+    unsigned conds = select.getNumConditions();
+    auto cases = select.getCases().getValue();
+    mlir::Value selector = adaptor.getSelector();
+    auto loc = select.getLoc();
+    assert(conds > 0 && "select must have cases");
+
+    llvm::SmallVector<mlir::Block *> destinations;
+    llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+    mlir::Block *defaultDestination;
+    mlir::ValueRange defaultOperands;
+    llvm::SmallVector<int32_t> caseValues;
+
+    for (unsigned t = 0; t != conds; ++t) {
+      mlir::Block *dest = select.getSuccessor(t);
+      auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+      const mlir::Attribute &attr = cases[t];
+      if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+        destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+        auto convertedBlock =
+            getConvertedBlock(rewriter, select, dest,
+                              mlir::TypeRange(destinationsOperands.back()));
+        if (mlir::failed(convertedBlock))
+          return mlir::failure();
+        destinations.push_back(*convertedBlock);
+        caseValues.push_back(intAttr.getInt());
+        continue;
+      }
+      assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+      assert((t + 1 == conds) && "unit must be last");
+      defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+      auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+                                              mlir::TypeRange(defaultOperands));
       if (mlir::failed(convertedBlock))
         return mlir::failure();
-      destinations.push_back(*convertedBlock);
-      caseValues.push_back(intAttr.getInt());
-      continue;
+      defaultDestination = *convertedBlock;
     }
-    assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
-    assert((t + 1 == conds) && "unit must be last");
-    defaultOperands = destOps ? *destOps : mlir::ValueRange{};
-    auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
-                                            mlir::TypeRange(defaultOperands));
-    if (mlir::failed(convertedBlock))
-      return mlir::failure();
-    defaultDestination = *convertedBlock;
-  }
-
-  // LLVM::SwitchOp takes a i32 type for the selector.
-  if (select.getSelector().getType() != rewriter.getI32Type())
-    selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
-                                           selector);
-
-  rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
-      select, selector,
-      /*defaultDestination=*/defaultDestination,
-      /*defaultOperands=*/defaultOperands,
-      /*caseValues=*/caseValues,
-      /*caseDestinations=*/destinations,
-      /*caseOperands=*/destinationsOperands,
-      /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
-  return mlir::success();
-}
 
+    // LLVM::SwitchOp takes a i32 type for the selector.
+    if (select.getSelector().getType() != rewriter.getI32Type())
+      selector =
+          this->integerCast(loc, rewriter, rewriter.getI32Type(), selector);
+
+    rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+        select, selector,
+        /*defaultDestination=*/defaultDestination,
+        /*defaultOperands=*/defaultOperands,
+        /*caseValues=*/caseValues,
+        /*caseDestinations=*/destinations,
+        /*caseOperands=*/destinationsOperands,
+        /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+    return mlir::success();
+  }
+};
 /// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
-  using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+  using SelectOpConversionBase::SelectOpConversionBase;
 
   llvm::LogicalResult
   matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
-                                                rewriter, getTypeConverter());
+    return this->selectMatchAndRewrite(op, adaptor, rewriter);
   }
 };
 
 /// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
-  using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+    : public SelectOpConversionBase<fir::SelectRankOp> {
+  using SelectOpConversionBase::SelectOpConversionBase;
 
   llvm::LogicalResult
   matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    return selectMatchAndRewrite<fir::SelectRankOp>(
-        lowerTy(), op, adaptor, rewriter, getTypeConverter());
+    return this->selectMatchAndRewrite(op, adaptor, rewriter);
   }
 };
 
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 50a98466f0d4b..338d3ab8edbaf 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,6 +2853,8 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @test_call_arg_attrs_indirect
 func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
@@ -2860,6 +2862,8 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   return %0 : i16
 }
 
+// -----
+
 // CHECK-LABEL: @test_byval
 func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
   //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
@@ -2867,9 +2871,52 @@ func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64)
   return
 }
 
+// -----
+
 // CHECK-LABEL: @test_sret
 func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
   //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
   fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
   return
 }
+
+// -----
+
+func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
+  fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
+  ^bb1:
+  fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
+  ^bb2:
+  fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
+  ^bb3:
+  fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
+  ^bb4:
+  return
+}
+// CHECK-LABEL:   llvm.func @select_with_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: i8,
+// CHECK-SAME:      %[[ARG1:.*]]: i16,
+// CHECK-SAME:      %[[ARG2:.*]]: i64,
+// CHECK-SAME:      %[[ARG3:.*]]: i64) {
+// CHECK:           %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i32
+// CHECK:           llvm.switch %[[VAL_0]] : i32, ^bb1 [
+// CHECK:             1: ^bb1
+// CHECK:           ]
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i32
+// CHECK:           llvm.switch %[[VAL_1]] : i32, ^bb2 [
+// CHECK:             1: ^bb2
+// CHECK:           ]
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_2:.*]] = llvm.trunc %[[ARG2]] : i64 to i32
+// CHECK:           llvm.switch %[[VAL_2]] : i32, ^bb3 [
+// CHECK:             1: ^bb3
+// CHECK:           ]
+// CHECK:         ^bb3:
+// CHECK:           %[[VAL_3:.*]] = llvm.trunc %[[ARG3]] : i64 to i32
+// CHECK:           llvm.switch %[[VAL_3]] : i32, ^bb4 [
+// CHECK:             1: ^bb4
+// CHECK:           ]
+// CHECK:         ^bb4:
+// CHECK:           llvm.return
+// CHECK:         }

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eugeneepshteyn
Copy link
Contributor

Properly cast the selector to i32 regardless of its integer type.

Why convert at all? Wouldn't llvm.switch accept any integer type, as long as the types of the operands match?

@vzakhari
Copy link
Contributor Author

Why convert at all? Wouldn't llvm.switch accept any integer type, as long as the types of the operands match?

Good point! I should have looked at llvm.switch definition - it does allow any signless integer. Let me rework this.

@vzakhari
Copy link
Contributor Author

vzakhari commented Aug 12, 2025

The case values' type is computed using the type of the selector:

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, <--- selector
                     Block *defaultDestination, ValueRange defaultOperands,
                     ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
                     ArrayRef<ValueRange> caseOperands,
                     ArrayRef<int32_t> branchWeights) {
  DenseIntElementsAttr caseValuesAttr;
  if (!caseValues.empty()) {
    ShapedType caseValueType = VectorType::get(
        static_cast<int64_t>(caseValues.size()), value.getType());
    caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
  }

And the type must match the types of the caseValues, which we create as int32_t values. So it looks like we have to cast selector to i32.

@vzakhari
Copy link
Contributor Author

Alternatively, I guess, we can create the caseValues vector using the iW with W matching the bit-width of the selector.

@eugeneepshteyn
Copy link
Contributor

eugeneepshteyn commented Aug 12, 2025

You know, if it becomes a headache, let's cast. But maybe just cast all to i64?

Even casting all to i32 is likely to work in practice. I haven't (yet) seen Fortran code that used anything more than default integer for computed GOTOs.

(Although, someone will now create a reproducer that explicitly uses INTEGER(8) with large values and breaks things.)

@vzakhari
Copy link
Contributor Author

You know, if it becomes a headache, let's cast. But maybe just cast all to i64?

Even casting all to i32 is likely to work in practice. I haven't (yet) seen Fortran code that used anything more than default integer for computed GOTOs.

(Although, someone will now create a reproducer that explicitly uses INTEGER(8) with large values and breaks things.)

Yep, it looks like creating differently typed caseValues dense attr is bulky. I will keep the casting, and will use i64. Thanks!

@vzakhari vzakhari changed the title [flang] Cast fir.select[_rank] selector to i32. [flang] Cast fir.select[_rank] selector to i64. Aug 12, 2025
@vzakhari vzakhari merged commit b8e4232 into llvm:main Aug 12, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[flang] Improper casting of integer(2) selector for computed gotos
4 participants