Skip to content

[mlir][ArmSME] Use liveness information in the tile allocator #90448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <memory>

#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir {
class Pass;
Expand All @@ -21,7 +22,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"

/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
std::unique_ptr<Pass>
createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);

/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1285,14 +1285,19 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
// ArmSMEToLLVM
//===----------------------------------------------------------------------===//

def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
let dependentDialects = [
"arm_sme::ArmSMEDialect",
"LLVM::LLVMDialect"
];
let options = [
Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
"bool", /*default=*/"false",
"Dump the live ranges of SME tiles (for debugging)">
];
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 1 addition & 5 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand All @@ -24,11 +25,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

namespace mlir::arm_sme {
static constexpr unsigned kInMemoryTileIdBase = 16;
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
} // namespace mlir::arm_sme

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"

Expand Down
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
#define MLIR_DIALECT_ARMSME_OPINTERFACES_H

#include "mlir/Dialect/Vector/IR/VectorOps.h"

namespace mlir::arm_sme {

namespace detail {
LogicalResult verifyArmSMETileOpInterface(Operation *);
}

// The first in-memory SME tile ID. This is set to 16 as that is the first tile
// ID larger than any virtual tile ID supported by the SME ISA.
static constexpr unsigned kInMemoryTileIdBase = 16;

#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
} // namespace mlir::arm_sme

#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
141 changes: 47 additions & 94 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",

def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
let description = [{
An interface for operations that use or allocate Arm SME tiles. These
operations need to be assigned a tile ID, an i32 attribute, which specifies
which virtual tile within the ZA storage to use. The number of tiles
available depends on the type of the tile. This is summarized below:
An interface for operations that use Arm SME tiles. These operations need to
be assigned a tile ID, an i32 attribute, which specifies which virtual tile
within the ZA storage to use. The number of tiles available depends on the
type of the tile. This is summarized below:

| Tile Vector Types | Possible Tile IDs |
|-------------------------------------------------------------------------|---------------------|
Expand All @@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
| `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |

Operations that allocate a new tile (such as arm_sme.get_tile), are used as
the roots for tile allocation, with all operations that (transitively)
depend on a root being assigned the same tile ID.
}];
let methods = [
InterfaceMethod<
Expand Down Expand Up @@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
}]
>,
InterfaceMethod<
[{
The type of tile this operation allocates. Returns none (std::nullopt)
if this operation does not allocate a tile.
}],
/*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
/*methodName=*/"getAllocatedTileType",
/*arguments=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/ [{
// This operation does not allocate a tile.
return std::nullopt;
}]
>,
InterfaceMethod<
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
Expand All @@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
];

let extraSharedClassDeclaration = [{
// A helper to create a new operation and propagate this operations tile ID.
template<typename T, typename... Args>
T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
tileOp.setTileId($_op.getTileId());
return op;
}

// A helper to replace this operation and forward its tile ID (if present).
template<typename T, typename... Args>
T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
rewriter.replaceOp($_op, newOp);
return newOp;
}

bool isInMemoryTile() {
auto tileId = getTileId();
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
}
}];

let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
Op<ArmSME_Dialect, mnemonic, traits> {}

def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
let summary = "Returns a SME virtual tile";
def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
let summary = "Creates an undefined value of SME virtual tile type";
let description = [{
Allocates a new SME "virtual tile" within a function. The contents of the
tile returned from this operation are undefined.
Creates a new SME "virtual tile" value within a function. The contents of
the tile returned from this operation are undefined.

Example 1:

```mlir
// Allocate an 8-bit element "virtual tile"
// Create an 8-bit element "virtual tile" value:
%za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
```

Example 2:

```mlir
// Allocate two 16-bit element "virtual tiles"
// Create two 16-bit element "virtual tiles" values:
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
```

Example 3:
```mlir
// Allocate an 128-bit element "virtual tile"
// Create an 128-bit element "virtual tile" value:
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
```
}];
Expand All @@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
VectorType getTileType() {
return ::llvm::cast<VectorType>(getTile().getType());
}

std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getTileType());
}
}];
}

def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
let summary = "SME tile placeholder";
let description = [{
A placeholder to preserve dataflow while lowering to SME intrinsics (which
do not take or return SME virtual tile values). This operation is intended
to be DCE'd once all ArmSME operations have been lowered.

This operation is not intended to be used outside of the ArmSME -> LLVM
conversion.
}];
let results = (outs SMETile:$tile);
let assemblyFormat = "attr-dict `:` type($tile)";
}

//
// Tile reset.
//

def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
let summary = "Initialize the two-dimensional ZA array with 0s";
def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
let summary = "Creates a zero-initialized value of SME virtual tile type";
let results = (outs SMETile:$res);
let description = [{
Initialise ZA with 0. This operation is convenient wrapper for the SME
`zero` intrinsic and instruction.
Creates a new SME "virtual tile" value within a function. The contents of
the tile returned from this operation are zero-initialized.

Example 1: Zero an 8-bit element ZA tile.

Expand All @@ -338,16 +281,39 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getVectorType());
}
VectorType getTileType() {
return getVectorType();
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
}

def CopyTileOp : ArmSME_Op<"copy_tile", [
Pure,
ArmSMETileOpInterface,
AllTypesMatch<["tile", "result"]>
]> {
let summary = "Copies an SME tile value";
let arguments = (ins SMETile:$tile);
let results = (outs SMETile:$result);
let description = [{
Copies an SME "virtual tile" value to a new SSA value. This operation is
primarily intended to be used to normalize the IR prior to tile allocation.

Example:

```mlir
%copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
```
}];
let extraClassDeclaration = [{
VectorType getTileType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
}];
let assemblyFormat = "$tile attr-dict `:` type($result)";
}

def TileLoadOp : ArmSME_Op<"tile_load", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
Expand Down Expand Up @@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getVectorType());
}
VectorType getTileType() {
return getVectorType();
}
Expand Down Expand Up @@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
Expand Down Expand Up @@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
}

def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
ArmSMETileOpInterface,
ArmSMETileOpInterface, Pure,
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"type of 'vector' matches type of 'tile' slice",
Expand Down Expand Up @@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}

def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
ArmSMETileOpInterface,
ArmSMETileOpInterface, Pure,
TypesMatchWith<
"type of 'result' matches type of 'tile' slice",
"tile", "result",
Expand Down Expand Up @@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :

def OuterProductOp :
ArmSME_Op<"outerproduct", [
Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
Expand Down Expand Up @@ -802,12 +766,6 @@ let arguments = (ins
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
// The outerproduct op allocates a new tile if no accumulator is passed.
if (!getAcc())
return arm_sme::getSMETileType(getResultType());
return std::nullopt;
}
VectorType getTileType() {
return getResultType();
}
Expand All @@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
list<Type> allowedResultVectorTypes,
int numOuterProducts> :
ArmSME_Op<mnemonic, [
Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
Expand Down Expand Up @@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
// The outerproduct op allocates a new tile if no accumulator is passed.
if (!getAcc())
return arm_sme::getSMETileType(getResultType());
return std::nullopt;
}
VectorType getTileType() {
return getResultType();
}
Expand Down
3 changes: 0 additions & 3 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);

/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();

/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();
Expand Down
Loading
Loading