Skip to content
Closed
178 changes: 143 additions & 35 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}

def TileLoadOp : ArmSME_Op<"tile_load"> {
def TileLoadOp : ArmSME_Op<"tile_load", [
AttrSizedOperandSegments,
TypesMatchWith<
"padding type matches element type of result (if present)",
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()",
"!getPadding() || std::equal_to<>()"
>,
TypesMatchWith<
"mask has i1 element type and same shape as result (if present)",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").setElementType(IntegerType::get($_self.getContext(), 1)))",
"!getMask() || std::equal_to<>()"
>
]> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
Expand All @@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
dimensions, since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An optional SSA value `padding` of the same elemental type as the MemRef is
provided to specify a fallback value in the case of masking.

An optional SSA value `mask` may be specified to mask out elements read
from the MemRef. The `mask` type is an `i1` vector with a shape that
matches how elements are read from the MemRef. Elements whose corresponding
mask element is `0` are masked out and replaced with `padding`.

If either `padding` or `mask` are specified, both must be specified.

Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
Expand All @@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```

Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
Expand All @@ -273,12 +306,34 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}
}];

let builders = [
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
"ValueRange":$indices, "TileSliceLayout":$layout), [{
build($_builder, $_state, resultType, base, indices, {}, {}, layout);
}]>,
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
"ValueRange":$indices), [{
build($_builder, $_state, resultType, base, indices, {}, {}, {});
}]>,
];

let assemblyFormat =
"$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($result)";
"$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
"attr-dict `:` type($base) `,` type($result)";
}

def TileStoreOp : ArmSME_Op<"tile_store"> {
def TileStoreOp : ArmSME_Op<"tile_store", [
AttrSizedOperandSegments,
TypesMatchWith<
"mask has i1 element type and same shape as value to store (if present)",
"valueToStore", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").setElementType(IntegerType::get($_self.getContext(), 1)))",
"!getMask() || std::equal_to<>()"
>
]> {
let summary = "Tile store operation";
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
Expand All @@ -289,6 +344,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.

An optional SSA value `mask` may be specified to mask out elements written
to the MemRef. The `mask` type is an `i1` vector of the same shape as the
vector type that matches how elements are written into the MemRef. Elements
whose corresponding mask element is `0` are masked out.

Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
Expand All @@ -303,10 +363,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```

Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
Variadic<Index>:$indices, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand All @@ -317,13 +383,28 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];

let builders = [
OpBuilder<(ins "Value":$valueToStore, "Value":$base,
"ValueRange":$indices), [{
build($_builder, $_state, valueToStore, base, indices, {});
}]>,
];

let assemblyFormat =
"$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($valueToStore)";
"$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
"attr-dict `:` type($base) `,` type($valueToStore)";
}

def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
AllTypesMatch<["tile", "result"]>
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"mask has i1 element type and same shape as result",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
")">,
]> {
let summary = "Tile slice load and update operation";
let description = [{
Expand All @@ -339,23 +420,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An SSA value `mask` specifies to mask out elements read from the MemRef.
The `mask` type is an `i1` vector with a shape that matches how elements
are read from the MemRef.

Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
```

Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
```

Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
Arg<AnyMemRef, "the reference to load from">:$base, AnyVector:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
Expand All @@ -371,12 +456,22 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];

let assemblyFormat = [{
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
(`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
type($result)
}];
}

def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
TypesMatchWith<
"mask has i1 element type and same shape as tile slice",
"tile", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
")">
]> {
let summary = "Tile slice store operation";
let description = [{
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
Expand All @@ -391,22 +486,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.

An SSA value `mask` specifies to mask out elements written to the MemRef.
The `mask` type is an `i1` vector with a shape that matches how elements
are written to the MemRef.

Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8>
```

Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32>
```

Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
let arguments = (ins
SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
);
Expand All @@ -420,8 +520,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];

let assemblyFormat = [{
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($tile)
$tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($mask) `,` type($tile)
}];
}

Expand All @@ -441,21 +541,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
of a 2-D scalable vector tile at the given index. The type of the 1-D
scalable vector to be moved must match the type of the tile slice. A tile
slice is a 1-D vector of horizontally or vertically contiguous elements
within a ZA tile. Horizontal tile slices are currently assumed when
lowering to intrinsics. The updated tile is returned as the result.
within a ZA tile. The updated tile is returned as the result.

An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Move a vector<[16]xi8> into tile at given index.
Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
```

Example 2: Move a vector<[2]xf64> into tile at given index.
Example 2: Move a vector<[2]xf64> into tile vertically at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
```
}];
let arguments = (ins
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
Expand All @@ -465,7 +568,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];

let assemblyFormat = [{
$vector `,` $tile `,` $tile_slice_index
$vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($vector) `into` type($result)
}];
}
Expand All @@ -480,29 +583,34 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
scalable tile at the given index. A tile slice is a 1-D vector of
horizontally or vertically contiguous elements within a ZA tile. Horizontal
tile slices are currently assumed when lowering to intrinsics.
horizontally or vertically contiguous elements within a ZA tile.

An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Extract `vector<[16]xi8>` from tile at the given index.
Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
```

Example 2: Extract `vector<[2]xf64>` from tile at the given index.
Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
```
}];

let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
let arguments = (ins
SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SVEVector:$result);

let extraClassDeclaration = [{
VectorType getSliceType() { return getResult().getType(); }
}];

let assemblyFormat = [{
$tile `[` $tile_slice_index `]` attr-dict
$tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
`:` type($result) `from` type($tile)
}];
}
Expand Down
Loading