-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Use liveness information in the tile allocator #90448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Depends on #90447 |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-linalg Author: Benjamin Maxwell (MacDue) ChangesThis patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case. Hopefully, the only changes needed for a user of the ArmSME dialect is that:
By integrating this into the The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as The aim is correctness, so we have a base for working on optimizations. Patch is 129.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90448.diff 30 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab499983..403f811a2569a0 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
#include <memory>
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir {
class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..e6d678dc1b12b3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
// ArmSMEToLLVM
//===----------------------------------------------------------------------===//
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
"arm_sme::ArmSMEDialect",
"LLVM::LLVMDialect"
];
+ let options = [
+ Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+ "bool", /*default=*/"false",
+ "Dump the live ranges of SME tiles (for debugging)">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a74..dac54712c7f47a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 00000000000000..f31062d8c25ed7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,28 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+static constexpr unsigned kInMemoryTileIdBase = 16;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d2a..9178655f010c9a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
let description = [{
- An interface for operations that use or allocate Arm SME tiles. These
- operations need to be assigned a tile ID, an i32 attribute, which specifies
- which virtual tile within the ZA storage to use. The number of tiles
- available depends on the type of the tile. This is summarized below:
+ An interface for operations that use Arm SME tiles. These operations need to
+ be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+ within the ZA storage to use. The number of tiles available depends on the
+ type of the tile. This is summarized below:
| Tile Vector Types | Possible Tile IDs |
|-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
| `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
-
- Operations that allocate a new tile (such as arm_sme.get_tile), are used as
- the roots for tile allocation, with all operations that (transitively)
- depend on a root being assigned the same tile ID.
}];
let methods = [
InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
}]
>,
- InterfaceMethod<
- [{
- The type of tile this operation allocates. Returns none (std::nullopt)
- if this operation does not allocate a tile.
- }],
- /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
- /*methodName=*/"getAllocatedTileType",
- /*arguments=*/(ins),
- /*methodBody=*/[{}],
- /*defaultImpl=*/ [{
- // This operation does not allocate a tile.
- return std::nullopt;
- }]
- >,
InterfaceMethod<
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
];
let extraSharedClassDeclaration = [{
- // A helper to create a new operation and propagate this operations tile ID.
- template<typename T, typename... Args>
- T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
- auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
- if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
- tileOp.setTileId($_op.getTileId());
- return op;
- }
-
- // A helper to replace this operation and forward its tile ID (if present).
- template<typename T, typename... Args>
- T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
- auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
- rewriter.replaceOp($_op, newOp);
- return newOp;
- }
-
bool isInMemoryTile() {
auto tileId = getTileId();
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
}
}];
- let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+ let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
}
//===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
Op<ArmSME_Dialect, mnemonic, traits> {}
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
- let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+ let summary = "Creates an undefined value of SME virtual tile type";
let description = [{
- Allocates a new SME "virtual tile" within a function. The contents of the
- tile returned from this operation are undefined.
+ Creates a new SME "virtual tile" value within a function. The contents of
+ the tile returned from this operation are undefined.
Example 1:
```mlir
- // Allocate an 8-bit element "virtual tile"
+ // Create an 8-bit element "virtual tile" value:
%za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
```
Example 2:
```mlir
- // Allocate two 16-bit element "virtual tiles"
+ // Create two 16-bit element "virtual tiles" values:
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
```
Example 3:
```mlir
- // Allocate an 128-bit element "virtual tile"
+ // Create an 128-bit element "virtual tile" value:
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
```
}];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
VectorType getTileType() {
return ::llvm::cast<VectorType>(getTile().getType());
}
-
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getTileType());
- }
- }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
- let summary = "SME tile placeholder";
- let description = [{
- A placeholder to preserve dataflow while lowering to SME intrinsics (which
- do not take or return SME virtual tile values). This operation is intended
- to be DCE'd once all ArmSME operations have been lowered.
-
- This operation is not intended to be used outside of the ArmSME -> LLVM
- conversion.
}];
- let results = (outs SMETile:$tile);
- let assemblyFormat = "attr-dict `:` type($tile)";
}
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
- let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+ let summary = "Creates a zero-initialized value of SME virtual tile type";
let results = (outs SMETile:$res);
let description = [{
- Initialise ZA with 0. This operation is convenient wrapper for the SME
- `zero` intrinsic and instruction.
+ Creates a new SME "virtual tile" value within a function. The contents of
+ the tile returned from this operation are zero-initialized.
Example 1: Zero an 8-bit element ZA tile.
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getVectorType());
- }
VectorType getTileType() {
return getVectorType();
}
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
let assemblyFormat = "attr-dict `:` type($res)";
}
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+ Pure,
+ ArmSMETileOpInterface,
+ AllTypesMatch<["tile", "result"]>
+]> {
+ let summary = "Copies an SME tile value";
+ let arguments = (ins SMETile:$tile);
+ let results = (outs SMETile:$result);
+ let description = [{
+ Copies an SME "virtual tile" value to a new SSA value. This operation is
+ primarily intended to be used to normalize the IR prior to tile allocation.
+
+ Example:
+
+ ```mlir
+ %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getResult().getType());
+ }
+ }];
+ let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
def TileLoadOp : ArmSME_Op<"tile_load", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getVectorType());
- }
VectorType getTileType() {
return getVectorType();
}
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
}
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
- ArmSMETileOpInterface,
+ ArmSMETileOpInterface, Pure,
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
- ArmSMETileOpInterface,
+ ArmSMETileOpInterface, Pure,
TypesMatchWith<
"type of 'result' matches type of 'tile' slice",
"tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
def OuterProductOp :
ArmSME_Op<"outerproduct", [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- // The outerproduct op allocates a new tile if no accumulator is passed.
- if (!getAcc())
- return arm_sme::getSMETileType(getResultType());
- return std::nullopt;
- }
VectorType getTileType() {
return getResultType();
}
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
list<Type> allowedResultVectorTypes,
int numOuterProducts> :
ArmSME_Op<mnemonic, [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- // The outerproduct op allocates a new tile if no accumulator is passed.
- if (!getAcc())
- return arm_sme::getSMETileType(getResultType());
- return std::nullopt;
- }
VectorType getTileType() {
return getResultType();
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874ec..156744ba57e7b2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e89267..b9d74fec6756e3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,16 +124,21 @@ def EnableArmStreaming
let dependentDialects = ["func::FuncDialect"];
}
-def TileAllocation
- : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
- let summary = "Allocate SME tiles";
+def TestTileAllocation
+ : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+ let summary = "Tests SME tile allocation";
let description = [{
This pass does tile allocation for SME "virtual tiles". It is run at the
'func.func' op level, and assigns tile IDs (via an attribute) to all ops
- that implement the `ArmSMETileOpInterface`. An error will be emitted when
- there's no tiles left.
+ that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+ to be used for testing, tile allocation is done as part of the ArmSME to
+ LLVM conversion.
}];
- let constructor = "mlir::arm_sme::createTileAllocationPass()";
+ let options = [
+ Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+ "bool", /*default=*/"false",
+ "Dump the live ranges of SME tiles (for debugging)">
+ ];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e69992..a25b844f01eaa6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
namespace mlir {
class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
class RewritePatternSet;
namespace arm_sme {
+
void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+ bool dumpRanges = false);
+
} // namespace arm_sme
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92f5..9ea1c5a5d63fe5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include <optional>
namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);
+inline bool isValidSMETileVectorType(Type type) {
+ auto vType = dyn_cast<VectorType>(type);
+ return vType && isValidSMETileVectorType(vType);
+}
+
/// Returns the type of SME tile this vector type corresponds to, or none if the
/// vector type does not fit within an SME tile.
std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,19 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
VectorType getSMETileTypeForElement(Type elementType);
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+ FunctionOpInterfa...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the patch Ben, this is impressive work and a significant improvement to SME support in MLIR.
I haven't reviewed the whole patch yet but I've left some initial comments for now.
// This transform allocates SME tiles at the 'func.func' op level for ArmSME | ||
// operations. It does this using a 16-bit tile mask that has a bit for each | ||
// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, could we have more documentation in this file? For example, a high level overview of how tile allocation works (live range calculation, followed by transformation). With some info for every section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MacDue , I'm going over allocateSMETiles
and my first reaction is that step one, // 1. Insert copy operations at branch operations.
, could easily be a separate pass. That would allow us for more thorough testing - right now it's either all or nothing. Could you split it out? WDYT?
auto parent = user->getParentOp(); | ||
traverseCorrespondingValues(user->getOperands(), | ||
parent->getResults()); | ||
rewriter.setInsertionPoint(terminator); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What restores the insertion point? Use insertion guard?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to restore the insertion as every use of the rewriter sets the insertion point (it would not really make sense to assume it's at an point in particular with how it's used).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you are referring to what's happening in this method today. However, this method might get updated and also the same rewriter
is share between other hooks. Would adding an insertion guard do any harm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM assuming splitCondBranches
gets tested. Accepting now before I go on holiday. Once again Ben excellent work! 🙏
auto parent = user->getParentOp(); | ||
traverseCorrespondingValues(user->getOperands(), | ||
parent->getResults()); | ||
rewriter.setInsertionPoint(terminator); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you are referring to what's happening in this method today. However, this method might get updated and also the same rewriter
is share between other hooks. Would adding an insertion guard do any harm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing, thanks.
@@ -1,19 +1,15 @@ | |||
// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | |||
// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal is to test tile allocation, maybe we prefer to use the post scf-to-cf
as input directly instead of relying on convert-to-scf as well ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that having tests in CF too would be good (since this is an algorithm on CFGs). I've used SCF in these tests as it's easier to write/read (but cases like this one, which is just a branch, can easily be written in CF).
abd40ec
to
a91685a
Compare
This patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case. Hopefully, the only changes needed for a user of the ArmSME dialect is that: - `-allocate-arm-sme-tiles` will no longer be a standalone pass - `-test-arm-sme-tile-allocation` is only for unit tests - `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf` - SME tile allocation is now part of the LLVM conversion By integrating this into the `ArmSME -> LLVM` conversion we can allow high-level (value-based) ArmSME operations to be side-effect-free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation). The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as `vector` and `arith` to SME, without making assumptions about how the input IR looks, as the semantics of the operations will be the same. That is no (new) side effects and the IR follows the rules of SSA (a value will never change). The aim is correctness, so we have a base for working on optimizations.
Review fixups, and fixes two bugs that could occur if the spilled tile type did not match tile type of the new live range. This is now covered by two new test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive work Ben, thank you for addressing all my comments! 🙏🏻
This is a very well designed and complex bit of logic, but absolutely necessary and 100% the right solution for SME virtual tile allocation. This is probably the most important building block for supporting SME in MLIR - vital element in future-proofing this work. Thank you so much for seeing this through!
-Andrzej
…ges (#102125) - Use vector.interleave rather than the LLVM intrinsic - Remove dependency on LLVM dialect - Remove manual outerproduct erases (these are now trivially dead) - Remove comment explaining issues with previous tile allocator - Update pipeline in `multi-tile-matmul-mixed-types.mlir` Recent changes: #90448, #80965
This patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case.
Hopefully, the only changes needed for a user of the ArmSME dialect is that:
-allocate-arm-sme-tiles
will no longer be a standalone pass-test-arm-sme-tile-allocation
is only for unit tests-convert-arm-sme-to-llvm
must happen after-convert-scf-to-cf
By integrating this into the
ArmSME -> LLVM
conversion we can allow high-level (value-based) ArmSME operations to be side-effect-free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation).The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as
vector
andarith
to SME, without making assumptions about how the input IR looks, as the semantics of the operations will be the same. That is no (new) side effects and the IR follows the rules of SSA (a value will never change).The aim is correctness, so we have a base for working on optimizations.