From 26bb9059390d42812c55bc126590a8b19e2b7f49 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 13:39:18 +0000 Subject: [PATCH 1/2] [mlir][ArmSME] Add mask operand to store_tile_slice --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 40 +++--- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 9 +- .../Transforms/LegalizeForLLVMExport.cpp | 28 ++--- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 3 +- mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 76 ++++++------ mlir/test/Dialect/ArmSME/invalid.mlir | 16 ++- mlir/test/Dialect/ArmSME/roundtrip.mlir | 114 +++++++++--------- 7 files changed, 155 insertions(+), 131 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 8a34472d3b3a2..beacf7ee91e80 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -66,6 +66,15 @@ class HasMatchingMaskTypeConstraint : vector, mask, "::llvm::cast($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">; +class TileSliceMaskConstraint : + TypesMatchWith< + "`" # mask # "` has i1 element type and the shape is a slice of `" # tile # "`", + tile, mask, + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1)))">; + //===----------------------------------------------------------------------===// // ArmSME attr definitions //===----------------------------------------------------------------------===// @@ -408,15 +417,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ } def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ - AllTypesMatch<["tile", "result"]>, - TypesMatchWith< - "mask has i1 element type and is a slice of the result", - "result", "mask", - "VectorType(" - "VectorType::Builder(" - "::llvm::cast($_self)" - ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))" - ")">, + AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask"> ]> { let summary = "Tile slice load and update operation"; let description = [{ @@ -474,7 +475,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ }]; } -def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { +def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ + TileSliceMaskConstraint<"tile", "mask"> +]> { let summary = "Tile slice store operation"; let description = [{ Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile @@ -489,22 +492,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the input tile. + An SSA value `mask` specifies to mask out elements written to the MemRef. + The `mask` type is an `i1` vector with a shape that matches how elements + are written to the MemRef. + Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref ``` Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout : vector<[4]x[4]xf32>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout : vector<[4]x[4]xf32>, vector<[4]xi1>, memref ``` Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout : vector<[1]x[1]xi128>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout : vector<[1]x[1]xi128>, vector<[1]xi1>, memref ``` }]; - let arguments = (ins SMETile:$tile, Index:$tile_slice_index, + let arguments = (ins + SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask, Arg:$base, Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout ); @@ -518,8 +526,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { }]; let assemblyFormat = [{ - $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)? - attr-dict `:` type($base) `,` type($tile) + $tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)? + attr-dict `:` type($base) `,` type($mask) `,` type($tile) }]; } diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 50cc818f1ffc0..80da6ffda1ed2 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -190,6 +190,12 @@ struct TileStoreOpConversion : public OpRewritePattern { rewriter.setInsertionPointToStart(forOp.getBody()); + // Create an 'all true' predicate for the tile slice. + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + auto allTruePredicate = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + SmallVector memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); getMemrefIndices(tileStoreOp.getIndices(), @@ -197,7 +203,8 @@ struct TileStoreOpConversion : public OpRewritePattern { numTileSlices, memrefIndices, loc, rewriter); rewriter.replaceOpWithNewOp( tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, - tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout()); + allTruePredicate, tileStoreOp.getBase(), memrefIndices, + tileStoreOp.getLayout()); return success(); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 7dd04e25075c8..d1a54658a595b 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -278,13 +278,7 @@ struct StoreTileSliceToArmSMELowering auto tileSliceI32 = rewriter.create( loc, rewriter.getI32Type(), tileSlice); - // Create all active predicate mask. - auto one = rewriter.create( - loc, rewriter.getI1Type(), - rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); - auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), - /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create(loc, predTy, one); + auto maskOp = storeTileSliceOp.getMask(); Value tileI32 = castTileIDToI32(tile, loc, rewriter); arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout(); @@ -295,23 +289,23 @@ struct StoreTileSliceToArmSMELowering llvm_unreachable("unexpected element type!"); case 8: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 16: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 32: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 64: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 128: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; } } else { @@ -320,23 +314,23 @@ struct StoreTileSliceToArmSMELowering llvm_unreachable("unexpected element type!"); case 8: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 16: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 32: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 64: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 128: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; } } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 3fb320c0d219e..d61f588941b40 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -48,8 +48,9 @@ func.func @arm_sme_tile_load_ver(%src : memref) { // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK: %[[PTRUE_S:.*]] = arith.constant dense : vector<[4]xi1> // CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]x[4]xi32> +// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref) { %c0 = arith.constant 0 : index arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 30ddb3c468601..8fdcf69958244 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -209,8 +209,8 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vecto // CHECK-LABEL: func.func @arm_sme_store_tile_slice_hor_i8( // CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, // CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>, // CHECK-SAME: %[[DEST:.*]]: memref) { -// CHECK: %[[PTRUE_B:.*]] = arith.constant dense : vector<[16]xi1> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 @@ -221,12 +221,12 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vecto // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 -// CHECK: "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK: "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK: return // CHECK: } -func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -234,9 +234,9 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -244,9 +244,9 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -254,9 +254,9 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -264,9 +264,9 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128 // CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -274,9 +274,9 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -284,9 +284,9 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -294,9 +294,9 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -304,9 +304,9 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } @@ -314,9 +314,9 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8 // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -324,9 +324,9 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -334,9 +334,9 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -344,9 +344,9 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -354,9 +354,9 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128 // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -364,9 +364,9 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -374,9 +374,9 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -384,9 +384,9 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -394,9 +394,9 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 588b8e891fadd..58d7b8f361a23 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -159,7 +159,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64 func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - // expected-error@+1 {{op failed to verify that mask has i1 element type and is a slice of the result}} + // expected-error@+1 {{op failed to verify that `mask` has i1 element type and the shape is a slice of `result`}} %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[16]x[16]xi8> return } @@ -178,6 +178,20 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask return } +//===----------------------------------------------------------------------===// +// arm_sme.store_tile_slice +//===----------------------------------------------------------------------===// + + +// ----- + +func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that `mask` has i1 element type and the shape is a slice of `tile`}} + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.outerproduct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 8206bd80e3eb4..1dbcc32e9259f 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -823,173 +823,173 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector< // ----- -func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> +func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xi16> +func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xi32> +func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xi64> +func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[1]x[1]xi128> +func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xf16> +func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xf32> +func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xf64> +func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[16]x[16]xi8> +func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xi16> +func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]x[4]xi32> +func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]x[2]xi64> +func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[1]x[1]xi128> +func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xf16> +func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]x[4]xf32> +func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]x[2]xf64> +func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- /// Layout is optional and horizontal is the default, verify it's still parsed. -func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> +func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } From 6c2cf20923dd1cc16b7b110ad266107bd2737d32 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Wed, 1 Nov 2023 13:45:09 +0000 Subject: [PATCH 2/2] address comments --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index beacf7ee91e80..3b8a4cb60f910 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -433,9 +433,8 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the result. - An SSA value `mask` specifies to mask out elements read from the MemRef. - The `mask` type is an `i1` vector with a shape that matches how elements - are read from the MemRef. + The provided `mask` is used to specify which elements of the tile slice + will be loaded. Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index. ```mlir @@ -492,9 +491,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the input tile. - An SSA value `mask` specifies to mask out elements written to the MemRef. - The `mask` type is an `i1` vector with a shape that matches how elements - are written to the MemRef. + The provided `mask` is used to specify which elements of the tile slice + will be stored. Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory. ```mlir @@ -512,7 +510,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ ``` }]; let arguments = (ins - SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask, + SMETile:$tile, Index:$tile_slice_index, SVEPredicate:$mask, Arg:$base, Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout );