diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 76d97f106dcb8..56fbe9cdc2d21 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -964,7 +964,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { (see [vector.transfer_read](../Vector/#vectortransfer_read-mlirvectortransferreadop)). }]; - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let builders = [ /// Builds an affine vector load op with the specified map and operands. @@ -1031,7 +1031,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { (see [vector.transfer_write](../Vector/#vectortransfer_write-mlirvectortransferwriteop)). }]; - let arguments = (ins AnyVector:$value, + let arguments = (ins AnyVectorOfNonZeroRank:$value, Arg:$memref, Variadic:$indices, diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td index 9cc792093bf83..475b11f12c5f0 100644 --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect { //===----------------------------------------------------------------------===// class NeonVectorOfLength : ShapedContainerType< - [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>, + [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>, "a vector with length " # length, "::mlir::VectorType">; diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 9a058ae4fe764..6fd992afbf043 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -371,7 +371,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ let arguments = (ins Arg:$base, Variadic:$indices, - Optional:$padding, Optional:$mask, + Optional:$padding, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); let results = (outs SMETile:$result); @@ -444,7 +444,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ }]; let arguments = (ins SMETile:$valueToStore, Arg:$base, - Variadic:$indices, Optional:$mask, + Variadic:$indices, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); let extraClassDeclaration = [{ @@ -799,9 +799,9 @@ class OuterProductWideningBase { let arguments = (ins - AnyTypeOf:$lhs, AnyVector:$rhs, - Optional:$lhsMask, Optional:$rhsMask, - Optional:$acc); + AnyTypeOf:$lhs, AnyVectorOfNonZeroRank:$rhs, + Optional:$lhsMask, Optional:$rhsMask, + Optional:$acc); let results = (outs AnyTypeOf:$result); let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index d7e8b22fbd2d3..cdcf4d8752e87 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -100,11 +100,11 @@ class ScalableMaskedFOp:$mask, - ScalableVectorOf<[AnyFloat]>:$src1, - ScalableVectorOf<[AnyFloat]>:$src2 + ScalableVectorOfAnyRank<[I1]>:$mask, + ScalableVectorOfAnyRank<[AnyFloat]>:$src1, + ScalableVectorOfAnyRank<[AnyFloat]>:$src2 ); - let results = (outs ScalableVectorOf<[AnyFloat]>:$res); + let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res); let assemblyFormat = "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)"; } @@ -123,11 +123,11 @@ class ScalableMaskedIOp:$mask, - ScalableVectorOf<[I8, I16, I32, I64]>:$src1, - ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ScalableVectorOfAnyRank<[I1]>:$mask, + ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1, + ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2 ); - let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res); + let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res); let assemblyFormat = "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)"; } @@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">; def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def UdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udot">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedAddIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"add">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedAddFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fadd">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedMulIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"mul">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedMulFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fmul">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedSubIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sub">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedSubFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fsub">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedSDivIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdiv">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedUDivIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udiv">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ScalableMaskedDivFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fdiv">, - Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; def ConvertFromSvboolIntrOp : ArmSVE_IntrOp<"convert.from.svbool", @@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2", /*overloadedOperands=*/[0], /*overloadedResults=*/[], /*numResults=*/2>, - Arguments<(ins Arg:$v1, - Arg:$v2)>; + Arguments<(ins Arg:$v1, + Arg:$v2)>; // Note: This multi-vector intrinsic requires SME2. def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4", @@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4", /*overloadedOperands=*/[0], /*overloadedResults=*/[], /*numResults=*/4>, - Arguments<(ins Arg:$v1, - Arg:$v2, - Arg:$v3, - Arg:$v4)>; + Arguments<(ins Arg:$v1, + Arg:$v2, + Arg:$v3, + Arg:$v4)>; // Note: This intrinsic requires SME or SVE2.1. def PselIntrOp : ArmSVE_IntrOp<"psel", diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 1f52f6b91617c..b39f2ee594cd4 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -255,7 +255,7 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [ let arguments = (ins Arg]>:$srcMemref, Variadic:$indices, BoolAttr:$transpose, I32Attr:$numTiles); - let results = (outs AnyVector:$res); + let results = (outs AnyVectorOfNonZeroRank:$res); let assemblyFormat = [{ $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) }]; @@ -301,13 +301,13 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> { (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> ``` }]; - let arguments = (ins AnyVector:$matrixA, - AnyVector:$matrixB, - AnyVector:$matrixC, + let arguments = (ins AnyVectorOfNonZeroRank:$matrixA, + AnyVectorOfNonZeroRank:$matrixB, + AnyVectorOfNonZeroRank:$matrixC, I64ArrayAttr:$mmaShape, OptionalAttr:$tf32Enabled); - let results = (outs AnyVector:$res); + let results = (outs AnyVectorOfNonZeroRank:$res); let builders = [ OpBuilder<(ins "Value":$matrixA, @@ -357,16 +357,16 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> { ``` }]; - let arguments = (ins AnyVector:$matrixA, - AnyVector:$matrixB, - AnyVector:$matrixC, + let arguments = (ins AnyVectorOfNonZeroRank:$matrixA, + AnyVectorOfNonZeroRank:$matrixB, + AnyVectorOfNonZeroRank:$matrixC, NVGPU_MmaSparseSyncMetadataType:$sparseMetadata, I64ArrayAttr:$mmaShape, DefaultValuedAttr:$sparsitySelector, OptionalAttr:$tf32Enabled ); - let results = (outs AnyVector:$res); + let results = (outs AnyVectorOfNonZeroRank:$res); let builders = [ OpBuilder<(ins "Value":$matrixA, @@ -825,10 +825,10 @@ def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure, The input and output must be of the same vector type and shape. }]; - let arguments = (ins VectorOf<[F32]>:$in, + let arguments = (ins VectorOfNonZeroRankOf<[F32]>:$in, DefaultValuedAttr:$rounding, UnitAttr:$ftz); - let results = (outs VectorOf<[F32]>:$out); + let results = (outs VectorOfNonZeroRankOf<[F32]>:$out); let assemblyFormat = [{ $in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}` attr-dict `:` type($out) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index a4b43d656fe43..a6d3163d4446f 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -166,7 +166,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[ class Tosa_TypeLike types, string description = ""> : TypeConstraint.predicate, - VectorOf.predicate, + VectorOfNonZeroRankOf.predicate, TosaTensorOf.predicate]>, description>; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index cc4cafa869e63..1e257136988bb 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -40,7 +40,7 @@ def Vector_ContractionOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, + Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyVectorOfNonZeroRank:$rhs, AnyType:$acc, ArrayAttr:$indexing_maps, Vector_IteratorTypeArrayAttr:$iterator_types, DefaultValuedAttr]>, Arguments<(ins Vector_CombiningKindAttr:$kind, - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, AnyType:$acc, DenseI64ArrayAttr:$reduction_dims)>, Results<(outs AnyType:$dest)> { @@ -417,16 +417,18 @@ def Vector_BroadcastOp : let hasVerifier = 1; } -def Vector_ShuffleOp : - Vector_Op<"shuffle", [Pure, - PredOpTrait<"first operand v1 and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"second operand v2 and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>, - InferTypeOpAdaptor]>, - Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2, - DenseI64ArrayAttr:$mask)>, - Results<(outs AnyVector:$vector)> { +def Vector_ShuffleOp + : Vector_Op< + "shuffle", + [Pure, + PredOpTrait<"first operand v1 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"second operand v2 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>, + InferTypeOpAdaptor]>, + Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2, + DenseI64ArrayAttr:$mask)>, + Results<(outs AnyVectorOfNonZeroRank:$vector)> { let summary = "shuffle operation"; let description = [{ The shuffle operation constructs a permutation (or duplication) of elements @@ -531,7 +533,7 @@ def Vector_InterleaveOp : }]; let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result) @@ -610,8 +612,8 @@ def Vector_DeinterleaveOp : ``` }]; - let arguments = (ins AnyVector:$source); - let results = (outs AnyVector:$res1, AnyVector:$res2); + let arguments = (ins AnyVectorOfNonZeroRank:$source); + let results = (outs AnyVectorOfNonZeroRank:$res1, AnyVectorOfNonZeroRank:$res2); let assemblyFormat = [{ $source attr-dict `:` type($source) `->` type($res1) @@ -1048,9 +1050,9 @@ def Vector_InsertStridedSliceOp : PredOpTrait<"operand #0 and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets, + Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets, I64ArrayAttr:$strides)>, - Results<(outs AnyVector:$res)> { + Results<(outs AnyVectorOfNonZeroRank:$res)> { let summary = "strided_slice operation"; let description = [{ Takes a k-D source vector, an n-D destination vector (n >= k), n-sized @@ -1107,10 +1109,10 @@ def Vector_OuterProductOp : PredOpTrait<"rhs operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>, DeclareOpInterfaceMethods]>, - Arguments<(ins AnyVector:$lhs, AnyType:$rhs, - Optional:$acc, + Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyType:$rhs, + Optional:$acc, DefaultValuedAttr:$kind)>, - Results<(outs AnyVector)> { + Results<(outs AnyVectorOfNonZeroRank)> { let summary = "vector outerproduct with optional fused add"; let description = [{ Takes 2 1-D vectors and returns the 2-D vector containing the outer-product, @@ -1190,9 +1192,9 @@ def Vector_ExtractStridedSliceOp : Vector_Op<"extract_strided_slice", [Pure, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, - Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets, + Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets, I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>, - Results<(outs AnyVector)> { + Results<(outs AnyVectorOfNonZeroRank)> { let summary = "extract_strided_slice operation"; let description = [{ Takes an n-D vector, k-D `offsets` integer array attribute, a k-sized @@ -1254,7 +1256,7 @@ def Vector_TransferReadOp : Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, - Optional>:$mask, + Optional>:$mask, BoolArrayAttr:$in_bounds)>, Results<(outs AnyVectorOfAnyRank:$vector)> { @@ -1502,7 +1504,7 @@ def Vector_TransferWriteOp : AnyShaped:$source, Variadic:$indices, AffineMapAttr:$permutation_map, - Optional>:$mask, + Optional>:$mask, BoolArrayAttr:$in_bounds)>, Results<(outs Optional:$result)> { @@ -1825,9 +1827,9 @@ def Vector_MaskedLoadOp : Vector_Op<"maskedload">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOf<[I1]>:$mask, - AnyVector:$pass_thru)>, - Results<(outs AnyVector:$result)> { + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$pass_thru)>, + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "loads elements from memory into a vector as defined by a mask vector"; @@ -1888,8 +1890,8 @@ def Vector_MaskedStoreOp : Vector_Op<"maskedstore">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOf<[I1]>:$mask, - AnyVector:$valueToStore)> { + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$valueToStore)> { let summary = "stores elements from a vector into memory as defined by a mask vector"; @@ -1951,10 +1953,10 @@ def Vector_GatherOp : ]>, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOf<[AnyInteger, Index]>:$index_vec, - VectorOf<[I1]>:$mask, - AnyVector:$pass_thru)>, - Results<(outs AnyVector:$result)> { + VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$pass_thru)>, + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = [{ gathers elements from memory or ranked tensor into a vector as defined by an @@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp : Vector_Op<"expandload">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOf<[I1]>:$mask, - AnyVector:$pass_thru)>, - Results<(outs AnyVector:$result)> { + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$pass_thru)>, + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "reads elements from memory and spreads them into a vector as defined by a mask"; @@ -2149,8 +2151,8 @@ def Vector_CompressStoreOp : Vector_Op<"compressstore">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOf<[I1]>:$mask, - AnyVector:$valueToStore)> { + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$valueToStore)> { let summary = "writes elements selectively from a vector as defined by a mask"; @@ -2508,7 +2510,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ }]; // TODO: Support multiple passthru values. - let arguments = (ins VectorOf<[I1]>:$mask, + let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$mask, Optional:$passthru); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$maskRegion); @@ -2891,11 +2893,11 @@ def Vector_ScanOp : AllTypesMatch<["source", "dest"]>, AllTypesMatch<["initial_value", "accumulated_value"]> ]>, Arguments<(ins Vector_CombiningKindAttr:$kind, - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, AnyVectorOfAnyRank:$initial_value, I64Attr:$reduction_dim, BoolAttr:$inclusive)>, - Results<(outs AnyVector:$dest, + Results<(outs AnyVectorOfNonZeroRank:$dest, AnyVectorOfAnyRank:$accumulated_value)> { let summary = "Scan operation"; let description = [{ diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 48e4c24f83865..fc4383d08422c 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -22,15 +22,15 @@ include "mlir/IR/DialectBase.td" // Whether a type is a VectorType. // Explicitly disallow 0-D vectors for now until we have good enough coverage. -def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, - CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; +def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. // TODO: Remove this when all ops support 0-D vectors. def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">; // Whether a type is a fixed-length VectorType. -def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && +def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && !::llvm::cast($_self).isScalable()}]>; // Whether a type is a scalable VectorType. @@ -53,7 +53,7 @@ def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[ // Whether a type is a VectorType and all dimensions are scalable. def IsVectorTypeWithAllDimsScalablePred : And<[ - IsVectorTypePred, + IsVectorOfNonZeroRankTypePred, CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]> ]>; @@ -428,8 +428,8 @@ class ValueSemanticsContainerOf allowedTypes> : // Vector types. -class VectorOf allowedTypes> : - ShapedContainerType allowedTypes> : + ShapedContainerType; // Temporary vector type clone that allows gradual transition to 0-D vectors. @@ -438,11 +438,11 @@ class VectorOfAnyRankOf allowedTypes> : ShapedContainerType; -class FixedVectorOf allowedTypes> : - ShapedContainerType allowedTypes> : + ShapedContainerType; -class ScalableVectorOf allowedTypes> : +class ScalableVectorOfAnyRank allowedTypes> : ShapedContainerType; @@ -458,7 +458,7 @@ class VectorWithTrailingDimScalableOf allowedTypes> : // Whether the number of elements of a vector is from the given // `allowedRanks` list class IsVectorOfRankPred allowedRanks> : - And<[IsVectorTypePred, + And<[IsVectorOfNonZeroRankTypePred, Or($_self).getRank() == }] @@ -467,7 +467,7 @@ class IsVectorOfRankPred allowedRanks> : // Whether the number of elements of a fixed-length vector is from the given // `allowedRanks` list class IsFixedVectorOfRankPred allowedRanks> : - And<[IsFixedVectorTypePred, + And<[IsFixedVectorOfAnyRankTypePred, Or($_self).getRank() == }] @@ -501,22 +501,22 @@ class ScalableVectorOfRank allowedRanks> : Type< // is from the given `allowedTypes` list class VectorOfRankAndType allowedRanks, list allowedTypes> : AllOfType< - [VectorOf, VectorOfRank], - VectorOf.summary # VectorOfRank.summary, + [VectorOfNonZeroRankOf, VectorOfRank], + VectorOfNonZeroRankOf.summary # VectorOfRank.summary, "::mlir::VectorType">; // Fixed-width vector where the rank is from the given `allowedRanks` list and // the type is from the given `allowedTypes` list class FixedVectorOfRankAndType allowedRanks, list allowedTypes> : AllOfType< - [FixedVectorOf, VectorOfRank], - FixedVectorOf.summary # VectorOfRank.summary, + [FixedVectorOfAnyRank, VectorOfRank], + FixedVectorOfAnyRank.summary # VectorOfRank.summary, "::mlir::VectorType">; // Whether the number of elements of a vector is from the given // `allowedLengths` list class IsVectorOfLengthPred allowedLengths> : - And<[IsVectorTypePred, + And<[IsVectorOfNonZeroRankTypePred, Or($_self).getNumElements() == }] @@ -525,7 +525,7 @@ class IsVectorOfLengthPred allowedLengths> : // Whether the number of elements of a fixed-length vector is from the given // `allowedLengths` list class IsFixedVectorOfLengthPred allowedLengths> : - And<[IsFixedVectorTypePred, + And<[IsFixedVectorOfAnyRankTypePred, Or($_self).getNumElements() == }] @@ -604,16 +604,16 @@ class ScalableVectorOfLength allowedLengths> : Type< // list class VectorOfLengthAndType allowedLengths, list allowedTypes> : AllOfType< - [VectorOf, VectorOfLength], - VectorOf.summary # VectorOfLength.summary, + [VectorOfNonZeroRankOf, VectorOfLength], + VectorOfNonZeroRankOf.summary # VectorOfLength.summary, "::mlir::VectorType">; // Any fixed-length vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class FixedVectorOfLengthAndType allowedLengths, list allowedTypes> : AllOfType< - [FixedVectorOf, FixedVectorOfLength], - FixedVectorOf.summary # + [FixedVectorOfAnyRank, FixedVectorOfLength], + FixedVectorOfAnyRank.summary # FixedVectorOfLength.summary, "::mlir::VectorType">; @@ -621,8 +621,8 @@ class FixedVectorOfLengthAndType allowedLengths, // `allowedLengths` list and the type is from the given `allowedTypes` list class ScalableVectorOfLengthAndType allowedLengths, list allowedTypes> : AllOfType< - [ScalableVectorOf, ScalableVectorOfLength], - ScalableVectorOf.summary # + [ScalableVectorOfAnyRank, ScalableVectorOfLength], + ScalableVectorOfAnyRank.summary # ScalableVectorOfLength.summary, "::mlir::VectorType">; @@ -632,10 +632,10 @@ class ScalableVectorOfLengthAndType allowedLengths, class ScalableVectorOfRankAndLengthAndType allowedRanks, list allowedLengths, list allowedTypes> : AllOfType< - [ScalableVectorOfRank, ScalableVectorOf, + [ScalableVectorOfRank, ScalableVectorOfAnyRank, ScalableVectorOfLength], ScalableVectorOfRank.summary # - ScalableVectorOf.summary # + ScalableVectorOfAnyRank.summary # ScalableVectorOfLength.summary, "::mlir::VectorType">; @@ -657,13 +657,14 @@ class VectorWithTrailingDimScalableOfSizeAndType allowedTrailingSizes, ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, "::mlir::VectorType">; -def AnyVector : VectorOf<[AnyType]>; -// Temporary vector type clone that allows gradual transition to 0-D vectors. +// Unlike the following definitions, this one excludes 0-D vectors +def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>; + def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; -def AnyFixedVector : FixedVectorOf<[AnyType]>; +def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>; -def AnyScalableVector : ScalableVectorOf<[AnyType]>; +def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>; // Shaped types. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index cfe19a2fd5c08..7caf3bc965797 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop", //===----------------------------------------------------------------------===// // Test InferIntRangeInterface //===----------------------------------------------------------------------===// -def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>; +def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOfNonZeroRankOf<[AnyInteger, Index]>]>; def TestWithBoundsOp : TEST_Op<"with_bounds", [DeclareOpInterfaceMethods,