Skip to content

Commit 1bc9f2c

Browse files
GaoXiangYalanza
authored andcommitted
[CIR][Lowering] Add MLIR lowering support for CIR math operations (#592)
This pr adds `cir.ceil` `cir.exp2` `cir.exp` `cir.fabs` `cir.floor` `cir.log` `cir.log10` `cir.log2` `cir.round` `cir.sqrt` lowering to MLIR passes and test files.
1 parent a78dfcd commit 1bc9f2c

File tree

8 files changed

+345
-8
lines changed

8 files changed

+345
-8
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,129 @@ class CIRCosOpLowering : public mlir::OpConversionPattern<mlir::cir::CosOp> {
209209
}
210210
};
211211

212+
class CIRSqrtOpLowering : public mlir::OpConversionPattern<mlir::cir::SqrtOp> {
213+
public:
214+
using mlir::OpConversionPattern<mlir::cir::SqrtOp>::OpConversionPattern;
215+
216+
mlir::LogicalResult
217+
matchAndRewrite(mlir::cir::SqrtOp op, OpAdaptor adaptor,
218+
mlir::ConversionPatternRewriter &rewriter) const override {
219+
rewriter.replaceOpWithNewOp<mlir::math::SqrtOp>(op, adaptor.getSrc());
220+
return mlir::LogicalResult::success();
221+
}
222+
};
223+
224+
class CIRFAbsOpLowering : public mlir::OpConversionPattern<mlir::cir::FAbsOp> {
225+
public:
226+
using mlir::OpConversionPattern<mlir::cir::FAbsOp>::OpConversionPattern;
227+
228+
mlir::LogicalResult
229+
matchAndRewrite(mlir::cir::FAbsOp op, OpAdaptor adaptor,
230+
mlir::ConversionPatternRewriter &rewriter) const override {
231+
rewriter.replaceOpWithNewOp<mlir::math::AbsFOp>(op, adaptor.getSrc());
232+
return mlir::LogicalResult::success();
233+
}
234+
};
235+
236+
class CIRFloorOpLowering
237+
: public mlir::OpConversionPattern<mlir::cir::FloorOp> {
238+
public:
239+
using mlir::OpConversionPattern<mlir::cir::FloorOp>::OpConversionPattern;
240+
241+
mlir::LogicalResult
242+
matchAndRewrite(mlir::cir::FloorOp op, OpAdaptor adaptor,
243+
mlir::ConversionPatternRewriter &rewriter) const override {
244+
rewriter.replaceOpWithNewOp<mlir::math::FloorOp>(op, adaptor.getSrc());
245+
return mlir::LogicalResult::success();
246+
}
247+
};
248+
249+
class CIRCeilOpLowering : public mlir::OpConversionPattern<mlir::cir::CeilOp> {
250+
public:
251+
using mlir::OpConversionPattern<mlir::cir::CeilOp>::OpConversionPattern;
252+
253+
mlir::LogicalResult
254+
matchAndRewrite(mlir::cir::CeilOp op, OpAdaptor adaptor,
255+
mlir::ConversionPatternRewriter &rewriter) const override {
256+
rewriter.replaceOpWithNewOp<mlir::math::CeilOp>(op, adaptor.getSrc());
257+
return mlir::LogicalResult::success();
258+
}
259+
};
260+
261+
class CIRLog10OpLowering
262+
: public mlir::OpConversionPattern<mlir::cir::Log10Op> {
263+
public:
264+
using mlir::OpConversionPattern<mlir::cir::Log10Op>::OpConversionPattern;
265+
266+
mlir::LogicalResult
267+
matchAndRewrite(mlir::cir::Log10Op op, OpAdaptor adaptor,
268+
mlir::ConversionPatternRewriter &rewriter) const override {
269+
rewriter.replaceOpWithNewOp<mlir::math::Log10Op>(op, adaptor.getSrc());
270+
return mlir::LogicalResult::success();
271+
}
272+
};
273+
274+
class CIRLogOpLowering : public mlir::OpConversionPattern<mlir::cir::LogOp> {
275+
public:
276+
using mlir::OpConversionPattern<mlir::cir::LogOp>::OpConversionPattern;
277+
278+
mlir::LogicalResult
279+
matchAndRewrite(mlir::cir::LogOp op, OpAdaptor adaptor,
280+
mlir::ConversionPatternRewriter &rewriter) const override {
281+
rewriter.replaceOpWithNewOp<mlir::math::LogOp>(op, adaptor.getSrc());
282+
return mlir::LogicalResult::success();
283+
}
284+
};
285+
286+
class CIRLog2OpLowering : public mlir::OpConversionPattern<mlir::cir::Log2Op> {
287+
public:
288+
using mlir::OpConversionPattern<mlir::cir::Log2Op>::OpConversionPattern;
289+
290+
mlir::LogicalResult
291+
matchAndRewrite(mlir::cir::Log2Op op, OpAdaptor adaptor,
292+
mlir::ConversionPatternRewriter &rewriter) const override {
293+
rewriter.replaceOpWithNewOp<mlir::math::Log2Op>(op, adaptor.getSrc());
294+
return mlir::LogicalResult::success();
295+
}
296+
};
297+
298+
class CIRRoundOpLowering
299+
: public mlir::OpConversionPattern<mlir::cir::RoundOp> {
300+
public:
301+
using mlir::OpConversionPattern<mlir::cir::RoundOp>::OpConversionPattern;
302+
303+
mlir::LogicalResult
304+
matchAndRewrite(mlir::cir::RoundOp op, OpAdaptor adaptor,
305+
mlir::ConversionPatternRewriter &rewriter) const override {
306+
rewriter.replaceOpWithNewOp<mlir::math::RoundOp>(op, adaptor.getSrc());
307+
return mlir::LogicalResult::success();
308+
}
309+
};
310+
311+
class CIRExpOpLowering : public mlir::OpConversionPattern<mlir::cir::ExpOp> {
312+
public:
313+
using mlir::OpConversionPattern<mlir::cir::ExpOp>::OpConversionPattern;
314+
315+
mlir::LogicalResult
316+
matchAndRewrite(mlir::cir::ExpOp op, OpAdaptor adaptor,
317+
mlir::ConversionPatternRewriter &rewriter) const override {
318+
rewriter.replaceOpWithNewOp<mlir::math::ExpOp>(op, adaptor.getSrc());
319+
return mlir::LogicalResult::success();
320+
}
321+
};
322+
323+
class CIRExp2OpLowering : public mlir::OpConversionPattern<mlir::cir::Exp2Op> {
324+
public:
325+
using mlir::OpConversionPattern<mlir::cir::Exp2Op>::OpConversionPattern;
326+
327+
mlir::LogicalResult
328+
matchAndRewrite(mlir::cir::Exp2Op op, OpAdaptor adaptor,
329+
mlir::ConversionPatternRewriter &rewriter) const override {
330+
rewriter.replaceOpWithNewOp<mlir::math::Exp2Op>(op, adaptor.getSrc());
331+
return mlir::LogicalResult::success();
332+
}
333+
};
334+
212335
class CIRSinOpLowering : public mlir::OpConversionPattern<mlir::cir::SinOp> {
213336
public:
214337
using mlir::OpConversionPattern<mlir::cir::SinOp>::OpConversionPattern;
@@ -1000,14 +1123,18 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
10001123
mlir::TypeConverter &converter) {
10011124
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
10021125

1003-
patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1004-
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1005-
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1006-
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1007-
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1008-
CIRGetGlobalOpLowering, CIRCastOpLowering,
1009-
CIRPtrStrideOpLowering, CIRSinOpLowering>(converter,
1010-
patterns.getContext());
1126+
patterns
1127+
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1128+
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1129+
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1130+
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1131+
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1132+
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1133+
CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering,
1134+
CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering,
1135+
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1136+
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering>(
1137+
converter, patterns.getContext());
10111138
}
10121139

