diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3f45d0804e045..c5b08d6aa022b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2770,11 +2770,11 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure, TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<( // TODO: tighten vector element types that make sense. - ins VectorOfRankAndType<[1], + ins FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix, I32Attr:$rows, I32Attr:$columns)>, Results<( - outs VectorOfRankAndType<[1], + outs FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> { let summary = "Vector matrix transposition on flattened 1-D MLIR vectors"; let description = [{ @@ -2789,7 +2789,9 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure, a 2-D matrix with rows and columns, and returns the transposed matrix in flattened form in 'res'. - Also see: + Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not + support scalable vectors. Hence, this Op is only available for fixed-width + vectors. Also see: http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 56039d04549aa..d591c60acb64e 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1900,3 +1900,12 @@ func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) { return } + +// ----- + +func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32> { + // expected-error @+1 {{'vector.flat_transpose' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[16]xf32>'}} + %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } + : vector<[16]xf32> -> vector<[16]xf32> + return %0 : vector<[16]xf32> +}