diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index dab54b63d8d22..9b9dbff10ea2d 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ of a 2-D scalable vector tile at the given index. The type of the 1-D scalable vector to be moved must match the type of the tile slice. A tile slice is a 1-D vector of horizontally or vertically contiguous elements - within a ZA tile. Horizontal tile slices are currently assumed when - lowering to intrinsics. The updated tile is returned as the result. + within a ZA tile. The updated tile is returned as the result. - Example 1: Move a vector<[16]xi8> into tile at given index. + An optional tile slice layout attribute specifies whether the tile slice is + horizontal (default) or vertical. + + Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index. ```mlir %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8> ``` - Example 2: Move a vector<[2]xf64> into tile at given index. + Example 2: Move a vector<[2]xf64> into tile vertically at given index. ```mlir - %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64> + %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout : vector<[2]xf64> into vector<[2]x[2]xf64> ``` }]; let arguments = (ins - SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index); + SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index, + ArmSME_TileSliceLayoutAttr:$layout); let results = (outs SMETile:$result); let extraClassDeclaration = [{ @@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ }]; let assemblyFormat = [{ - $vector `,` $tile `,` $tile_slice_index + $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)? attr-dict `:` type($vector) `into` type($result) }]; } @@ -480,21 +483,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure, let description = [{ The tile slice to vector operation extracts a 1-D scalable slice from a 2-D scalable tile at the given index. A tile slice is a 1-D vector of - horizontally or vertically contiguous elements within a ZA tile. Horizontal - tile slices are currently assumed when lowering to intrinsics. + horizontally or vertically contiguous elements within a ZA tile. + + An optional tile slice layout attribute specifies whether the tile slice is + horizontal (default) or vertical. - Example 1: Extract `vector<[16]xi8>` from tile at the given index. + Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index. ```mlir %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8> ``` - Example 2: Extract `vector<[2]xf64>` from tile at the given index. + Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index. ```mlir - %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64> + %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[2]xf64> from vector<[2]x[2]xf64> ``` }]; - let arguments = (ins SMETile:$tile, Index:$tile_slice_index); + let arguments = (ins + SMETile:$tile, Index:$tile_slice_index, + ArmSME_TileSliceLayoutAttr:$layout + ); let results = (outs SVEVector:$result); let extraClassDeclaration = [{ @@ -502,7 +510,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure, }]; let assemblyFormat = [{ - $tile `[` $tile_slice_index `]` attr-dict + $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict `:` type($result) `from` type($tile) }]; } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 5e13707ea0aa2..1231da356f8ed 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering } }; -/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal -/// tile slices are currently supported. +/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. struct MoveVectorToTileSliceToArmSMELowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering auto tileI32 = castTileIDToI32(tile, loc, rewriter); - // Create 'arm_sme.intr.write.horiz' to write vector to tile slice. - rewriter.create( - loc, tileI32, tileSliceI32, allActiveMask, - moveVectorToTileSliceOp.getVector()); + // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. + switch (moveVectorToTileSliceOp.getLayout()) { + case arm_sme::TileSliceLayout::Horizontal: + rewriter.create( + loc, tileI32, tileSliceI32, allActiveMask, + moveVectorToTileSliceOp.getVector()); + break; + case arm_sme::TileSliceLayout::Vertical: + rewriter.create( + loc, tileI32, tileSliceI32, allActiveMask, + moveVectorToTileSliceOp.getVector()); + break; + } // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with // 'arm_sme.cast_tile_to_vector' to preserve dataflow. @@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering } }; -/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal -/// tile slices are currently supported. +/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. struct MoveTileSliceToVectorArmSMELowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering auto sliceIndexI32 = rewriter.create( loc, rewriter.getI32Type(), sliceIndex); - // Create 'arm_sme.intr.read.horiz' to extract the tile slice. - rewriter.replaceOpWithNewOp( - moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, - tileIdI32, sliceIndexI32); + // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. + switch (moveTileSliceToVector.getLayout()) { + case arm_sme::TileSliceLayout::Horizontal: + rewriter.replaceOpWithNewOp( + moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, + tileIdI32, sliceIndexI32); + break; + case arm_sme::TileSliceLayout::Vertical: + rewriter.replaceOpWithNewOp( + moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, + tileIdI32, sliceIndexI32); + break; + } return success(); } @@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget( arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, - arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa, + arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, + arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); target.addIllegalOp(); diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 07485b3ee8ddf..9074f0a7ee655 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s return } +//===----------------------------------------------------------------------===// +// arm_sme.move_vector_to_tile_slice +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32 +// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> () +func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () { + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> + return +} + +// ----- + +// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16 +// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> () +func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () { + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout : vector<[8]xbf16> into vector<[8]x[8]xbf16> + return +} //===----------------------------------------------------------------------===// // arm_sme.move_tile_slice_to_vector @@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64> return %slice : vector<[2]xf64> } + +// ----- + +// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128 +// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128> +func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> { + %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[1]xi128> from vector<[1]x[1]xi128> + return %slice : vector<[1]xi128> +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 427154158e797..e5ba81eff8360 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til return } +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout : vector<[16]xi8> into vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout : vector<[16]xi8> into vector<[16]x[16]xi8> + return +} //===----------------------------------------------------------------------===// // arm_sme.move_tile_slice_to_vector @@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64> return %slice : vector<[2]xf64> } + +// ----- + +func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> { + // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout : vector<[2]xf64> from vector<[2]x[2]xf64> + %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[2]xf64> from vector<[2]x[2]xf64> + return %slice : vector<[2]xf64> +}