-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][complex] Support Fastmath flag in the conversion of exp,expm1 #67001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir ChangesSee: Full diff: https://github.com/llvm/llvm-project/pull/67001.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 28e490da330f3c3..174b7ce9fed2df4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -446,16 +446,19 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
Value imag =
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value expReal = rewriter.create<math::ExpOp>(loc, real);
- Value cosImag = rewriter.create<math::CosOp>(loc, imag);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
- Value sinImag = rewriter.create<math::SinOp>(loc, imag);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
+ Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
+ Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+ Value resultReal =
+ rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
+ Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+ Value resultImag =
+ rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -471,14 +474,15 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
ConversionPatternRewriter &rewriter) const override {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
+ Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
Value real = b.create<complex::ReOp>(elementType, exp);
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
- Value realMinusOne = b.create<arith::SubFOp>(real, one);
+ Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
Value imag = b.create<complex::ImOp>(elementType, exp);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 9b2eef82541952a..8264382a02651c2 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -757,3 +757,44 @@ func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_exp_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %exp = complex.exp %arg fastmath<nnan,contract> : complex<f32>
+ return %exp : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func.func @complex_expm1_with_fmf(
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
+ return %expm1 : complex<f32>
+}
+// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
+// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
+// CHECK: return %[[RES]] : complex<f32>
\ No newline at end of file
|
I don't feel appropriate to review this as I haven't been involved with anything MLIR for ~4 months. I'll remove myself, but you can add me again and I'll review it if you can't find anyone else. |
@tpopp Sorry for bothering you with the review request. I've got the review from @joker-eph this time. Thank! |
Local branch amd-gfx da5edb4 Merged main:12ee3a6f53db into amd-gfx:b62921ced91f Remote branch main d230bf3 [mlir][complex] Support Fastmath flag in the conversion of exp,expm1 (llvm#67001)
See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981