Skip to content

Commit 72b8073

Browse files
[mlir][SCF] Add scf.index_switch support for populateSCFStructuralTypeConversionsAndLegality (llvm#160344)
In a downstream project, there is a need for a type conversion pattern for scf.index_switch operation. A test is added into `mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir` (not sure this functionality is really required for sparse tensors, but the test showcase that the new conversion pattern is functional)
1 parent c520531 commit 72b8073

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,30 @@ class ConvertWhileOpTypes
185185
};
186186
} // namespace
187187

188+
namespace {
189+
class ConvertIndexSwitchOpTypes
190+
: public Structural1ToNConversionPattern<IndexSwitchOp,
191+
ConvertIndexSwitchOpTypes> {
192+
public:
193+
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
194+
195+
std::optional<IndexSwitchOp>
196+
convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor,
197+
ConversionPatternRewriter &rewriter,
198+
TypeRange dstTypes) const {
199+
auto newOp =
200+
IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(),
201+
op.getCases(), op.getNumCases());
202+
203+
for (unsigned i = 0u; i < op.getNumRegions(); i++) {
204+
auto &dstRegion = newOp.getRegion(i);
205+
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
206+
}
207+
return newOp;
208+
}
209+
};
210+
} // namespace
211+
188212
namespace {
189213
// When the result types of a ForOp/IfOp get changed, the operand types of the
190214
// corresponding yield op need to be changed. In order to trigger the
@@ -220,18 +244,19 @@ void mlir::scf::populateSCFStructuralTypeConversions(
220244
const TypeConverter &typeConverter, RewritePatternSet &patterns,
221245
PatternBenefit benefit) {
222246
patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
223-
ConvertWhileOpTypes, ConvertConditionOpTypes>(
224-
typeConverter, patterns.getContext(), benefit);
247+
ConvertWhileOpTypes, ConvertConditionOpTypes,
248+
ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(),
249+
benefit);
225250
}
226251

227252
void mlir::scf::populateSCFStructuralTypeConversionTarget(
228253
const TypeConverter &typeConverter, ConversionTarget &target) {
229-
target.addDynamicallyLegalOp<ForOp, IfOp>(
254+
target.addDynamicallyLegalOp<ForOp, IfOp, IndexSwitchOp>(
230255
[&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
231256
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
232257
// We only have conversions for a subset of ops that use scf.yield
233258
// terminators.
234-
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
259+
if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp>(op->getParentOp()))
235260
return true;
236261
return typeConverter.isLegal(op.getOperands());
237262
});

mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,47 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x
8686
}
8787
return %0: tensor<1024xf32, #SparseVector>
8888
}
89+
90+
// CHECK-LABEL: func.func @index_switch(
91+
// CHECK-SAME: %[[PRED:.*0]]: index,
92+
// CHECK-SAME: %[[VAL_A_1:.*1]]: memref<?xindex>,
93+
// CHECK-SAME: %[[VAL_A_2:.*2]]: memref<?xindex>,
94+
// CHECK-SAME: %[[VAL_A_3:.*3]]: memref<?xf32>,
95+
// CHECK-SAME: %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier
96+
// CHECK-SAME: %[[VAL_B_1:.*5]]: memref<?xindex>,
97+
// CHECK-SAME: %[[VAL_B_2:.*6]]: memref<?xindex>,
98+
// CHECK-SAME: %[[VAL_B_3:.*7]]: memref<?xf32>,
99+
// CHECK-SAME: %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier
100+
// CHECK-SAME: %[[VAL_C_1:.*9]]: memref<?xindex>,
101+
// CHECK-SAME: %[[VAL_C_2:.*10]]: memref<?xindex>,
102+
// CHECK-SAME: %[[VAL_C_3:.*11]]: memref<?xf32>,
103+
// CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier
104+
105+
// CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]]
106+
// CHECK-SAME: -> memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
107+
// CHECK: case 1 {
108+
// CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]]
109+
// CHECK: case 2 {
110+
// CHECK: scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]]
111+
// CHECK: default {
112+
// CHECK: scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]]
113+
114+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 :
115+
// CHECK-SAME: memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
116+
117+
func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>,
118+
%b: tensor<5xf32, #SparseVector>,
119+
%c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> {
120+
%0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector>
121+
case 1 {
122+
scf.yield %a : tensor<5xf32, #SparseVector>
123+
}
124+
case 2 {
125+
scf.yield %b : tensor<5xf32, #SparseVector>
126+
}
127+
default {
128+
scf.yield %c : tensor<5xf32, #SparseVector>
129+
}
130+
131+
return %0 : tensor<5xf32, #SparseVector>
132+
}

0 commit comments

Comments
 (0)