From c471f8fb90b7493f9da718d235eca8ada8498ccb Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Fri, 8 Sep 2023 09:41:01 +0900 Subject: [PATCH] [mlir][complex] Support fastmath in the binary op conversion. Complex dialect arithmetic operations are now able to recognize the given fastmath flags. This PR lets the conversion from complex to standard keep the fastmath flag passed to arith dialect ops. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981 --- .../ComplexToStandard/ComplexToStandard.cpp | 9 ++--- .../convert-to-standard.mlir | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 2bcec4ea10f92..28e490da330f3 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -137,15 +137,16 @@ struct BinaryComplexOpConversion : public OpConversionPattern { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value realLhs = b.create(elementType, adaptor.getLhs()); Value realRhs = b.create(elementType, adaptor.getRhs()); - Value resultReal = - b.create(elementType, realLhs, realRhs); + Value resultReal = b.create(elementType, realLhs, realRhs, + fmf.getValue()); Value imagLhs = b.create(elementType, adaptor.getLhs()); Value imagRhs = b.create(elementType, adaptor.getRhs()); - Value resultImag = - b.create(elementType, imagLhs, imagRhs); + Value resultImag = b.create(elementType, imagLhs, imagRhs, + fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index bc2ea0dd7a584..9b2eef8254195 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -723,3 +723,37 @@ func.func @complex_abs_with_fmf(%arg: complex) -> f32 { // CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath : f32 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 + +// ----- + +// CHECK-LABEL: func @complex_add_with_fmf +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func.func @complex_add_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %add = complex.add %lhs, %rhs fastmath : complex + return %add : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] fastmath : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @complex_sub_with_fmf +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func.func @complex_sub_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %sub = complex.sub %lhs, %rhs fastmath : complex + return %sub : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] fastmath : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex