-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][ArmSME] Add arith-to-arm-sme conversion pass #78197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Existing 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Cullen Rhodes (c-rhodes) ChangesExisting 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions. Patch is 27.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78197.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
new file mode 100644
index 00000000000000..012e7fb5b0af2f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
@@ -0,0 +1,27 @@
+//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..0bfc5064c5dd72 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71be8841ca7c03..3467e042c493e9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToArmSME
+//===----------------------------------------------------------------------===//
+
+def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
+ let summary = "Convert Arith dialect to ArmSME dialect";
+ let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArmNeon2dToIntr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index b7d90195d49d76..a15eac7302077b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,9 +16,16 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include <optional>
+namespace mlir {
+class Location;
+class PatternRewriter;
+class Value;
+} // namespace mlir
+
namespace mlir::arm_sme {
constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,16 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
/// Verifies the tile ID (if set) on this tile operation is valid.
LogicalResult verifyOperationHasValidTileId(Operation *);
+using LoopBodyBuilder =
+ std::function<void(OpBuilder &, Location, Value, Value)>;
+
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via LoopBodyBuilder, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
new file mode 100644
index 00000000000000..9aab969881f75e
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (llvm::isa<FloatType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (llvm::isa<IntegerType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+ using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = dyn_cast<VectorType>(constantOp.getType());
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ auto tileElementType = tileType.getElementType();
+
+ // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+ if (isSplatZero(tileElementType, denseAttr)) {
+ rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+ return success();
+ }
+
+ // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+ // ops that broadcast the constant to each tile slice.
+ auto loc = constantOp.getLoc();
+
+ // To fill a tile with a constant, we create a 1-D splat of the constant,
+ // then move that into each tile slice (the largest unit we can set at once,
+ // outside of operations like the outerproduct).
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto denseAttr1D = DenseElementsAttr::get(
+ tileSliceType, denseAttr.getSplatValue<Attribute>());
+ auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+ // slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+ auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
+ initTile, loopBody);
+ rewriter.replaceOp(constantOp, forOp.getResult(0));
+
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::arith::populateArithToArmSMEConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ArithToArmSMEConversionPass final
+ : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
+ using impl::ArithToArmSMEConversionPassBase<
+ ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arith::populateArithToArmSMEConversionPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..c2a6fe5398e7c8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToArmSME
+ ArithToArmSME.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRArithDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..3a5dbc12c23f5c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 87d1bf9bed5a31..88252725bcff26 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -16,39 +16,6 @@
using namespace mlir;
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
- if (llvm::isa<FloatType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
- if (llvm::isa<IntegerType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
- return false;
-}
-
-/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via the callback, which returns the next tile value.
-template <typename LoopBodyCallback>
-static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
- Location loc, Value initTile,
- LoopBodyCallback callback) {
- OpBuilder::InsertionGuard g(rewriter);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
- ValueRange{initTile});
- rewriter.setInsertionPointToStart(forOp.getBody());
- auto nextTile = callback(forOp);
- rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
- return forOp;
-}
-
namespace {
/// Conversion pattern for vector.transfer_read.
@@ -223,56 +190,6 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
}
};
-/// Conversion pattern for dense arith.constant.
-struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
- using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
- PatternRewriter &rewriter) const final {
- auto tileType = dyn_cast<VectorType>(constantOp.getType());
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
- if (!denseAttr || !denseAttr.isSplat())
- return failure();
-
- auto tileElementType = tileType.getElementType();
-
- // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
- if (isSplatZero(tileElementType, denseAttr)) {
- rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
- return success();
- }
-
- // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
- // ops that broadcast the constant to each tile slice.
- auto loc = constantOp.getLoc();
-
- // To fill a tile with a constant, we create a 1-D splat of the constant,
- // then move that into each tile slice (the largest unit we can set at once,
- // outside of operations like the outerproduct).
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto denseAttr1D = DenseElementsAttr::get(
- tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
-
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
- // slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
- });
- rewriter.replaceOp(constantOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.broadcast.
///
/// Example:
@@ -322,16 +239,19 @@ struct BroadcastOpToArmSMELowering
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+ // to each tile slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Create a loop over ZA tile slices.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
- // to each tile slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(broadcastOp, forOp.getResult(0));
@@ -381,15 +301,18 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(splatOp, forOp.getResult(0));
@@ -741,11 +664,10 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns
- .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadToArmSMELowering,
- TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
- VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
- VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
+ patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+ VectorPrintToArmSMELowering>(&ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1fa060cafc0bc6..2e159abb1e89eb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,4 +72,24 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
return success();
}
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder) {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+ ValueRange{initTile});
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+ /*currentTile=*/forOp.getRegionIterArg(0));
+ return forOp;
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
similarity index 97%
rename from mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
rename to mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
index e51f2485dadbcc..49d2e2f3c182b9 100644
--- a...
[truncated]
|
@llvm/pr-subscribers-mlir-sme Author: Cullen Rhodes (c-rhodes) ChangesExisting 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions. Patch is 27.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78197.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
new file mode 100644
index 00000000000000..012e7fb5b0af2f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
@@ -0,0 +1,27 @@
+//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..0bfc5064c5dd72 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71be8841ca7c03..3467e042c493e9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToArmSME
+//===----------------------------------------------------------------------===//
+
+def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
+ let summary = "Convert Arith dialect to ArmSME dialect";
+ let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArmNeon2dToIntr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index b7d90195d49d76..a15eac7302077b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,9 +16,16 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include <optional>
+namespace mlir {
+class Location;
+class PatternRewriter;
+class Value;
+} // namespace mlir
+
namespace mlir::arm_sme {
constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,16 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
/// Verifies the tile ID (if set) on this tile operation is valid.
LogicalResult verifyOperationHasValidTileId(Operation *);
+using LoopBodyBuilder =
+ std::function<void(OpBuilder &, Location, Value, Value)>;
+
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via LoopBodyBuilder, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
new file mode 100644
index 00000000000000..9aab969881f75e
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (llvm::isa<FloatType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (llvm::isa<IntegerType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+ using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = dyn_cast<VectorType>(constantOp.getType());
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ auto tileElementType = tileType.getElementType();
+
+ // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+ if (isSplatZero(tileElementType, denseAttr)) {
+ rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+ return success();
+ }
+
+ // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+ // ops that broadcast the constant to each tile slice.
+ auto loc = constantOp.getLoc();
+
+ // To fill a tile with a constant, we create a 1-D splat of the constant,
+ // then move that into each tile slice (the largest unit we can set at once,
+ // outside of operations like the outerproduct).
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto denseAttr1D = DenseElementsAttr::get(
+ tileSliceType, denseAttr.getSplatValue<Attribute>());
+ auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+ // slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+ auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
+ initTile, loopBody);
+ rewriter.replaceOp(constantOp, forOp.getResult(0));
+
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::arith::populateArithToArmSMEConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ArithToArmSMEConversionPass final
+ : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
+ using impl::ArithToArmSMEConversionPassBase<
+ ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arith::populateArithToArmSMEConversionPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..c2a6fe5398e7c8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToArmSME
+ ArithToArmSME.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRArithDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..3a5dbc12c23f5c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 87d1bf9bed5a31..88252725bcff26 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -16,39 +16,6 @@
using namespace mlir;
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
- if (llvm::isa<FloatType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
- if (llvm::isa<IntegerType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
- return false;
-}
-
-/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via the callback, which returns the next tile value.
-template <typename LoopBodyCallback>
-static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
- Location loc, Value initTile,
- LoopBodyCallback callback) {
- OpBuilder::InsertionGuard g(rewriter);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
- ValueRange{initTile});
- rewriter.setInsertionPointToStart(forOp.getBody());
- auto nextTile = callback(forOp);
- rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
- return forOp;
-}
-
namespace {
/// Conversion pattern for vector.transfer_read.
@@ -223,56 +190,6 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
}
};
-/// Conversion pattern for dense arith.constant.
-struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
- using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
- PatternRewriter &rewriter) const final {
- auto tileType = dyn_cast<VectorType>(constantOp.getType());
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
- if (!denseAttr || !denseAttr.isSplat())
- return failure();
-
- auto tileElementType = tileType.getElementType();
-
- // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
- if (isSplatZero(tileElementType, denseAttr)) {
- rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
- return success();
- }
-
- // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
- // ops that broadcast the constant to each tile slice.
- auto loc = constantOp.getLoc();
-
- // To fill a tile with a constant, we create a 1-D splat of the constant,
- // then move that into each tile slice (the largest unit we can set at once,
- // outside of operations like the outerproduct).
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto denseAttr1D = DenseElementsAttr::get(
- tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
-
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
- // slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
- });
- rewriter.replaceOp(constantOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.broadcast.
///
/// Example:
@@ -322,16 +239,19 @@ struct BroadcastOpToArmSMELowering
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+ // to each tile slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Create a loop over ZA tile slices.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
- // to each tile slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(broadcastOp, forOp.getResult(0));
@@ -381,15 +301,18 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(splatOp, forOp.getResult(0));
@@ -741,11 +664,10 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns
- .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadToArmSMELowering,
- TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
- VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
- VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
+ patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+ VectorPrintToArmSMELowering>(&ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1fa060cafc0bc6..2e159abb1e89eb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,4 +72,24 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
return success();
}
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder) {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+ ValueRange{initTile});
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+ /*currentTile=*/forOp.getRegionIterArg(0));
+ return forOp;
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
similarity index 97%
rename from mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
rename to mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
index e51f2485dadbcc..49d2e2f3c182b9 100644
--- a...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Existing 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions.