Skip to content

Commit a879a1b

Browse files
NatashaKnkrsuderman
authored andcommitted
[mlir][tosa] Add tosa.reciprocal and tosa.sigmoid lowerings
Lowering reciprocal and sigmoid elementwise operations to linalg dialect. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D99676
1 parent e470147 commit a879a1b

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
115115
return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
116116
}
117117

118+
// tosa::ReciprocalOp
119+
if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
120+
auto one =
121+
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
122+
return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
123+
}
124+
118125
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
119126
Value a = args[0];
120127
Value b = args[1];
@@ -325,6 +332,16 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
325332
rewriter);
326333
}
327334

335+
// tosa::SigmoidOp
336+
if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
337+
auto one =
338+
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
339+
auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
340+
auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
341+
auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
342+
return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
343+
}
344+
328345
// tosa::CastOp
329346
if (isa<tosa::CastOp>(op)) {
330347
Type srcTy = elementTy;
@@ -1382,11 +1399,11 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
13821399
RewritePatternSet *patterns) {
13831400
patterns->add<
13841401
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
1385-
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
1386-
PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
1387-
PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
1388-
PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
1389-
PointwiseConverter<tosa::BitwiseAndOp>,
1402+
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::ReciprocalOp>,
1403+
PointwiseConverter<tosa::NegateOp>, PointwiseConverter<tosa::PowOp>,
1404+
PointwiseConverter<tosa::RsqrtOp>, PointwiseConverter<tosa::LogOp>,
1405+
PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
1406+
PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
13901407
PointwiseConverter<tosa::BitwiseOrOp>,
13911408
PointwiseConverter<tosa::BitwiseNotOp>,
13921409
PointwiseConverter<tosa::BitwiseXorOp>,
@@ -1401,11 +1418,11 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
14011418
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
14021419
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
14031420
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
1404-
IdentityNConverter<tosa::IdentityOp>,
1421+
PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
14051422
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
14061423
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
1407-
ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter, PadConverter,
1408-
ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
1409-
TransposeConverter, MatMulConverter, FullyConnectedConverter>(
1410-
patterns->getContext());
1424+
ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
1425+
PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,
1426+
TileConverter, TransposeConverter, MatMulConverter,
1427+
FullyConnectedConverter>(patterns->getContext());
14111428
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,33 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
180180
// CHECK: select
181181
%18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
182182

183+
// CHECK: linalg.generic
184+
// CHECK: negf
185+
// CHECK: exp
186+
// CHECK: addf
187+
// CHECK: divf
188+
%19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
189+
183190
// CHECK: linalg.generic
184191
// CHECK: fptosi
185-
%19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
192+
%20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
186193

187194
// CHECK: linalg.generic
188195
// CHECK: constant 0
189196
// CHECK: cmpf
190-
%20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
197+
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
191198

192199
// CHECK: linalg.generic
193200
// CHECK: fptrunc
194-
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
201+
%22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
195202

196203
// CHECK: linalg.generic
197204
// CHECK: yield
198-
%22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
205+
%23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
206+
207+
// CHECK: linalg.generic
208+
// CHECK: divf
209+
%24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
199210

200211
return
201212
}

0 commit comments

Comments
 (0)