diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 9b9dbff10ea2d..b30d0fdb866bd 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -231,7 +231,26 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> { let assemblyFormat = "attr-dict `:` type($res)"; } -def TileLoadOp : ArmSME_Op<"tile_load"> { +def TileLoadOp : ArmSME_Op<"tile_load", [ + AttrSizedOperandSegments, + OptionalTypesMatchWith< + "padding type matches element type of result", + "result", "padding", + "::llvm::cast($_self).getElementType()" + >, + OptionalTypesMatchWith< + "mask has i1 element type and same shape as result", + "result", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").setElementType(IntegerType::get($_self.getContext(), 1)))" + >, + PredOpTrait< + "both `padding` and `mask` should be provided or neither", + CPred<"bool(getPadding()) == bool(getMask())"> + >, +]> { let summary = "Tile load operation"; let description = [{ Loads a 2D SME "virtual tile" from memory defined by a base and indices, @@ -242,6 +261,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An optional SSA value `padding` of the same elemental type as the MemRef is + provided to specify a fallback value in the case of masking. + + An optional SSA value `mask` may be specified to mask out elements read + from the MemRef. The `mask` type is an `i1` vector with a shape that + matches how elements are read from the MemRef. Elements whose corresponding + mask element is `0` are masked out and replaced with `padding`. + + If either `padding` or `mask` are specified, both must be specified. + Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B). ```mlir %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[16]x[16]xi8> @@ -256,10 +285,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { ```mlir %tile = arm_sme.tile_load %base[%c0, %c0] layout : memref, vector<[1]x[1]xi128> ``` + + Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory. + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref, vector<[4]x[4]xf32> + ``` }]; let arguments = (ins Arg:$base, Variadic:$indices, + Optional:$padding, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); let results = (outs SMETile:$result); @@ -273,9 +308,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { } }]; + let builders = [ + OpBuilder<(ins "VectorType":$resultType, "Value":$base, + "ValueRange":$indices, "TileSliceLayout":$layout), [{ + build($_builder, $_state, resultType, base, indices, {}, {}, layout); + }]>, + OpBuilder<(ins "VectorType":$resultType, "Value":$base, + "ValueRange":$indices), [{ + build($_builder, $_state, resultType, base, indices, {}, {}, {}); + }]>, + ]; + let assemblyFormat = - "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict " - "`:` type($base) `,` type($result)"; + "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?" + "attr-dict `:` type($base) `,` type($result)"; } def TileStoreOp : ArmSME_Op<"tile_store"> { diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 431009b1b9ede..25c62f78d8435 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +//===----------------------------------------------------------------------===// +// arm_sme.cast_tile_to_vector +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> { @@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1 return %0 : vector<[4]x[16]xi8> } +//===----------------------------------------------------------------------===// +// arm_sme.cast_vector_to_tile +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 { @@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) - return %0 : i8 } +//===----------------------------------------------------------------------===// +// arm_sme.get_tile_id +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_get_tile_id__bad_type() -> i1 { @@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 { return %0 : i1 } +//===----------------------------------------------------------------------===// +// arm_sme.move_vector_to_tile_slice +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> { @@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect return %0 : vector<[4]x[4]xf32> } +//===----------------------------------------------------------------------===// +// arm_sme.move_tile_slice_to_vector +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> { @@ -97,3 +117,36 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4] %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32> return %0 : vector<[2]xf64> } + +//===----------------------------------------------------------------------===// +// arm_sme.tile_load +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_tile_load__bad_padding_type(%src : memref, %pad : f32, %mask : vector<[2]x[2]xi1>) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + +func.func @arm_sme_tile_load__bad_mask_type(%src : memref, %pad : f64, %mask : vector<[4]x[4]xi1>) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + +func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that both `padding` and `mask` should be provided or neither}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref, vector<[2]x[2]xf64> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index e5ba81eff8360..6866137267dc6 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref) { // ----- +/// Padding and mask are optional +func.func @arm_sme_tile_load_hor_pad_f64(%src : memref, %pad : f64, %mask : vector<[2]x[2]xi1>) { + // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + /// Layout is optional and horizontal is the default, verify it's still parsed. func.func @arm_sme_tile_load_explicit_hor(%src : memref) { // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8>