From cc9dfde0466a2ac7cdb23520200d62e5d6532a92 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 8 May 2025 10:03:37 +0200 Subject: [PATCH] [mlir] fix MemRefToLLVM lowering of atomic operations We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics: `m**imumf` propagates NaNs from either operand whereas `fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN operand. This also contradicts the lowering of `arith.m**imumf` and `arith.m**numf` operations. Change the lowering to match the terminology in arith. Add tests for these lowerings. Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now). --- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 13 +++++++++++++ .../Conversion/MemRefToLLVM/memref-to-llvm.mlir | 10 +++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index c8b2c0bdc6c20..0eade14ee89e5 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -28,6 +28,9 @@ #include "llvm/Support/MathExtras.h" #include +#define DEBUG_TYPE "memref-to-llvm" +#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " + namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" @@ -1773,12 +1776,22 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { case arith::AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: + // TODO: remove this by end of 2025. + LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"); + return LLVM::AtomicBinOp::fmaximum; + case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; case arith::AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case arith::AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: + // TODO: remove this by end of 2025. + LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"); + return LLVM::AtomicBinOp::fminimum; + case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; case arith::AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 5538ddf8e4c3c..a986a39fc1e92 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -427,11 +427,19 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel + memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel + memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel + memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel + memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32 // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32 // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel - // CHECK-INTERFACE-COUNT-9: llvm.atomicrmw + // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw return }