Skip to content

Conversation

Lewuathe
Copy link
Member

@llvmbot llvmbot added the mlir label Sep 21, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 21, 2023

@llvm/pr-subscribers-mlir

Changes

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981


Full diff: https://github.com/llvm/llvm-project/pull/67001.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+11-7)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+41)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 28e490da330f3c3..174b7ce9fed2df4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -446,16 +446,19 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value expReal = rewriter.create<math::ExpOp>(loc, real);
-    Value cosImag = rewriter.create<math::CosOp>(loc, imag);
-    Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
-    Value sinImag = rewriter.create<math::SinOp>(loc, imag);
-    Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
+    Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
+    Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+    Value resultReal =
+        rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
+    Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+    Value resultImag =
+        rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
@@ -471,14 +474,15 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
+    Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
 
     Value real = b.create<complex::ReOp>(elementType, exp);
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1));
-    Value realMinusOne = b.create<arith::SubFOp>(real, one);
+    Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
     Value imag = b.create<complex::ImOp>(elementType, exp);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 9b2eef82541952a..8264382a02651c2 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -757,3 +757,44 @@ func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_exp_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %exp = complex.exp %arg fastmath<nnan,contract> : complex<f32>
+  return %exp : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_expm1_with_fmf(
+// CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
+  return %expm1 : complex<f32>
+}
+// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
+// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
+// CHECK: return %[[RES]] : complex<f32>
\ No newline at end of file

@Lewuathe Lewuathe requested review from bixia1 and tpopp September 22, 2023 04:18
@tpopp
Copy link
Contributor

tpopp commented Sep 22, 2023

I don't feel appropriate to review this as I haven't been involved with anything MLIR for ~4 months. I'll remove myself, but you can add me again and I'll review it if you can't find anyone else.

@tpopp tpopp removed their request for review September 22, 2023 07:03
@Lewuathe
Copy link
Member Author

@tpopp Sorry for bothering you with the review request. I've got the review from @joker-eph this time. Thank!

@Lewuathe Lewuathe merged commit d230bf3 into llvm:main Sep 23, 2023
@Lewuathe Lewuathe deleted the fastmath-for-complex-exp branch September 23, 2023 01:27
Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Sep 28, 2023
Local branch amd-gfx da5edb4 Merged main:12ee3a6f53db into amd-gfx:b62921ced91f
Remote branch main d230bf3 [mlir][complex] Support Fastmath flag in the conversion of exp,expm1 (llvm#67001)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants