-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] fix MemRefToLLVM lowering of atomic operations #139045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics: `m**imumf` propagates NaNs from either operand whereas `fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN operand. This also contradicts the lowering of `arith.m**imumf` and `arith.m**numf` operations. Change the lowering to match the terminology in arith. Add tests for these lowerings. Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now).
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesWe have been confusingly, and arguably incorrectly, lowering Change the lowering to match the terminology in arith. Add tests for these lowerings. Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now). Full diff: https://github.com/llvm/llvm-project/pull/139045.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..0eade14ee89e5 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -28,6 +28,9 @@
#include "llvm/Support/MathExtras.h"
#include <optional>
+#define DEBUG_TYPE "memref-to-llvm"
+#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
+
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -1773,12 +1776,22 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs");
+ return LLVM::AtomicBinOp::fmaximum;
+ case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs");
+ return LLVM::AtomicBinOp::fminimum;
+ case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5538ddf8e4c3c..a986a39fc1e92 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -427,11 +427,19 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-9: llvm.atomicrmw
+ // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Is it worth adding a release note entry too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed that we should flag this in a release note, but otherwise seems fine to me
We have been confusingly, and arguably incorrectly, lowering
m**imumf
atomic RMW operations in the MemRef dialect tofm**
atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics:m**imumf
propagates NaNs from either operand whereasfm**
, which lowers to thefm**num
intrinsic returns the non-NaN operand. This also contradicts the lowering ofarith.m**imumf
andarith.m**numf
operations.Change the lowering to match the terminology in arith.
Add tests for these lowerings.
Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now).