10131140
static mlir::TypeConverter prepareTypeConverter() {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<1.31> : !cir.float
7+
%1 = cir.const #cir.fp<3.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<2.73> : !cir.double
9+
%3 = cir.const #cir.fp<4.67> : !cir.long_double<!cir.double>
10+
%4 = cir.ceil %0 : !cir.float
11+
%5 = cir.ceil %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.ceil %2 : !cir.double
13+
%7 = cir.ceil %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 1.310000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant 3.000000e+00 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant 2.730000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant 4.670000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.ceil %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.ceil %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.ceil %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.ceil %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<1.0> : !cir.float
7+
%1 = cir.const #cir.fp<3.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<2.0> : !cir.double
9+
%3 = cir.const #cir.fp<4.00> : !cir.long_double<!cir.double>
10+
%4 = cir.exp %0 : !cir.float
11+
%5 = cir.exp %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.exp2 %2 : !cir.double
13+
%7 = cir.exp2 %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 1.000000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant 3.000000e+00 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant 2.000000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant 4.000000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.exp %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.exp %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.exp2 %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.exp2 %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<-1.0> : !cir.float
7+
%1 = cir.const #cir.fp<-3.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<-2.0> : !cir.double
9+
%3 = cir.const #cir.fp<-4.00> : !cir.long_double<!cir.double>
10+
%4 = cir.fabs %0 : !cir.float
11+
%5 = cir.fabs %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.fabs %2 : !cir.double
13+
%7 = cir.fabs %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant -1.000000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant -3.000000e+00 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant -2.000000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant -4.000000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.absf %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.absf %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.absf %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.absf %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<1.51> : !cir.float
7+
%1 = cir.const #cir.fp<3.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<2.73> : !cir.double
9+
%3 = cir.const #cir.fp<4.67> : !cir.long_double<!cir.double>
10+
%4 = cir.floor %0 : !cir.float
11+
%5 = cir.floor %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.floor %2 : !cir.double
13+
%7 = cir.floor %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 1.510000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant 3.000000e+00 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant 2.730000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant 4.670000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.floor %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.floor %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.floor %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.floor %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<1.0> : !cir.float
7+
%1 = cir.const #cir.fp<3.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<2.0> : !cir.double
9+
%3 = cir.const #cir.fp<4.0> : !cir.long_double<!cir.double>
10+
%4 = cir.log %0 : !cir.float
11+
%5 = cir.log %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.log2 %2 : !cir.double
13+
%7 = cir.log10 %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 1.000000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant 3.000000e+00 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant 2.000000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant 4.000000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.log %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.log %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.log2 %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.log10 %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<1.31> : !cir.float
7+
%1 = cir.const #cir.fp<3.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<2.73> : !cir.double
9+
%3 = cir.const #cir.fp<4.67> : !cir.long_double<!cir.double>
10+
%4 = cir.round %0 : !cir.float
11+
%5 = cir.round %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.round %2 : !cir.double
13+
%7 = cir.round %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 1.310000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant 3.000000e+00 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant 2.730000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant 4.670000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.round %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.round %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.round %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.round %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck %s --input-file %t.mlir
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.const #cir.fp<9.0> : !cir.float
7+
%1 = cir.const #cir.fp<100.0> : !cir.long_double<!cir.f80>
8+
%2 = cir.const #cir.fp<1.0> : !cir.double
9+
%3 = cir.const #cir.fp<2.56> : !cir.long_double<!cir.double>
10+
%4 = cir.sqrt %0 : !cir.float
11+
%5 = cir.sqrt %1 : !cir.long_double<!cir.f80>
12+
%6 = cir.sqrt %2 : !cir.double
13+
%7 = cir.sqrt %3 : !cir.long_double<!cir.double>
14+
cir.return
15+
}
16+
}
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: func.func @foo() {
20+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 9.000000e+00 : f32
21+
// CHECK-NEXT: %[[C1:.+]] = arith.constant 1.000000e+02 : f80
22+
// CHECK-NEXT: %[[C2:.+]] = arith.constant 1.000000e+00 : f64
23+
// CHECK-NEXT: %[[C3:.+]] = arith.constant 2.560000e+00 : f64
24+
// CHECK-NEXT: %{{.+}} = math.sqrt %[[C0]] : f32
25+
// CHECK-NEXT: %{{.+}} = math.sqrt %[[C1]] : f80
26+
// CHECK-NEXT: %{{.+}} = math.sqrt %[[C2]] : f64
27+
// CHECK-NEXT: %{{.+}} = math.sqrt %[[C3]] : f64
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)