diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp index cc9e36cfca419..8b2be7bc1901b 100644 --- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp @@ -31,7 +31,7 @@ bool mlir::arm_sme::isValidSMETileElementType(Type type) { } bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) { - if ((vType.getRank() != 2) && vType.allDimsScalable()) + if ((vType.getRank() != 2) || !vType.allDimsScalable()) return false; auto elemType = vType.getElementType(); diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index 8b6bd8f52d190..cb35de11ab5b3 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -154,6 +154,17 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor) -> te return %0 : tensor } +// ----- + +// CHECK-LABEL: @transfer_write_2d__fixed +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.tile_store +func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref + return +} + // ============================================================================= // vector.broadcast // =============================================================================