diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 2bcec4ea10f92..28e490da330f3 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -137,15 +137,16 @@ struct BinaryComplexOpConversion : public OpConversionPattern { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value realLhs = b.create(elementType, adaptor.getLhs()); Value realRhs = b.create(elementType, adaptor.getRhs()); - Value resultReal = - b.create(elementType, realLhs, realRhs); + Value resultReal = b.create(elementType, realLhs, realRhs, + fmf.getValue()); Value imagLhs = b.create(elementType, adaptor.getLhs()); Value imagRhs = b.create(elementType, adaptor.getRhs()); - Value resultImag = - b.create(elementType, imagLhs, imagRhs); + Value resultImag = b.create(elementType, imagLhs, imagRhs, + fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index bc2ea0dd7a584..9b2eef8254195 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -723,3 +723,37 @@ func.func @complex_abs_with_fmf(%arg: complex) -> f32 { // CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath : f32 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 + +// ----- + +// CHECK-LABEL: func @complex_add_with_fmf +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func.func @complex_add_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %add = complex.add %lhs, %rhs fastmath : complex + return %add : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] fastmath : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @complex_sub_with_fmf +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func.func @complex_sub_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %sub = complex.sub %lhs, %rhs fastmath : complex + return %sub : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] fastmath : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex