From 73a31af5309a309a2066dcc36828b43c04b5ac7c Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 08:09:57 +0000 Subject: [PATCH 1/9] [mlir][ArmSME] Add optional padding and mask operands to tile_load Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read. --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 50 +++++++++++++++++-- mlir/test/Dialect/ArmSME/invalid.mlir | 44 ++++++++++++++++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 10 ++++ 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index dab54b63d8d22..6f6b54aad0058 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> { let assemblyFormat = "attr-dict `:` type($res)"; } -def TileLoadOp : ArmSME_Op<"tile_load"> { +def TileLoadOp : ArmSME_Op<"tile_load", [ + AttrSizedOperandSegments, + TypesMatchWith< + "padding type matches element type of result (if present)", + "result", "padding", + "::llvm::cast($_self).getElementType()", + "!getPadding() || std::equal_to<>()" + >, + TypesMatchWith< + "mask has i1 element type and same shape as result (if present)", + "result", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").setElementType(IntegerType::get($_self.getContext(), 1)))", + "!getMask() || std::equal_to<>()" + > +]> { let summary = "Tile load operation"; let description = [{ Loads a 2D SME "virtual tile" from memory defined by a base and indices, @@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An optional SSA value `padding` of the same elemental type as the MemRef is + provided to specify a fallback value in the case of masking. + + An optional SSA value `mask` may be specified to mask out elements read + from the MemRef. The `mask` type is an `i1` vector with a shape that + matches how elements are read from the MemRef. Elements whose corresponding + mask element is `0` are masked out and replaced with `padding`. + + If either `padding` or `mask` are specified, both must be specified. + Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B). ```mlir %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[16]x[16]xi8> @@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { ```mlir %tile = arm_sme.tile_load %base[%c0, %c0] layout : memref, vector<[1]x[1]xi128> ``` + + Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory. + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref, vector<[4]x[4]xf32> + ``` }]; let arguments = (ins Arg:$base, Variadic:$indices, + Optional:$padding, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); let results = (outs SMETile:$result); @@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { } }]; + let builders = [ + OpBuilder<(ins "VectorType":$resultType, "Value":$base, + "ValueRange":$indices, "TileSliceLayout":$layout), [{ + build($_builder, $_state, resultType, base, indices, {}, {}, layout); + }]>, + OpBuilder<(ins "VectorType":$resultType, "Value":$base, + "ValueRange":$indices), [{ + build($_builder, $_state, resultType, base, indices, {}, {}, {}); + }]>, + ]; + let assemblyFormat = - "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict " - "`:` type($base) `,` type($result)"; + "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?" + "attr-dict `:` type($base) `,` type($result)"; } def TileStoreOp : ArmSME_Op<"tile_store"> { diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 431009b1b9ede..9229f0415c076 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +//===----------------------------------------------------------------------===// +// arm_sme.cast_tile_to_vector +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> { @@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1 return %0 : vector<[4]x[16]xi8> } +//===----------------------------------------------------------------------===// +// arm_sme.cast_vector_to_tile +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 { @@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) - return %0 : i8 } +//===----------------------------------------------------------------------===// +// arm_sme.get_tile_id +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_get_tile_id__bad_type() -> i1 { @@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 { return %0 : i1 } +//===----------------------------------------------------------------------===// +// arm_sme.move_vector_to_tile_slice +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> { @@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect return %0 : vector<[4]x[4]xf32> } +//===----------------------------------------------------------------------===// +// arm_sme.move_tile_slice_to_vector +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> { @@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4] %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32> return %0 : vector<[2]xf64> } + +//===----------------------------------------------------------------------===// +// arm_sme.tile_load +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_tile_load__bad_padding_type(%src : memref, %pad : f32, %mask : vector<[2]x[2]xi1>) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + +func.func @arm_sme_tile_load__bad_mask_type(%src : memref, %pad : f64, %mask : vector<[4]x[4]xi1>) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 427154158e797..f6459f0858436 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref) { // ----- +/// Padding and mask are optional +func.func @arm_sme_tile_load_hor_pad_f64(%src : memref, %pad : f64, %mask : vector<[2]x[2]xi1>) { + // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + /// Layout is optional and horizontal is the default, verify it's still parsed. func.func @arm_sme_tile_load_explicit_hor(%src : memref) { // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> From a8db19de0693461ba51581364d7380ae5fe3e59d Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 08:28:53 +0000 Subject: [PATCH 2/9] [mlir][ArmSME] Add mask operand to load_tile_slice --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 27 +++-- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 18 ++- .../Transforms/LegalizeForLLVMExport.cpp | 37 +++--- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 15 ++- mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 76 ++++++------ mlir/test/Dialect/ArmSME/invalid.mlir | 13 ++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 114 +++++++++--------- 7 files changed, 173 insertions(+), 127 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 6f6b54aad0058..8a05ed89799d5 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -367,7 +367,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { } def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ - AllTypesMatch<["tile", "result"]> + AllTypesMatch<["tile", "result"]>, + TypesMatchWith< + "mask has i1 element type and same shape as result", + "result", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))" + ")">, ]> { let summary = "Tile slice load and update operation"; let description = [{ @@ -383,23 +391,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An SSA value `mask` specifies to mask out elements read from the MemRef. + The `mask` type is an `i1` vector with a shape that matches how elements + are read from the MemRef. + Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> ``` Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> ``` Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> ``` }]; let arguments = (ins - Arg:$base, + Arg:$base, AnyVector:$mask, SMETile:$tile, Variadic:$indices, Index:$tile_slice_index, ArmSME_TileSliceLayoutAttr:$layout ); @@ -415,8 +427,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ }]; let assemblyFormat = [{ - $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)? - attr-dict `:` type($base) `,` type($result) + $base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index + (`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,` + type($result) }]; } diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 0ec51b7430c02..9cfb13216d9bf 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, /// /// AFTER: /// ```mlir +/// %ptrue_s = arith.constant dense : vector<[4]xi1> /// %tile_id = arm_sme.get_tile_id : i32 /// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> /// %vscale = vector.vscale @@ -69,7 +70,8 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, /// %svl_s = arith.muli %min_svl_s, %vscale : index /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { /// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx], -/// %tile, %tile_slice_idx : memref, vector<[4]x[4]xi32> +/// %ptrue_s, %tile, %tile_slice_idx +/// : memref, vector<[4]xi1>, vector<[4]x[4]xi32> /// } /// ``` struct TileLoadOpConversion : public OpRewritePattern { @@ -77,6 +79,10 @@ struct TileLoadOpConversion : public OpRewritePattern { LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, PatternRewriter &rewriter) const override { + if (tileLoadOp.getMask()) + return rewriter.notifyMatchFailure( + tileLoadOp, "op has mask, needs masked pattern(s)"); + OpBuilder::InsertionGuard g(rewriter); auto loc = tileLoadOp.getLoc(); auto tileType = tileLoadOp.getVectorType(); @@ -109,6 +115,12 @@ struct TileLoadOpConversion : public OpRewritePattern { rewriter.setInsertionPointToStart(forOp.getBody()); + // Create an 'all true' predicate for the tile slice. + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + auto allTruePredicate = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + // Create 'arm_sme.load_tile_slice' to load tile slice from memory into // tile. SmallVector memrefIndices; @@ -117,8 +129,8 @@ struct TileLoadOpConversion : public OpRewritePattern { tileLoadOp.getMemRefType().getRank(), tileSliceIndex, numTileSlices, memrefIndices, loc, rewriter); rewriter.create( - loc, tileType, tileLoadOp.getBase(), tile, memrefIndices, - tileSliceIndex, tileLoadOp.getLayout()); + loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile, + memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 5e13707ea0aa2..220e0bdd70979 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create( - loc, rewriter.getI1Type(), - rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); - auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), - /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create(loc, predTy, one); + auto maskOp = loadTileSliceOp.getMask(); auto tileI32 = castTileIDToI32(tile, loc, rewriter); arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); @@ -195,24 +190,24 @@ struct LoadTileSliceToArmSMELowering default: llvm_unreachable("unexpected element type!"); case 8: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 16: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 32: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 64: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 128: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; } } else { @@ -220,23 +215,23 @@ struct LoadTileSliceToArmSMELowering default: llvm_unreachable("unexpected element type!"); case 8: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 16: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 32: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 64: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 128: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 4b3020970d6cc..3fb320c0d219e 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s +//===----------------------------------------------------------------------===// +// arm_sme.tile_load +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func.func @arm_sme_tile_load_hor( // CHECK-SAME: %[[SRC:.*]]: memref) { // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 @@ -10,8 +14,9 @@ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense : vector<[4]xi1> // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> +// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> func.func @arm_sme_tile_load_hor(%src : memref) { %c0 = arith.constant 0 : index %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xi32> @@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref) { return } +//===----------------------------------------------------------------------===// +// arm_sme.tile_store +//===----------------------------------------------------------------------===// + // ----- // CHECK-LABEL: func.func @arm_sme_tile_store_hor( @@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref) diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 07485b3ee8ddf..4fb4ca2f102ee 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -8,9 +8,9 @@ // CHECK-LABEL: func.func @arm_sme_load_tile_slice_hor_i8( // CHECK-SAME: %[[SRC:.*]]: memref, +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>, // CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, // CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index) { -// CHECK: %[[PTRUE_B:.*]] = arith.constant dense : vector<[16]xi1> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 @@ -21,12 +21,12 @@ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 -// CHECK: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK: "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK: return // CHECK: } -func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -34,9 +34,9 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector< // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -44,9 +44,9 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -54,9 +54,9 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -64,9 +64,9 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128 // CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -74,9 +74,9 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -84,9 +84,9 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -94,9 +94,9 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -104,9 +104,9 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } @@ -114,9 +114,9 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8 // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -124,9 +124,9 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector< // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -134,9 +134,9 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -144,9 +144,9 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -154,9 +154,9 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128 // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -164,9 +164,9 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -174,9 +174,9 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -184,9 +184,9 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -194,9 +194,9 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 9229f0415c076..60350a888c884 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -141,3 +141,16 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref, %pad : f64, %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> return } + +//===----------------------------------------------------------------------===// +// arm_sme.load_tile_slice +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that mask has i1 element type and same shape as result}} + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[16]x[16]xi8> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index f6459f0858436..93b103fb83ac4 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -638,173 +638,173 @@ func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memre // ----- -func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]x[16]xi8> +func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]x[8]xi16> +func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]x[4]xi32> +func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]x[2]xi64> +func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[1]x[1]xi128> +func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]x[8]xf16> +func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]x[4]xf32> +func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]x[2]xf64> +func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[16]x[16]xi8> +func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xi16> +func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]x[4]xi32> +func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]x[2]xi64> +func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[1]x[1]xi128> +func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xf16> +func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]x[4]xf32> +func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]x[2]xf64> +func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- /// Layout is optional and horizontal is the default, verify it's still parsed. -func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]x[16]xi8> +func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } From 05686112de89018c982c7e5f6879a2361c2fd562 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 08:47:40 +0000 Subject: [PATCH 3/9] [mlir][ArmSME] Propagate pad and mask in vector.transfer_read lowering This extends the lowering of vector.transfer_read -> arm_sme.tile_load lowering to propagate pad and mask. The restriction on the transfer_read being a transposition is also removed, identity maps are lowered to normal horizontal loads. --- .../VectorToArmSME/VectorToArmSME.cpp | 57 ++++--- .../Dialect/ArmSME/vector-ops-to-sme.mlir | 140 +++++++++--------- 2 files changed, 109 insertions(+), 88 deletions(-) diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index d06eb4f5b01c9..02a5bc64fa52c 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc, namespace { -/// Conversion pattern for vector.transfer_read op with transpose permutation -/// map to vertical arm_sme.tile_load (in-flight transpose). +/// Conversion pattern for vector.transfer_read. +/// +/// --- +/// +/// Example 1: op with identity permutation map to horizontal +/// arm_sme.tile_load: +/// +/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) +/// +/// is converted to: +/// +/// arm_sme.tile_load ... +/// +/// --- +/// +/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load +/// (in-flight transpose): /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) /// /// is converted to: /// /// arm_sme.tile_load ... layout -struct TransferReadPermutationToArmSMELowering +struct TransferReadToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering return rewriter.notifyMatchFailure(transferReadOp, "not a 2 result permutation map"); - AffineMap map = transferReadOp.getPermutationMap(); - - // Permutation map doesn't perform permutation, can be lowered to - // vector.load by TransferReadToVectorLoadLowering and then - // arm_sme.tile_load by VectorLoadToArmSMELowering. - if (map.isIdentity()) - return rewriter.notifyMatchFailure( - transferReadOp, "map is an identity, apply another pattern"); - auto vectorType = transferReadOp.getVectorType(); if (!arm_sme::isValidSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(transferReadOp, @@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering if (!llvm::isa(transferReadOp.getSource().getType())) return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); - if (transferReadOp.getMask()) - // TODO: support masking. - return rewriter.notifyMatchFailure(transferReadOp, - "masking not yet supported"); - // Out-of-bounds dims are not supported. if (transferReadOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(transferReadOp, "not inbounds transfer read"); + arm_sme::TileSliceLayout layout; + AffineExpr d0, d1; bindDims(transferReadOp.getContext(), d0, d1); - if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0}, - transferReadOp.getContext())) + AffineMap map = transferReadOp.getPermutationMap(); + if (map.isIdentity()) + layout = arm_sme::TileSliceLayout::Horizontal; + else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, + transferReadOp.getContext())) + layout = arm_sme::TileSliceLayout::Vertical; + else return rewriter.notifyMatchFailure(transferReadOp, - "not true 2-D matrix transpose"); + "unsupported permutation map"); + // Padding isn't optional for transfer_read, but is only used in the case + // of out-of-bounds accesses (not supported here) and/or masking. Mask is + // optional, if it's not present don't pass padding. + auto mask = transferReadOp.getMask(); + auto padding = mask ? transferReadOp.getPadding() : nullptr; rewriter.replaceOpWithNewOp( transferReadOp, vectorType, transferReadOp.getSource(), - transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical); + transferReadOp.getIndices(), padding, mask, layout); return success(); } @@ -432,7 +445,7 @@ struct TransposeOpToArmSMELowering void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { patterns.add(&ctx); } diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index 455b47a83e28f..80ca3d3b82813 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -1,181 +1,189 @@ // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s //===----------------------------------------------------------------------===// -// vector.transfer_read (with in-flight transpose) +// vector.transfer_read //===----------------------------------------------------------------------===// -// CHECK-LABEL: @transfer_read_2d_transpose_i8 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[16]x[16]xi8> -func.func @transfer_read_2d_transpose_i8(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i8 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> +func.func @transfer_read_2d_i8(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i8 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[16]x[16]xi8> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[16]x[16]xi8> "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i16 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xi16> -func.func @transfer_read_2d_transpose_i16(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xi16> +func.func @transfer_read_2d_i16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i16 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[8]x[8]xi16> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[8]x[8]xi16> "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i32 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xi32> -func.func @transfer_read_2d_transpose_i32(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i32 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xi32> +func.func @transfer_read_2d_i32(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i32 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[4]x[4]xi32> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[4]x[4]xi32> "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i64 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xi64> -func.func @transfer_read_2d_transpose_i64(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i64 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xi64> +func.func @transfer_read_2d_i64(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xi64> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[2]x[2]xi64> "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i128 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[1]x[1]xi128> -func.func @transfer_read_2d_transpose_i128(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i128 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[1]x[1]xi128> +func.func @transfer_read_2d_i128(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i128 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[1]x[1]xi128> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[1]x[1]xi128> "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_f16 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xf16> -func.func @transfer_read_2d_transpose_f16(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_f16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xf16> +func.func @transfer_read_2d_f16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f16 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[8]x[8]xf16> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[8]x[8]xf16> "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_bf16 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xbf16> -func.func @transfer_read_2d_transpose_bf16(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_bf16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xbf16> +func.func @transfer_read_2d_bf16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : bf16 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[8]x[8]xbf16> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[8]x[8]xbf16> "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_f32 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> -func.func @transfer_read_2d_transpose_f32(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_f32 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xf32> +func.func @transfer_read_2d_f32(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f32 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_f64 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xf64> -func.func @transfer_read_2d_transpose_f64(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_f64 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xf64> +func.func @transfer_read_2d_f64(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__bad_type -// CHECK-NOT: arm_sme.tile_load -// CHECK: vector.transfer_read -func.func @transfer_read_2d__bad_type(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_with_mask_i16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref, vector<[8]x[8]xi16> +func.func @transfer_read_2d_with_mask_i16(%src : memref, %mask : vector<[8]x[8]xi1>) { %c0 = arith.constant 0 : index - %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref, vector<[4]x[4]xf64> - "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () + %pad = arith.constant 0 : i16 + %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref, vector<[8]x[8]xi16> + "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__non_memref_type -// CHECK-NOT: arm_sme.tile_load -// CHECK: vector.transfer_read -func.func @transfer_read_2d__non_memref_type(%src : tensor) { +/// in-flight transpose + +// CHECK-LABEL: @transfer_read_2d_transpose_i8 +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[16]x[16]xi8> +func.func @transfer_read_2d_transpose_i8(%src : memref) { %c0 = arith.constant 0 : index - %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor, vector<[2]x[2]xf64> - "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () + %pad = arith.constant 0 : i8 + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[16]x[16]xi8> + "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank +// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32 +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> +func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref, %mask : vector<[4]x[4]xi1>) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: @transfer_read_2d__bad_type // CHECK-NOT: arm_sme.tile_load // CHECK: vector.transfer_read -func.func @transfer_read_2d__bad_transfer_rank(%src : memref) { +func.func @transfer_read_2d__bad_type(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref, vector<[2]xf64> - "prevent.dce"(%0) : (vector<[2]xf64>) -> () + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref, vector<[4]x[4]xf64> + "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__unsupported_mask +// CHECK-LABEL: @transfer_read_2d__non_memref_type // CHECK-NOT: arm_sme.tile_load // CHECK: vector.transfer_read -func.func @transfer_read_2d__unsupported_mask(%src : memref, %mask : vector<[2]x[2]xi1>) { +func.func @transfer_read_2d__non_memref_type(%src : tensor) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor, vector<[2]x[2]xf64> "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () return } // ----- -/// transfer_read with identity map should be lowered to vector.load by -/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by -/// VectorLoadToArmSMELowering. - -// CHECK-LABEL: @transfer_read_2d__non_permuting_map +// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank // CHECK-NOT: arm_sme.tile_load // CHECK: vector.transfer_read -func.func @transfer_read_2d__non_permuting_map(%src : memref) { +func.func @transfer_read_2d__bad_transfer_rank(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> - "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]xf64>) -> () return } From 554055a54323f56a482d037b387dcc582abea0b1 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 10:19:34 +0000 Subject: [PATCH 4/9] [mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 36 ++++++++++------ .../Transforms/LegalizeForLLVMExport.cpp | 43 +++++++++++++------ mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 32 ++++++++++++++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 16 +++++++ 4 files changed, 100 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 8a05ed89799d5..e35725934315b 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -498,21 +498,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ of a 2-D scalable vector tile at the given index. The type of the 1-D scalable vector to be moved must match the type of the tile slice. A tile slice is a 1-D vector of horizontally or vertically contiguous elements - within a ZA tile. Horizontal tile slices are currently assumed when - lowering to intrinsics. The updated tile is returned as the result. + within a ZA tile. The updated tile is returned as the result. - Example 1: Move a vector<[16]xi8> into tile at given index. + An optional tile slice layout attribute specifies whether the tile slice is + horizontal (default) or vertical. + + Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index. ```mlir %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8> ``` - Example 2: Move a vector<[2]xf64> into tile at given index. + Example 2: Move a vector<[2]xf64> into tile vertically at given index. ```mlir - %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64> + %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout : vector<[2]xf64> into vector<[2]x[2]xf64> ``` }]; let arguments = (ins - SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index); + SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index, + ArmSME_TileSliceLayoutAttr:$layout); let results = (outs SMETile:$result); let extraClassDeclaration = [{ @@ -522,7 +525,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ }]; let assemblyFormat = [{ - $vector `,` $tile `,` $tile_slice_index + $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)? attr-dict `:` type($vector) `into` type($result) }]; } @@ -537,21 +540,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure, let description = [{ The tile slice to vector operation extracts a 1-D scalable slice from a 2-D scalable tile at the given index. A tile slice is a 1-D vector of - horizontally or vertically contiguous elements within a ZA tile. Horizontal - tile slices are currently assumed when lowering to intrinsics. + horizontally or vertically contiguous elements within a ZA tile. + + An optional tile slice layout attribute specifies whether the tile slice is + horizontal (default) or vertical. - Example 1: Extract `vector<[16]xi8>` from tile at the given index. + Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index. ```mlir %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8> ``` - Example 2: Extract `vector<[2]xf64>` from tile at the given index. + Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index. ```mlir - %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64> + %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[2]xf64> from vector<[2]x[2]xf64> ``` }]; - let arguments = (ins SMETile:$tile, Index:$tile_slice_index); + let arguments = (ins + SMETile:$tile, Index:$tile_slice_index, + ArmSME_TileSliceLayoutAttr:$layout + ); let results = (outs SVEVector:$result); let extraClassDeclaration = [{ @@ -559,7 +567,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure, }]; let assemblyFormat = [{ - $tile `[` $tile_slice_index `]` attr-dict + $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict `:` type($result) `from` type($tile) }]; } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 220e0bdd70979..86f245d82b16c 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -345,8 +345,7 @@ struct StoreTileSliceToArmSMELowering } }; -/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal -/// tile slices are currently supported. +/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. struct MoveVectorToTileSliceToArmSMELowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -383,10 +382,19 @@ struct MoveVectorToTileSliceToArmSMELowering auto tileI32 = castTileIDToI32(tile, loc, rewriter); - // Create 'arm_sme.intr.write.horiz' to write vector to tile slice. - rewriter.create( - loc, tileI32, tileSliceI32, allActiveMask, - moveVectorToTileSliceOp.getVector()); + // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. + switch (moveVectorToTileSliceOp.getLayout()) { + case arm_sme::TileSliceLayout::Horizontal: + rewriter.create( + loc, tileI32, tileSliceI32, allActiveMask, + moveVectorToTileSliceOp.getVector()); + break; + case arm_sme::TileSliceLayout::Vertical: + rewriter.create( + loc, tileI32, tileSliceI32, allActiveMask, + moveVectorToTileSliceOp.getVector()); + break; + } // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with // 'arm_sme.cast_tile_to_vector' to preserve dataflow. @@ -397,8 +405,7 @@ struct MoveVectorToTileSliceToArmSMELowering } }; -/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal -/// tile slices are currently supported. +/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. struct MoveTileSliceToVectorArmSMELowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -430,10 +437,19 @@ struct MoveTileSliceToVectorArmSMELowering auto sliceIndexI32 = rewriter.create( loc, rewriter.getI32Type(), sliceIndex); - // Create 'arm_sme.intr.read.horiz' to extract the tile slice. - rewriter.replaceOpWithNewOp( - moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, - tileIdI32, sliceIndexI32); + // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. + switch (moveTileSliceToVector.getLayout()) { + case arm_sme::TileSliceLayout::Horizontal: + rewriter.replaceOpWithNewOp( + moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, + tileIdI32, sliceIndexI32); + break; + case arm_sme::TileSliceLayout::Vertical: + rewriter.replaceOpWithNewOp( + moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, + tileIdI32, sliceIndexI32); + break; + } return success(); } @@ -675,7 +691,8 @@ void mlir::configureArmSMELegalizeForExportTarget( arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, - arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa, + arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, + arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); target.addIllegalOp(); diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 4fb4ca2f102ee..30ddb3c468601 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s return } +//===----------------------------------------------------------------------===// +// arm_sme.move_vector_to_tile_slice +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32 +// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> () +func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () { + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> + return +} + +// ----- + +// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16 +// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> () +func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () { + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout : vector<[8]xbf16> into vector<[8]x[8]xbf16> + return +} //===----------------------------------------------------------------------===// // arm_sme.move_tile_slice_to_vector @@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64> return %slice : vector<[2]xf64> } + +// ----- + +// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128 +// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128> +func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> { + %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[1]xi128> from vector<[1]x[1]xi128> + return %slice : vector<[1]xi128> +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 93b103fb83ac4..f0704a75ed2fc 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -1069,6 +1069,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til return } +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout : vector<[16]xi8> into vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout : vector<[16]xi8> into vector<[16]x[16]xi8> + return +} //===----------------------------------------------------------------------===// // arm_sme.move_tile_slice_to_vector @@ -1145,3 +1153,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64> return %slice : vector<[2]xf64> } + +// ----- + +func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> { + // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout : vector<[2]xf64> from vector<[2]x[2]xf64> + %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[2]xf64> from vector<[2]x[2]xf64> + return %slice : vector<[2]xf64> +} From 4badd45c9d0152bb502b8b99b79dc4ef2d095f6d Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 09:15:52 +0000 Subject: [PATCH 5/9] [mlir][ArmSME] Add support for lowering masked tile_load ops This patch extends ArmSMEToSCF to support lowering of masked tile_load ops. Only masks created by 'vector.create_mask' are currently supported. There are two lowerings, one for pad of constant zero and another for non-zero pad. For the following example: %pad = arith.constant 0 : i32 %num_rows = arith.constant 2 : index %num_cols = arith.constant 4 : index %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1> %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[4]x[4]xi32> The former (constant non-zero pad) is lowered as follows: --------------------------------------------------------- %tile = arm_sme.zero : vector<[4]x[4]xi32> %num_cols = vector.create_mask %c4 : vector<[4]xi1> scf.for %slice_idx = %c0 to %num_rows step %c1 %tile_update = arm_sme.load_tile_slice %src[%slice_idx], %num_cols, %tile, %tile_slice_idx : memref, vector<[1]xi32>, vector<[4]x[4]xi32> The tile is zeroed the satisfy the padding and only active rows are loaded. The latter (non-zero pad) is lowered as follows: ------------------------------------------------ scf.for %slice_idx = %c0 to %num_tile_slices step %c1 { %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index %slice = scf.if %row_is_active -> vector<[4]xf32> { %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> scf.yield %slice : vector<[4]xf32> } else { scf.yield %pad_1d : vector<[4]xf32> } arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx : vector<[4]xi32> into vector<[4]x[4]xi32> The scalar pad is broadcast to a 1-D vector and a regular 'vector.masked_load' (will be lowered to SVE, not SME) loads each slice for active rows, with padding specified as a passthru. For non-active rows the slice is the 1-D pad. The resulting slice is inserted into the tile with 'arm_sme.move_vector_to_tile_slice'. --- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 251 +++++++++++++++++- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 56 ++++ .../CPU/ArmSME/test-transfer-read-2d.mlir | 237 +++++++++++++++++ 3 files changed, 543 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 9cfb13216d9bf..75b7b8acdd190 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -141,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern { } }; +/// Lower `arm_sme.tile_load` with mask and pad of constant zero. +/// +/// BEFORE: +/// ```mlir +/// %pad = arith.constant 0 : i32 +/// %num_rows = arith.constant 2 : index +/// %num_cols = arith.constant 4 : index +/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1> +/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : +/// memref, vector<[4]x[4]xi32> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %tile = arm_sme.zero : vector<[4]x[4]xi32> +/// %num_cols = vector.create_mask %c4 : vector<[4]xi1> +/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 { +/// %tile_update = arm_sme.load_tile_slice +/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx : +/// memref, vector<[1]xi32>, vector<[4]x[4]xi32> +/// } +/// ``` +/// +/// NOTE: Only mask of 'vector.create_mask' op is currently supported. +struct TileLoadOpWithMaskAndPadZeroConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, + PatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard g(rewriter); + auto loc = tileLoadOp.getLoc(); + auto tileType = tileLoadOp.getVectorType(); + + auto maskOp = tileLoadOp.getMask(); + if (!maskOp) + return rewriter.notifyMatchFailure( + tileLoadOp, "op has no mask, needs unmasked pattern"); + + auto padOp = tileLoadOp.getPadding(); + assert(padOp && "expected padding when masking!"); + + auto createMaskOp = maskOp.getDefiningOp(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + tileLoadOp, "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + + auto constPadOp = padOp.getDefiningOp(); + if (!constPadOp || constPadOp.getValue() != + rewriter.getZeroAttr(tileType.getElementType())) + return rewriter.notifyMatchFailure( + tileLoadOp, "op has non-zero pad, needs non-zero pad pattern"); + + auto numRows = createMaskOp.getOperands()[0]; + auto numCols = createMaskOp.getOperands()[1]; + + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + auto numColsOp = + rewriter.create(loc, predicateType, numCols); + + // Initialize tile with zero to satisfy padding. Inactive cols will be + // zeroed anyway since the loads use zeroing predication. For inactive rows + // however, no load will occur so these need to be zeroed. + auto tile = rewriter.create(loc, tileType); + + // Create a loop to load the active tile slices from memory. + auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = numRows; + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + + rewriter.setInsertionPointToStart(forOp.getBody()); + + // Create 'arm_sme.load_tile_slice' to load tile slice from memory into + // tile. + SmallVector memrefIndices; + auto tileSliceIndex = forOp.getInductionVar(); + getMemrefIndices(tileLoadOp.getIndices(), + tileLoadOp.getMemRefType().getRank(), tileSliceIndex, + upperBound, memrefIndices, loc, rewriter); + rewriter.create( + loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices, + tileSliceIndex, tileLoadOp.getLayout()); + + rewriter.setInsertionPointAfter(forOp); + + // Replace 'arm_sme.tile_load' with the tile. + rewriter.replaceOp(tileLoadOp, tile); + + return success(); + } +}; + +/// Lower `arm_sme.tile_load` with mask and non-zero pad. +/// +/// BEFORE: +/// ```mlir +/// %pad = arith.constant 1 : i32 +/// %num_rows = arith.constant 2 : index +/// %num_cols = arith.constant 4 : index +/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1> +/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : +/// memref, vector<[4]x[4]xi32> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %pad_1d = arith.constant dense<1> : vector<[4]xi32> +/// %num_rows = arith.constant 2 : index +/// %num_cols = arith.constant 4 : index +/// %tile_id = arm_sme.get_tile_id : i32 +/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %svl_s = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { +/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index +/// %slice = scf.if %row_is_active -> vector<[4]xi32> { +/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad +/// : memref, vector<[4]xi1>, +/// vector<[4]xi32> into vector<[4]xi32> +/// scf.yield %slice : vector<[4]xi32> +/// } else { +/// scf.yield %pad_1d : vector<[4]xi32> +/// } +/// // Insert slice into tile +/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx +/// : vector<[4]xi32> into vector<[4]x[4]xi32> +/// } +/// ``` +struct TileLoadOpWithMaskAndPadNonZeroConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, + PatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard g(rewriter); + auto loc = tileLoadOp.getLoc(); + auto tileType = tileLoadOp.getVectorType(); + auto tileElementType = tileType.getElementType(); + unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + + auto maskOp = tileLoadOp.getMask(); + if (!maskOp) + return rewriter.notifyMatchFailure( + tileLoadOp, "op has no mask, needs unmasked pattern"); + + auto padOp = tileLoadOp.getPadding(); + assert(padOp && "expected padding when masking!"); + + auto createMaskOp = maskOp.getDefiningOp(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + tileLoadOp, "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + + auto constPadOp = padOp.getDefiningOp(); + if (constPadOp && + constPadOp.getValue() == rewriter.getZeroAttr(tileElementType)) + return rewriter.notifyMatchFailure( + tileLoadOp, "op has constant zero pad, needs zero pad pattern"); + + auto numRows = createMaskOp.getOperands()[0]; + auto numCols = createMaskOp.getOperands()[1]; + + VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + auto numColsOp = + rewriter.create(loc, predicateType, numCols); + + // Create 'arm_sme.get_tile' op. + auto tileId = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth)); + + // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to + // use as input tile to 'arm_sme.load_tile_slice' ops. + auto tile = + rewriter.create(loc, tileType, tileId); + + // Create a loop that loads each ZA tile slice from memory. + auto step = rewriter.create(loc, 1); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); + + rewriter.setInsertionPointToStart(forOp.getBody()); + + auto tileSliceIndex = forOp.getInductionVar(); + + auto rowIsActive = rewriter.create( + loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); + + SmallVector memrefIndices; + getMemrefIndices(tileLoadOp.getIndices(), + tileLoadOp.getMemRefType().getRank(), tileSliceIndex, + numTileSlices, memrefIndices, loc, rewriter); + + // Splat pad into 1-D vector matching type of tile slice. + auto pad1DOp = rewriter.create(loc, tileSliceType, padOp); + + Operation *slice = rewriter.create( + loc, rowIsActive, + [&](OpBuilder &b, Location loc) { + // If the row is active, emit a masked load where the predicate is + // 'numCols'. Pad is used for inactive elements, taken from + // passthru. + auto loadSlice = rewriter.create( + loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, + numColsOp, /*passthru=*/pad1DOp); + rewriter.create(loc, loadSlice->getResult(0)); + }, + [&](OpBuilder &b, Location loc) { + // Inactive rows are filled with pad. + rewriter.create(loc, pad1DOp.getResult()); + }); + + // TODO: If the load is vertical the transpose can't be done in-flight with + // a regular (SVE) maskedload. Propagate layout to + // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This + // is currently broken. + + // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile. + rewriter.create( + loc, tileType, slice->getResult(0), tile, tileSliceIndex, + tileLoadOp.getLayout()); + + rewriter.setInsertionPointAfter(forOp); + + // Replace 'arm_sme.tile_load' with the tile. + rewriter.replaceOp(tileLoadOp, tile); + + return success(); + } +}; + /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each /// slice using `arm_sme.store_tile_slice`. /// @@ -265,7 +513,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern { } // namespace void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 3fb320c0d219e..4906812032ae9 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref) { return } +// ----- + +// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero( +// CHECK-SAME: %[[SRC:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> +// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32> +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] { +// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> +func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %pad = arith.constant 0 : i32 + %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad( +// CHECK-SAME: %[[SRC:.*]]: memref, +// CHECK-SAME: %[[PAD:.*]]: i32) { +// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index +// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32> +// CHECK: %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) { +// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> +// CHECK: scf.yield %[[LOAD_SLICE]] : vector<[4]xi32> +// CHECK: } else { +// CHECK: scf.yield %[[PAD_1D]] : vector<[4]xi32> +// CHECK: } +// CHECK: arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32> +func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref, %pad : i32) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[4]x[4]xi32> + return +} + //===----------------------------------------------------------------------===// // arm_sme.tile_store //===----------------------------------------------------------------------===// diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir new file mode 100644 index 0000000000000..fe40dd13ce291 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir @@ -0,0 +1,237 @@ +// DEFINE: %{entry_point} = entry +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s + +llvm.func @printCString(!llvm.ptr) + +// TODO: replace with vector.print once #68695 lands. +func.func @print_str(%str: !llvm.ptr>) attributes { enable_arm_streaming_ignore } { + %c0 = llvm.mlir.constant(0 : index) : i64 + %str_bytes = llvm.getelementptr %str[%c0, %c0] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%str_bytes) : (!llvm.ptr) -> () + return +} + +// Vector load. +func.func @transfer_read_2d(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %c4 = arith.constant 4 : index + %pad = arith.constant 0.0 : f32 + %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} : + memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0: vector<[4]x[4]xf32> + + return +} + +// Vector load + transpose. +func.func @transfer_read_2d_transposed(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %pad = arith.constant 0.0 : f32 + %0 = vector.transfer_read %A[%base1, %base2], %pad + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} + : memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0 : vector<[4]x[4]xf32> + + return +} + +// Vector load with mask and pad of zero. +func.func @transfer_read_2d_mask(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %pad = arith.constant 0.0 : f32 + %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1> + %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask + {in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0: vector<[4]x[4]xf32> + + return +} + +// Vector load with mask and pad of zero + transpose. +func.func @transfer_read_2d_mask_transposed(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %pad = arith.constant 0.0 : f32 + %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1> + %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} + : memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0: vector<[4]x[4]xf32> + + return +} + +// Vector load with mask and non-zero pad. +func.func @transfer_read_2d_mask_non_zero_pad(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %pad = arith.constant -42.0 : f32 + %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1> + %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask + {in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0: vector<[4]x[4]xf32> + + return +} + +// Vector load with mask and non-zero pad + transpose. +func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %pad = arith.constant -42.0 : f32 + %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1> + %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} + : memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0: vector<[4]x[4]xf32> + + return +} + +// Allocate heap memory of size 'd0' x 'd1' and initialize. +// +// Example: +// +// initialize_memory(%c4, %c5) +// +// 0, 1, 2, 3, 4 +// 10, 11, 12, 13, 14 +// 20, 21, 22, 23, 24 +// 30, 31, 32, 33, 34 +// +// Returns dynamic memref. It's the callers responsiblity to free the returned +// memref. +func.func @initialize_memory(%d0 : index, %d1 : index) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f32 = arith.constant 1.0 : f32 + %c10_f32 = arith.constant 10.0 : f32 + + %A = memref.alloc(%d0, %d1) : memref + + %init = arith.constant 0.0 : f32 + scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 { + scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 { + memref.store %inner_val, %A[%i, %j] : memref + %inner_val_next = arith.addf %inner_val, %c1_f32 : f32 + scf.yield %inner_val_next : f32 + } + %val_next = arith.addf %val, %c10_f32 : f32 + scf.yield %val_next : f32 + } + + return %A : memref +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test + // non-zero offsets while remaining inbounds. + %vscale = vector.vscale + %svl_s = arith.muli %c4, %vscale : index + %svl_s_plus_two = arith.addi %svl_s, %c2 : index + + %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref + + // 1.a. Read 2D vector from 2D memref. + // + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 1, 2, 3 + // CHECK-NEXT: ( 10, 11, 12, 13 + // CHECK-NEXT: ( 20, 21, 22, 23 + // CHECK-NEXT: ( 30, 31, 32, 33 + call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () + + // 1.b. Same as 1.a., but with non-zero offsets. + // + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 12, 13, 14, 15 + // CHECK-NEXT: ( 22, 23, 24, 25 + // CHECK-NEXT: ( 32, 33, 34, 35 + // CHECK-NEXT: ( 42, 43, 44, 45 + call @transfer_read_2d(%A, %c1, %c2) : (memref, index, index) -> () + + // 2. Same as 1.a., but with mask and a pad of constant zero. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 1, 2, 0 + // CHECK-NEXT: ( 10, 11, 12, 0 + // CHECK-NEXT: ( 0, 0, 0, 0 + // CHECK-NEXT: ( 0, 0, 0, 0 + call @transfer_read_2d_mask(%A, %c0, %c0) : (memref, index, index) -> () + + // 3. Same as 1.a., but with mask and non-zero pad. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 1, 2, -42 + // CHECK-NEXT: ( 10, 11, 12, -42 + // CHECK-NEXT: ( -42, -42, -42, -42 + // CHECK-NEXT: ( -42, -42, -42, -42 + call @transfer_read_2d_mask_non_zero_pad(%A, %c0, %c0) : (memref, index, index) -> () + + // 4. Same as 1.a., but transpose the result. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 10, 20, 30 + // CHECK-NEXT: ( 1, 11, 21, 31 + // CHECK-NEXT: ( 2, 12, 22, 32 + // CHECK-NEXT: ( 3, 13, 23, 33 + call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref, index, index) -> () + + // 5. Same as 2., but transpose the result. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 10, 0, 0 + // CHECK-NEXT: ( 1, 11, 0, 0 + // CHECK-NEXT: ( 2, 12, 0, 0 + // CHECK-NEXT: ( 0, 0, 0, 0 + call @transfer_read_2d_mask_transposed(%A, %c0, %c0) : (memref, index, index) -> () + + // 5. Same as 3, but transpose the result. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 10, -42, -42 + // CHECK-NEXT: ( 1, 11, -42, -42 + // CHECK-NEXT: ( 2, 12, -42, -42 + // CHECK-NEXT: ( -42, -42, -42, -42 + call @transfer_read_2d_mask_non_zero_pad_transposed(%A, %c0, %c0) : (memref, index, index) -> () + + memref.dealloc %A : memref + + return +} + +llvm.mlir.global internal constant @tile_begin("TILE BEGIN: \0A\00") From 0bba1701313b116693a7243826169ff6670c4c31 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 13:16:54 +0000 Subject: [PATCH 6/9] [mlir][ArmSME] Add optional mask operand to tile_store --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 37 +++++++++++++++++-- .../VectorToArmSME/VectorToArmSME.cpp | 4 +- mlir/test/Dialect/ArmSME/invalid.mlir | 14 +++++++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 9 +++++ .../Dialect/ArmSME/vector-ops-to-sme.mlir | 14 +++++++ 5 files changed, 72 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index e35725934315b..85cbe22acad6f 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -322,7 +322,18 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ "attr-dict `:` type($base) `,` type($result)"; } -def TileStoreOp : ArmSME_Op<"tile_store"> { +def TileStoreOp : ArmSME_Op<"tile_store", [ + AttrSizedOperandSegments, + TypesMatchWith< + "mask has i1 element type and same shape as value to store (if present)", + "valueToStore", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").setElementType(IntegerType::get($_self.getContext(), 1)))", + "!getMask() || std::equal_to<>()" + > +]> { let summary = "Tile store operation"; let description = [{ Stores a 2D SME "virtual tile" to memory defined by a base and indices, @@ -333,6 +344,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { rank 2 with dynamic dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An optional SSA value `mask` may be specified to mask out elements written + to the MemRef. The `mask` type is an `i1` vector of the same shape as the + vector type that matches how elements are written into the MemRef. Elements + whose corresponding mask element is `0` are masked out. + Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B). ```mlir arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref @@ -347,10 +363,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { ```mlir arm_sme.tile_store %tile, %base[%c0, %c0] layout : vector<[1]x[1]xi128>, memref ``` + + Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory. + ```mlir + arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout : vector<[4]x[4]xf32>, memref + ``` }]; let arguments = (ins SMETile:$valueToStore, Arg:$base, - Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout + Variadic:$indices, Optional:$mask, + ArmSME_TileSliceLayoutAttr:$layout ); let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -361,9 +383,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { } }]; + let builders = [ + OpBuilder<(ins "Value":$valueToStore, "Value":$base, + "ValueRange":$indices), [{ + build($_builder, $_state, valueToStore, base, indices, {}); + }]>, + ]; + let assemblyFormat = - "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict " - "`:` type($base) `,` type($valueToStore)"; + "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?" + "attr-dict `:` type($base) `,` type($valueToStore)"; } def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 02a5bc64fa52c..0cc5732c9212d 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -157,8 +157,8 @@ struct TransferWriteToArmSMELowering return failure(); rewriter.replaceOpWithNewOp( - writeOp, writeOp.getVector(), writeOp.getSource(), - writeOp.getIndices()); + writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), + writeOp.getMask()); return success(); } }; diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 60350a888c884..7a2550b8576d7 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -142,6 +142,20 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref, %pad : f64, return } +//===----------------------------------------------------------------------===// +// arm_sme.tile_store +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}} + arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.load_tile_slice //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index f0704a75ed2fc..c0c5c539f3f08 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre // ----- +func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref, %mask : vector<[4]x[4]xi1>) { + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout : memref, vector<[4]x[4]xf32> + return +} + +// ----- + /// Layout is optional and horizontal is the default, verify it's still parsed. func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref) { // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index 80ca3d3b82813..f9251edbe658b 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -323,6 +323,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref, +// CHECK-SAME: %[[DEST:.*]]: memref, +// CHECK-SAME: %[[MASK:.*]]: vector<[2]x[2]xi1>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], {{.*}} : memref, vector<[2]x[2]xf64> +func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref, %mask : vector<[2]x[2]xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref + return +} + +// ----- + // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' // lowering only occurs for vector types of correct rank, shape, element size // and number of scalable dims. From e5e3e1405051a88a8490c5e13f8c29645d2bdf72 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 13:39:18 +0000 Subject: [PATCH 7/9] [mlir][ArmSME] Add mask operand to store_tile_slice --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 28 +++-- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 9 +- .../Transforms/LegalizeForLLVMExport.cpp | 28 ++--- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 3 +- mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 76 ++++++------ mlir/test/Dialect/ArmSME/invalid.mlir | 14 +++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 114 +++++++++--------- 7 files changed, 151 insertions(+), 121 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 85cbe22acad6f..36fc4b9a39728 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -462,7 +462,16 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ }]; } -def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { +def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ + TypesMatchWith< + "mask has i1 element type and same shape as tile slice", + "tile", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))" + ")"> +]> { let summary = "Tile slice store operation"; let description = [{ Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile @@ -477,22 +486,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the input tile. + An SSA value `mask` specifies to mask out elements written to the MemRef. + The `mask` type is an `i1` vector with a shape that matches how elements + are written to the MemRef. + Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref ``` Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout : vector<[4]x[4]xf32>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout : vector<[4]x[4]xf32>, vector<[4]xi1>, memref ``` Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout : vector<[1]x[1]xi128>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout : vector<[1]x[1]xi128>, vector<[1]xi1>, memref ``` }]; - let arguments = (ins SMETile:$tile, Index:$tile_slice_index, + let arguments = (ins + SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask, Arg:$base, Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout ); @@ -506,8 +520,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { }]; let assemblyFormat = [{ - $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)? - attr-dict `:` type($base) `,` type($tile) + $tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)? + attr-dict `:` type($base) `,` type($mask) `,` type($tile) }]; } diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 75b7b8acdd190..e72064651c5ca 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -437,6 +437,12 @@ struct TileStoreOpConversion : public OpRewritePattern { rewriter.setInsertionPointToStart(forOp.getBody()); + // Create an 'all true' predicate for the tile slice. + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + auto allTruePredicate = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + SmallVector memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); getMemrefIndices(tileStoreOp.getIndices(), @@ -444,7 +450,8 @@ struct TileStoreOpConversion : public OpRewritePattern { numTileSlices, memrefIndices, loc, rewriter); rewriter.replaceOpWithNewOp( tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, - tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout()); + allTruePredicate, tileStoreOp.getBase(), memrefIndices, + tileStoreOp.getLayout()); return success(); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 86f245d82b16c..bbfe41a34e150 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -278,13 +278,7 @@ struct StoreTileSliceToArmSMELowering auto tileSliceI32 = rewriter.create( loc, rewriter.getI32Type(), tileSlice); - // Create all active predicate mask. - auto one = rewriter.create( - loc, rewriter.getI1Type(), - rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); - auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), - /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create(loc, predTy, one); + auto maskOp = storeTileSliceOp.getMask(); Value tileI32 = castTileIDToI32(tile, loc, rewriter); arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout(); @@ -295,23 +289,23 @@ struct StoreTileSliceToArmSMELowering llvm_unreachable("unexpected element type!"); case 8: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 16: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 32: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 64: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 128: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; } } else { @@ -320,23 +314,23 @@ struct StoreTileSliceToArmSMELowering llvm_unreachable("unexpected element type!"); case 8: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 16: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 32: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 64: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; case 128: rewriter.replaceOpWithNewOp( - storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32); break; } } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 4906812032ae9..55ea56f42c96e 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -104,8 +104,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref : vector<[4]xi1> // CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]x[4]xi32> +// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref) { %c0 = arith.constant 0 : index arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 30ddb3c468601..8fdcf69958244 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -209,8 +209,8 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vecto // CHECK-LABEL: func.func @arm_sme_store_tile_slice_hor_i8( // CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, // CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>, // CHECK-SAME: %[[DEST:.*]]: memref) { -// CHECK: %[[PTRUE_B:.*]] = arith.constant dense : vector<[16]xi1> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 @@ -221,12 +221,12 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vecto // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 -// CHECK: "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK: "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK: return // CHECK: } -func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -234,9 +234,9 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -244,9 +244,9 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -254,9 +254,9 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -264,9 +264,9 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128 // CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -274,9 +274,9 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -284,9 +284,9 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -294,9 +294,9 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -304,9 +304,9 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } @@ -314,9 +314,9 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8 // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -324,9 +324,9 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -334,9 +334,9 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -344,9 +344,9 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -354,9 +354,9 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128 // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -364,9 +364,9 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -374,9 +374,9 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -384,9 +384,9 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -394,9 +394,9 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { +func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 7a2550b8576d7..c29ae0581d392 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -168,3 +168,17 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[16]x[16]xi8> return } + +//===----------------------------------------------------------------------===// +// arm_sme.store_tile_slice +//===----------------------------------------------------------------------===// + + +// ----- + +func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that mask has i1 element type and same shape as tile slice}} + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[16]x[16]xi8> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index c0c5c539f3f08..640ca3835e88a 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -823,173 +823,173 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector< // ----- -func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> +func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xi16> +func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xi32> +func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xi64> +func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[1]x[1]xi128> +func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xf16> +func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xf32> +func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xf64> +func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[16]x[16]xi8> +func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xi16> +func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]x[4]xi32> +func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]x[2]xi64> +func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[1]x[1]xi128> +func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xf16> +func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]x[4]xf32> +func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]x[2]xf64> +func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- /// Layout is optional and horizontal is the default, verify it's still parsed. -func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> +func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } From f5bbd5a8006545d8350b33a4d3e656c56192b927 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 13:57:59 +0000 Subject: [PATCH 8/9] [mlir][ArmSME] Add support for lowering masked tile_store ops This patch extends ArmSMEToSCF to support lowering of masked tile_store ops. Only masks created by 'vector.create_mask' are currently supported. Example: %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref, vector<[4]x[4]xi32> Produces: %num_rows = arith.constant 3 : index %num_cols = vector.create_mask %c2 : vector<[4]xi1> scf.for %slice_idx = %c0 to %num_rows step %c1 arm_sme.store_tile_slice %tile, %slice_idx, %num_cols, %dest[%slice_idx, %c0] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> --- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 65 ++++++++++++------- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 25 ++++++- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index e72064651c5ca..86d1172ac4957 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -420,38 +420,59 @@ struct TileStoreOpConversion : public OpRewritePattern { auto tileType = tileStoreOp.getVectorType(); auto tileElementType = tileType.getElementType(); - // Create a loop that stores each ZA tile slice from memory. + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + + Value maskCols; + Value upperBound; + auto maskOp = tileStoreOp.getMask(); + if (maskOp) { + auto createMaskOp = maskOp.getDefiningOp(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + tileStoreOp, "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + + auto numRows = createMaskOp.getOperands()[0]; + auto numCols = createMaskOp.getOperands()[1]; + + upperBound = numRows; + maskCols = + rewriter.create(loc, predicateType, numCols); + } else { + // Store all tile slices if no mask. + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, + // SVL_H, + // ..., SVL_Q). + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + + upperBound = numTileSlices; + // Create an 'all true' predicate for the tile slice. + maskCols = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + } + + // Create a loop that stores each (active) active ZA tile slice from memory. auto step = rewriter.create(loc, 1); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); - // This describes both the number of ZA tile slices and the number of - // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, - // ..., SVL_Q). - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - auto forOp = - rewriter.create(loc, lowerBound, numTileSlices, step); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); rewriter.setInsertionPointToStart(forOp.getBody()); - // Create an 'all true' predicate for the tile slice. - auto predicateType = - VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto allTruePredicate = rewriter.create( - loc, DenseElementsAttr::get(predicateType, true)); - SmallVector memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); getMemrefIndices(tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(), tileSliceIndex, - numTileSlices, memrefIndices, loc, rewriter); + upperBound, memrefIndices, loc, rewriter); rewriter.replaceOpWithNewOp( - tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, - allTruePredicate, tileStoreOp.getBase(), memrefIndices, - tileStoreOp.getLayout()); + tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols, + tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout()); return success(); } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 55ea56f42c96e..58c6998870edd 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -102,9 +102,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref : vector<[4]xi1> +// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: %[[PTRUE_S:.*]] = arith.constant dense : vector<[4]xi1> // CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index // CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref) { @@ -123,6 +123,27 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] { +// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK-NEXT: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> +func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> + arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref, vector<[4]x[4]xi32> + return +} + //===----------------------------------------------------------------------===// // vector.print //===----------------------------------------------------------------------===// From 14aac4339638dedc0ed18cc5ab35a346cda32e79 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Mon, 16 Oct 2023 11:43:32 +0000 Subject: [PATCH 9/9] [mlir][ArmSME] Lower transfer_write + transpose to vertical store This patch extends the lowering of vector.transfer_write in VectorToArmSME to support in-flight transpose via SME vertical store. --- .../VectorToArmSME/VectorToArmSME.cpp | 47 ++++- .../Dialect/ArmSME/vector-ops-to-sme.mlir | 42 +++++ .../CPU/ArmSME/test-transfer-write-2d.mlir | 174 ++++++++++++++++++ 3 files changed, 260 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 0cc5732c9212d..40e8378306bbf 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering /// Conversion pattern for vector.transfer_write. /// -/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>, -/// memref +/// --- +/// +/// Example 1: op with identity permutation map to horizontal +/// arm_sme.tile_store: +/// +/// vector.transfer_write %vector, %source[%c0, %c0] +/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref, /// vector<[16]x[16]xi8> +/// --- +/// +/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store +/// (in-flight transpose): +/// +/// vector.transfer_write %vector, %source[%c0, %c0] +/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, +/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref +/// +/// is converted to: +/// +/// arm_sme.tile_store %vector, %source[%c0, %c0] layout +/// : memref, vector<[16]x[16]xi8> struct TransferWriteToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering if (!arm_sme::isValidSMETileVectorType(vType)) return failure(); + assert(writeOp.getTransferRank() == 2 && + "expected a permutation_map with result dims of the same rank as " + "the vector type"); + if (!llvm::isa(writeOp.getSource().getType())) return failure(); + // Out-of-bounds dims are not supported. + if (writeOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(writeOp, + "not inbounds transfer write"); + + arm_sme::TileSliceLayout layout; + + AffineExpr d0, d1; + bindDims(writeOp.getContext(), d0, d1); + AffineMap map = writeOp.getPermutationMap(); + if (map.isIdentity()) + layout = arm_sme::TileSliceLayout::Horizontal; + else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, + writeOp.getContext())) + layout = arm_sme::TileSliceLayout::Vertical; + else + return rewriter.notifyMatchFailure(writeOp, + "unsupported permutation map"); + rewriter.replaceOpWithNewOp( writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), - writeOp.getMask()); + writeOp.getMask(), layout); return success(); } }; diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index f9251edbe658b..e1a8a9ff9bf10 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest // ----- +/// in-flight transpose via vertical store. + +// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout : memref, vector<[2]x[2]xi64> +func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref + return +} + +// ----- + +/// in-flight transpose via vertical store with mask. + +// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>, +// CHECK-SAME: %[[DEST:.*]]: memref, +// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout : memref, vector<[8]x[8]xbf16> +func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref, %mask : vector<[8]x[8]xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref + return +} + +// ----- + // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' // lowering only occurs for vector types of correct rank, shape, element size // and number of scalable dims. @@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref + return +} + //===----------------------------------------------------------------------===// // vector.broadcast //===----------------------------------------------------------------------===// diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir new file mode 100644 index 0000000000000..1cb685d7bc27c --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir @@ -0,0 +1,174 @@ +// DEFINE: %{entry_point} = entry +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s + +llvm.func @printCString(!llvm.ptr) + +// TODO: replace with vector.print once #68695 lands. +func.func @print_str(%str: !llvm.ptr>) attributes { enable_arm_streaming_ignore } { + %c0 = llvm.mlir.constant(0 : index) : i64 + %str_bytes = llvm.getelementptr %str[%c0, %c0] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%str_bytes) : (!llvm.ptr) -> () + return +} + +// Vector store. +func.func @transfer_write_2d(%A : memref, %base1: index, %base2: index) { + %c0 = arith.constant 0.0 : f32 + %zero = vector.splat %c0 : vector<[4]x[4]xf32> + vector.transfer_write %zero, %A[%base1, %base2] {in_bounds=[true, true]} : + vector<[4]x[4]xf32>, memref + return +} + +// Masked vector store. +func.func @transfer_write_2d_mask(%A : memref, %base1: index, %base2: index) { + %c0 = arith.constant 0.0 : f32 + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1> + %zero = vector.splat %c0 : vector<[4]x[4]xf32> + vector.transfer_write %zero, %A[%base1, %base2], %mask {in_bounds=[true, true]} : + vector<[4]x[4]xf32>, memref + return +} + +// Vector store + transpose. +func.func @transfer_write_2d_transposed(%A : memref, %base1: index, %base2: index) { + %0 = vector.load %A[%base1, %base2] : memref, vector<[4]x[4]xf32> + vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} : + vector<[4]x[4]xf32>, memref + return +} + +// Masked vector store + transpose. +func.func @transfer_write_2d_mask_transposed(%A : memref, %base1: index, %base2: index) { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1> + %0 = vector.load %A[%base1, %base2] : memref, vector<[4]x[4]xf32> + vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} : + vector<[4]x[4]xf32>, memref + return +} + +// Vector load + print. +func.func @load_and_print(%A : memref, %base1: index, %base2: index) { + %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr> + + %0 = vector.load %A[%base1, %base2] : memref, vector<[4]x[4]xf32> + + func.call @print_str(%tile_begin_str) : (!llvm.ptr>) -> () + vector.print %0: vector<[4]x[4]xf32> + + return +} + +// Allocate heap memory of size 'd0' x 'd1' and initialize. +// +// Example: +// +// initialize_memory(%c4, %c5) +// +// 0, 1, 2, 3, 4 +// 10, 11, 12, 13, 14 +// 20, 21, 22, 23, 24 +// 30, 31, 32, 33, 34 +// +// Returns dynamic memref. It's the callers responsiblity to free the returned +// memref. +func.func @initialize_memory(%d0 : index, %d1 : index) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f32 = arith.constant 1.0 : f32 + %c10_f32 = arith.constant 10.0 : f32 + + %A = memref.alloc(%d0, %d1) : memref + + %init = arith.constant 0.0 : f32 + scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 { + scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 { + memref.store %inner_val, %A[%i, %j] : memref + %inner_val_next = arith.addf %inner_val, %c1_f32 : f32 + scf.yield %inner_val_next : f32 + } + %val_next = arith.addf %val, %c10_f32 : f32 + scf.yield %val_next : f32 + } + + return %A : memref +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test + // non-zero offsets while remaining inbounds. + %vscale = vector.vscale + %svl_s = arith.muli %c4, %vscale : index + %svl_s_plus_two = arith.addi %svl_s, %c2 : index + + // 1. Initialize memory + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 1, 2, 3 + // CHECK-NEXT: ( 10, 11, 12, 13 + // CHECK-NEXT: ( 20, 21, 22, 23 + // CHECK-NEXT: ( 30, 31, 32, 33 + %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref + call @load_and_print(%A, %c0, %c0) : (memref, index, index) -> () + + // 2. Write 2-D vector of zeroes to 1. at offset [2, 2]. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 1, 2, 3 + // CHECK-NEXT: ( 10, 11, 12, 13 + // CHECK-NEXT: ( 20, 21, 0, 0 + // CHECK-NEXT: ( 30, 31, 0, 0 + call @transfer_write_2d(%A, %c2, %c2) : (memref, index, index) -> () + call @load_and_print(%A, %c0, %c0) : (memref, index, index) -> () + + // 3. Write 2-D vector of zeroes to 2. but with mask (nrows=2, ncols=3). + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 0, 0, 3 + // CHECK-NEXT: ( 0, 0, 0, 13 + // CHECK-NEXT: ( 20, 21, 0, 0 + // CHECK-NEXT: ( 30, 31, 0, 0 + call @transfer_write_2d_mask(%A, %c0, %c0) : (memref, index, index) -> () + call @load_and_print(%A, %c0, %c0) : (memref, index, index) -> () + + // 4. Reload 3. + store + transpose. + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 0, 20, 30 + // CHECK-NEXT: ( 0, 0, 21, 31 + // CHECK-NEXT: ( 0, 0, 0, 0 + // CHECK-NEXT: ( 3, 13, 0, 0 + call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref, index, index) -> () + call @load_and_print(%A, %c0, %c0) : (memref, index, index) -> () + + // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2). + // The mask applies after permutation + // CHECK-LABEL: TILE BEGIN: + // CHECK-NEXT: ( 0, 0, 20, 30 + // CHECK-NEXT: ( 0, 0, 21, 31 + // CHECK-NEXT: ( 20, 21, 0, 0 + // CHECK-NEXT: ( 30, 31, 0, 0 + call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref, index, index) -> () + call @load_and_print(%A, %c0, %c0) : (memref, index, index) -> () + + memref.dealloc %A : memref + + return +} + +llvm.mlir.global internal constant @tile_begin("TILE BEGIN: \0A\00")