Skip to content

[mlir][ArmSME] Use liveness information in the tile allocator #90448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 14, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Apr 29, 2024

This patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case.

Hopefully, the only changes needed for a user of the ArmSME dialect is that:

  • -allocate-arm-sme-tiles will no longer be a standalone pass
    • -test-arm-sme-tile-allocation is only for unit tests
  • -convert-arm-sme-to-llvm must happen after -convert-scf-to-cf
    • SME tile allocation is now part of the LLVM conversion

By integrating this into the ArmSME -> LLVM conversion we can allow high-level (value-based) ArmSME operations to be side-effect-free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation).

The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as vector and arith to SME, without making assumptions about how the input IR looks, as the semantics of the operations will be the same. That is no (new) side effects and the IR follows the rules of SSA (a value will never change).

The aim is correctness, so we have a base for working on optimizations.

@MacDue
Copy link
Member Author

MacDue commented Apr 29, 2024

Depends on #90447

@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-linalg

Author: Benjamin Maxwell (MacDue)

Changes

This patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case.

Hopefully, the only changes needed for a user of the ArmSME dialect is that:

  • -allocate-arm-sme-tiles will no longer be a standalone pass
    • -test-arm-sme-tile-allocation is only for unit tests
  • -convert-arm-sme-to-llvm must happen after -convert-scf-to-cf
    • SME tile allocation is now part of the LLVM conversion

By integrating this into the ArmSME -> LLVM conversion we can allow high-level (value-based) ArmSME operations to be side-effect-free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation).

The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as vector and arith to SME, without making assumptions about how the input IR looks, as the semantics of the operations will be the same. That is no (new) side effects and the IR follows the rules of SSA (a value will never change).

The aim is correctness, so we have a base for working on optimizations.


Patch is 129.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90448.diff

30 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h (+3-1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+6-1)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+1-5)
  • (added) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h (+28)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+47-94)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (-3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+11-6)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h (+9)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+20)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+36-33)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+11-18)
  • (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+6)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+46)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+487-157)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+23-4)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (+8-7)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir (+2-1)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+6)
  • (renamed) mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir (+170-95)
  • (modified) mlir/test/Dialect/ArmSME/canonicalize.mlir (+4-6)
  • (removed) mlir/test/Dialect/ArmSME/cse.mlir (-30)
  • (modified) mlir/test/Dialect/ArmSME/enable-arm-za.mlir (+6-14)
  • (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+4-3)
  • (modified) mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir (+5-7)
  • (added) mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir (+269)
  • (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+30-15)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir (+5-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir (+2-1)
  • (modified) mlir/test/lib/Dialect/ArmSME/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (+11-8)
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab499983..403f811a2569a0 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
 class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
 void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..e6d678dc1b12b3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
 // ArmSMEToLLVM
 //===----------------------------------------------------------------------===//
 
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
   let summary = "Lower the operations from the ArmSME dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
     "arm_sme::ArmSMEDialect",
     "LLVM::LLVMDialect"
   ];
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a74..dac54712c7f47a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 00000000000000..f31062d8c25ed7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,28 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+static constexpr unsigned kInMemoryTileIdBase = 16;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d2a..9178655f010c9a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
 
 def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   let description = [{
-    An interface for operations that use or allocate Arm SME tiles. These
-    operations need to be assigned a tile ID, an i32 attribute, which specifies
-    which virtual tile within the ZA storage to use. The number of tiles
-    available depends on the type of the tile. This is summarized below:
+    An interface for operations that use Arm SME tiles. These operations need to
+    be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+    within the ZA storage to use. The number of tiles available depends on the
+    type of the tile. This is summarized below:
 
     | Tile Vector Types                                                       | Possible Tile IDs   |
     |-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
     | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
     | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
     | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
-
-    Operations that allocate a new tile (such as arm_sme.get_tile), are used as
-    the roots for tile allocation, with all operations that (transitively)
-    depend on a root being assigned the same tile ID.
   }];
   let methods = [
     InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
         return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
       }]
     >,
-    InterfaceMethod<
-      [{
-        The type of tile this operation allocates. Returns none (std::nullopt)
-        if this operation does not allocate a tile.
-      }],
-      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
-      /*methodName=*/"getAllocatedTileType",
-      /*arguments=*/(ins),
-      /*methodBody=*/[{}],
-      /*defaultImpl=*/ [{
-        // This operation does not allocate a tile.
-        return std::nullopt;
-      }]
-    >,
     InterfaceMethod<
       "Returns the VectorType of the tile used by this operation.",
       /*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   ];
 
   let extraSharedClassDeclaration = [{
-    // A helper to create a new operation and propagate this operations tile ID.
-    template<typename T, typename... Args>
-    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
-      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
-      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
-        tileOp.setTileId($_op.getTileId());
-      return op;
-    }
-
-    // A helper to replace this operation and forward its tile ID (if present).
-    template<typename T, typename... Args>
-    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
-      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
-      rewriter.replaceOp($_op, newOp);
-      return newOp;
-    }
-
     bool isInMemoryTile() {
       auto tileId = getTileId();
       return tileId && tileId.getInt() >= kInMemoryTileIdBase;
     }
   }];
 
-  let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+  let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
-  let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates an undefined value of SME virtual tile type";
   let description = [{
-    Allocates a new SME "virtual tile" within a function. The contents of the
-    tile returned from this operation are undefined.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are undefined.
 
     Example 1:
 
     ```mlir
-    // Allocate an 8-bit element "virtual tile"
+    // Create an 8-bit element "virtual tile" value:
     %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
     Example 2:
 
     ```mlir
-    // Allocate two 16-bit element "virtual tiles"
+    // Create two 16-bit element "virtual tiles" values:
     %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
     Example 3:
     ```mlir
-    // Allocate an 128-bit element "virtual tile"
+    // Create an 128-bit element "virtual tile" value:
     %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
   }];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
     VectorType getTileType() {
       return ::llvm::cast<VectorType>(getTile().getType());
     }
-
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getTileType());
-    }
-  }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
-  let summary = "SME tile placeholder";
-  let description = [{
-    A placeholder to preserve dataflow while lowering to SME intrinsics (which
-    do not take or return SME virtual tile values). This operation is intended
-    to be DCE'd once all ArmSME operations have been lowered.
-
-    This operation is not intended to be used outside of the ArmSME -> LLVM
-    conversion.
   }];
-  let results = (outs SMETile:$tile);
-  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
-  let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates a zero-initialized value of SME virtual tile type";
   let results = (outs SMETile:$res);
   let description = [{
-    Initialise ZA with 0. This operation is convenient wrapper for the SME
-    `zero` intrinsic and instruction.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are zero-initialized.
 
     Example 1: Zero an 8-bit element ZA tile.
 
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+  Pure,
+  ArmSMETileOpInterface,
+  AllTypesMatch<["tile", "result"]>
+]> {
+  let summary = "Copies an SME tile value";
+  let arguments = (ins SMETile:$tile);
+  let results = (outs SMETile:$result);
+  let description = [{
+    Copies an SME "virtual tile" value to a new SSA value. This operation is
+    primarily intended to be used to normalize the IR prior to tile allocation.
+
+    Example:
+
+    ```mlir
+    %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+  let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
 def TileLoadOp : ArmSME_Op<"tile_load", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
 }
 
 def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
                                list<Type> allowedResultVectorTypes,
                                int numOuterProducts> :
   ArmSME_Op<mnemonic, [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874ec..156744ba57e7b2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
     const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
 
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
 std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e89267..b9d74fec6756e3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,16 +124,21 @@ def EnableArmStreaming
   let dependentDialects = ["func::FuncDialect"];
 }
 
-def TileAllocation
-    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
-  let summary = "Allocate SME tiles";
+def TestTileAllocation
+    : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+  let summary = "Tests SME tile allocation";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
     'func.func' op level, and assigns tile IDs (via an attribute) to all ops
-    that implement the `ArmSMETileOpInterface`. An error will be emitted when
-    there's no tiles left.
+    that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+    to be used for testing, tile allocation is done as part of the ArmSME to
+    LLVM conversion.
   }];
-  let constructor = "mlir::arm_sme::createTileAllocationPass()";
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
   let dependentDialects = ["func::FuncDialect"];
 }
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e69992..a25b844f01eaa6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_H
 
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
 namespace mlir {
 
 class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
 class RewritePatternSet;
 
 namespace arm_sme {
+
 void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+                               bool dumpRanges = false);
+
 } // namespace arm_sme
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92f5..9ea1c5a5d63fe5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include <optional>
 
 namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+inline bool isValidSMETileVectorType(Type type) {
+  auto vType = dyn_cast<VectorType>(type);
+  return vType && isValidSMETileVectorType(vType);
+}
+
 /// Returns the type of SME tile this vector type corresponds to, or none if the
 /// vector type does not fit within an SME tile.
 std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,19 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
 /// Creates a vector type for the SME tile of `elementType`.
 VectorType getSMETileTypeForElement(Type elementType);
 
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+                               FunctionOpInterfa...
[truncated]

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the patch Ben, this is impressive work and a significant improvement to SME support in MLIR.

I haven't reviewed the whole patch yet but I've left some initial comments for now.

Comment on lines 9 to 11
// This transform allocates SME tiles at the 'func.func' op level for ArmSME
// operations. It does this using a 16-bit tile mask that has a bit for each
// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, could we have more documentation in this file? For example, a high level overview of how tile allocation works (live range calculation, followed by transformation). With some info for every section.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MacDue , I'm going over allocateSMETiles and my first reaction is that step one, // 1. Insert copy operations at branch operations., could easily be a separate pass. That would allow us for more thorough testing - right now it's either all or nothing. Could you split it out? WDYT?

auto parent = user->getParentOp();
traverseCorrespondingValues(user->getOperands(),
parent->getResults());
rewriter.setInsertionPoint(terminator);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What restores the insertion point? Use insertion guard?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to restore the insertion as every use of the rewriter sets the insertion point (it would not really make sense to assume it's at an point in particular with how it's used).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you are referring to what's happening in this method today. However, this method might get updated and also the same rewriter is share between other hooks. Would adding an insertion guard do any harm?

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM assuming splitCondBranches gets tested. Accepting now before I go on holiday. Once again Ben excellent work! 🙏

@MacDue MacDue force-pushed the new_tile_alloc branch from f7dc1f9 to 54a8e7d Compare May 1, 2024 15:44
auto parent = user->getParentOp();
traverseCorrespondingValues(user->getOperands(),
parent->getResults());
rewriter.setInsertionPoint(terminator);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you are referring to what's happening in this method today. However, this method might get updated and also the same rewriter is share between other hooks. Would adding an insertion guard do any harm?

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, thanks.

@@ -1,19 +1,15 @@
// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics
// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal is to test tile allocation, maybe we prefer to use the post scf-to-cf as input directly instead of relying on convert-to-scf as well ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having tests in CF too would be good (since this is an algorithm on CFGs). I've used SCF in these tests as it's easier to write/read (but cases like this one, which is just a branch, can easily be written in CF).

@MacDue MacDue force-pushed the new_tile_alloc branch from 1d9c791 to fc96a7a Compare May 8, 2024 10:39
@MacDue MacDue force-pushed the new_tile_alloc branch 2 times, most recently from abd40ec to a91685a Compare May 9, 2024 16:10
MacDue added 16 commits May 13, 2024 10:04
This patch rewrites the ArmSME tile allocator to use liveness
information to make better tile allocation decisions and improve the
correctness of the ArmSME dialect. This algorithm used here is a linear
scan over live ranges, where live ranges are assigned to tiles as they
appear in the program (chronologically). Live ranges release their
assigned tile ID when the current program point is passed their end.
This is a greedy algorithm (which is mainly to keep the implementation
relatively straightforward), and because it seems to be sufficient for
most kernels (e.g. matmuls) that use ArmSME. The general steps of this
are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf,
though there have been a few simplifications and assumptions made for
our use case.

Hopefully, the only changes needed for a user of the ArmSME dialect is
that:

- `-allocate-arm-sme-tiles` will no longer be a standalone pass
  - `-test-arm-sme-tile-allocation` is only for unit tests
- `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf`
  - SME tile allocation is now part of the LLVM conversion

By integrating this into the `ArmSME -> LLVM` conversion we can allow
high-level (value-based) ArmSME operations to be side-effect-free, as we
can guarantee nothing will rearrange ArmSME operations before we emit
intrinsics (which could invalidate the tile allocation).

The hope is for ArmSME operations to have no hidden state/side effects
and allow easily lowering dialects such as `vector` and `arith` to SME,
without making assumptions about how the input IR looks, as the
semantics of the operations will be the same. That is no (new) side
effects and the IR follows the rules of SSA (a value will never change).

The aim is correctness, so we have a base for working on optimizations.
Review fixups, and fixes two bugs that could occur if the spilled tile
type did not match tile type of the new live range. This is now covered
by two new test cases.
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive work Ben, thank you for addressing all my comments! 🙏🏻

This is a very well designed and complex bit of logic, but absolutely necessary and 100% the right solution for SME virtual tile allocation. This is probably the most important building block for supporting SME in MLIR - vital element in future-proofing this work. Thank you so much for seeing this through!

-Andrzej

@MacDue MacDue merged commit 041baf2 into llvm:main May 14, 2024
3 of 4 checks passed
MacDue added a commit that referenced this pull request Aug 7, 2024
…ges (#102125)

- Use vector.interleave rather than the LLVM intrinsic
- Remove dependency on LLVM dialect
- Remove manual outerproduct erases (these are now trivially dead)
- Remove comment explaining issues with previous tile allocator
- Update pipeline in `multi-tile-matmul-mixed-types.mlir`

Recent changes: #90448, #80965
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants