diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 37a2257a0015c..8a34472d3b3a2 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith< "::llvm::cast($_self).getElementType())" ".getWidth())">; +class HasMatchingMaskTypeConstraint : + OptionalTypesMatchWith< + mask # " has i1 element type and same shape as " # vector, + vector, mask, + "::llvm::cast($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">; + //===----------------------------------------------------------------------===// // ArmSME attr definitions //===----------------------------------------------------------------------===// @@ -259,14 +265,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ "result", "padding", "::llvm::cast($_self).getElementType()" >, - OptionalTypesMatchWith< - "mask has i1 element type and same shape as result", - "result", "mask", - "VectorType(" - "VectorType::Builder(" - "::llvm::cast($_self)" - ").setElementType(IntegerType::get($_self.getContext(), 1)))" - >, + HasMatchingMaskTypeConstraint<"result", "mask">, PredOpTrait< "both `padding` and `mask` should be provided or neither", CPred<"bool(getPadding()) == bool(getMask())"> @@ -345,7 +344,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ "attr-dict `:` type($base) `,` type($result)"; } -def TileStoreOp : ArmSME_Op<"tile_store"> { +def TileStoreOp : ArmSME_Op<"tile_store", [ + AttrSizedOperandSegments, + HasMatchingMaskTypeConstraint<"valueToStore", "mask">, +]> { let summary = "Tile store operation"; let description = [{ Stores a 2D SME "virtual tile" to memory defined by a base and indices, @@ -356,6 +358,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { rank 2 with dynamic dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An optional `mask` may be provided, the shape of which corresponds to the + `tile`, and selects which elements of the tile will be stored. + Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B). ```mlir arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref @@ -370,10 +375,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { ```mlir arm_sme.tile_store %tile, %base[%c0, %c0] layout : vector<[1]x[1]xi128>, memref ``` + + Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory. + ```mlir + arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout : vector<[4]x[4]xf32>, memref + ``` }]; let arguments = (ins SMETile:$valueToStore, Arg:$base, - Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout + Variadic:$indices, Optional:$mask, + ArmSME_TileSliceLayoutAttr:$layout ); let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -384,9 +395,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { } }]; + let builders = [ + OpBuilder<(ins "Value":$valueToStore, "Value":$base, + "ValueRange":$indices), [{ + build($_builder, $_state, valueToStore, base, indices, {}); + }]>, + ]; + let assemblyFormat = - "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict " - "`:` type($base) `,` type($valueToStore)"; + "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?" + "attr-dict `:` type($base) `,` type($valueToStore)"; } def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ @@ -595,12 +613,6 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure, }]; } -class HasMatchingMaskTypeConstraint : - OptionalTypesMatchWith< - "shape of `" # operand # "Mask` matches `" # operand # "`", - operand, operand # "Mask", - "::llvm::cast($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">; - class OuterProductResultTileTypeConstraint : OptionalTypesMatchWith, - HasMatchingMaskTypeConstraint<"lhs">, - HasMatchingMaskTypeConstraint<"rhs">, + HasMatchingMaskTypeConstraint<"lhs", "lhsMask">, + HasMatchingMaskTypeConstraint<"rhs", "rhsMask">, PredOpTrait< "both `lhsMask` and `rhsMask` should be provided or neither", CPred<"bool(getLhsMask()) == bool(getRhsMask())">>, diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index b60c21e2ced7a..005dd546bf163 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -144,8 +144,8 @@ struct TransferWriteToArmSMELowering return failure(); rewriter.replaceOpWithNewOp( - writeOp, writeOp.getVector(), writeOp.getSource(), - writeOp.getIndices()); + writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), + writeOp.getMask()); return success(); } }; diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 1d6386bbf3828..588b8e891fadd 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -164,6 +164,20 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : return } +//===----------------------------------------------------------------------===// +// arm_sme.tile_store +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}} + arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.outerproduct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 6d0aa48015c14..8206bd80e3eb4 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre // ----- +func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref, %mask : vector<[4]x[4]xi1>) { + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout : memref, vector<[4]x[4]xf32> + return +} + +// ----- + /// Layout is optional and horizontal is the default, verify it's still parsed. func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref) { // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index 9eb7cd143e5b5..5f41313fc6ac7 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -315,6 +315,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref, +// CHECK-SAME: %[[DEST:.*]]: memref, +// CHECK-SAME: %[[MASK:.*]]: vector<[2]x[2]xi1>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] : memref, vector<[2]x[2]xf64> +func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref, %mask : vector<[2]x[2]xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref + return +} + +// ----- + // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' // lowering only occurs for vector types of correct rank, shape, element size // and number of scalable dims.