Skip to content

Commit 8bb290a

Browse files
committed
[mlir] Add support for vector types whose number of elements are from a range of values.
Add types and predicates for Vector, Fixed Vector, and Scalable Vector whose number of elements is from a given `allowedRanges` list. The list has two values, start and end of the range (inclusive).
1 parent e0f86ca commit 8bb290a

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

mlir/include/mlir/IR/CommonTypeConstraints.td

+70
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
546546
ScalableVectorOfLength<allowedLengths>.summary,
547547
"::mlir::VectorType">;
548548

549+
// Whether the number of elements of a vector is from the given
550+
// `allowedRanges` list, the list has two values, start and end
551+
// of the range (inclusive).
552+
class IsVectorOfLengthRangePred<list<int> allowedRanges>
553+
: And<[IsVectorTypePred,
554+
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
555+
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
556+
557+
// Whether the number of elements of a fixed-length vector is from the given
558+
// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
559+
class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
560+
: And<[IsFixedVectorTypePred,
561+
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
562+
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
563+
564+
// Whether the number of elements of a scalable vector is from the given
565+
// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
566+
class IsScalableVectorOfLengthRangePred<list<int> allowedRanges>
567+
: And<[IsScalableVectorTypePred,
568+
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
569+
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
570+
571+
// Any vector where the number of elements is from the given
572+
// `allowedRanges` list.
573+
class VectorOfLengthRange<list<int> allowedRanges>
574+
: Type<IsVectorOfLengthRangePred<allowedRanges>,
575+
" of length " # !interleave(allowedRanges, "-"),
576+
"::mlir::VectorType">;
577+
578+
// Any fixed-length vector where the number of elements is from the given
579+
// `allowedRanges` list.
580+
class FixedVectorOfLengthRange<list<int> allowedRanges>
581+
: Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
582+
" of length " # !interleave(allowedRanges, "-"),
583+
"::mlir::VectorType">;
584+
585+
// Any scalable vector where the number of elements is from the given
586+
// `allowedRanges` list.
587+
class ScalableVectorOfLengthRange<list<int> allowedRanges>
588+
: Type<IsScalableVectorOfLengthRangePred<allowedRanges>,
589+
" of length " # !interleave(allowedRanges, "-"),
590+
"::mlir::VectorType">;
591+
592+
// Any vector where the number of elements is from the given
593+
// `allowedRanges` list and the type is from the given `allowedTypes`
594+
// list.
595+
class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
596+
: Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
597+
VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
598+
"::mlir::VectorType">;
599+
600+
// Any fixed-length vector where the number of elements is from the given
601+
// `allowedRanges` list and the type is from the given `allowedTypes`
602+
// list.
603+
class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
604+
: Type<
605+
And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
606+
FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
607+
"::mlir::VectorType">;
608+
609+
// Any scalable vector where the number of elements is from the given
610+
// `allowedRanges` list and the type is from the given `allowedTypes`
611+
// list.
612+
class ScalableVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
613+
: Type<
614+
And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfLengthRange<allowedRanges>.predicate]>,
615+
ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfLengthRange<allowedRanges>.summary,
616+
"::mlir::VectorType">;
617+
618+
549619
def AnyVector : VectorOf<[AnyType]>;
550620
// Temporary vector type clone that allows gradual transition to 0-D vectors.
551621
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

0 commit comments

Comments
 (0)