diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 92ce053ad5c82..2039316e6ba25 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Support/ADTExtras.h" +#include "llvm/ADT/STLExtras.h" namespace llvm { class BitVector; @@ -181,6 +182,239 @@ class BaseMemRefType : public Type, public ShapedType::Trait { operator ShapedType() const { return llvm::cast(*this); } }; +//===----------------------------------------------------------------------===// +// VectorDim +//===----------------------------------------------------------------------===// + +/// This class represents a dimension of a vector type. Unlike other ShapedTypes +/// vector dimensions can have scalable quantities, which means the dimension +/// has a known minimum size, which is scaled by a constant that is only +/// known at runtime. +class VectorDim { +public: + explicit constexpr VectorDim(int64_t quantity, bool scalable) + : quantity(quantity), scalable(scalable){}; + + /// Constructs a new fixed dimension. + constexpr static VectorDim getFixed(int64_t quantity) { + return VectorDim(quantity, false); + } + + /// Constructs a new scalable dimension. + constexpr static VectorDim getScalable(int64_t quantity) { + return VectorDim(quantity, true); + } + + /// Returns true if this dimension is scalable; + constexpr bool isScalable() const { return scalable; } + + /// Returns true if this dimension is fixed. + constexpr bool isFixed() const { return !isScalable(); } + + /// Returns the minimum number of elements this dimension can contain. + constexpr int64_t getMinSize() const { return quantity; } + + /// If this dimension is fixed returns the number of elements, otherwise + /// aborts. + constexpr int64_t getFixedSize() const { + assert(isFixed()); + return quantity; + } + + constexpr bool operator==(VectorDim const &dim) const { + return quantity == dim.quantity && scalable == dim.scalable; + } + + constexpr bool operator!=(VectorDim const &dim) const { + return !(*this == dim); + } + + /// Print the dim. + void print(raw_ostream &os) { + if (isScalable()) + os << '['; + os << getMinSize(); + if (isScalable()) + os << ']'; + } + + /// Helper class for indexing into a list of sizes (and possibly empty) list + /// of scalable dimensions, extracting VectorDim elements. + struct Indexer { + explicit Indexer(ArrayRef sizes, ArrayRef scalableDims) + : sizes(sizes), scalableDims(scalableDims) { + assert( + scalableDims.empty() || + sizes.size() == scalableDims.size() && + "expected `scalableDims` to be empty or match `sizes` in length"); + } + + VectorDim operator[](size_t idx) const { + int64_t size = sizes[idx]; + bool scalable = scalableDims.empty() ? false : scalableDims[idx]; + return VectorDim(size, scalable); + } + + ArrayRef sizes; + ArrayRef scalableDims; + }; + +private: + int64_t quantity; + bool scalable; +}; + +inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) { + dim.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// VectorDims +//===----------------------------------------------------------------------===// + +/// Represents a non-owning list of vector dimensions. The underlying dimension +/// sizes and scalability flags are stored a two seperate lists to match the +/// storage of a VectorType. +class VectorDims : public VectorDim::Indexer { +public: + using VectorDim::Indexer::Indexer; + + class Iterator : public llvm::iterator_facade_base< + Iterator, std::random_access_iterator_tag, VectorDim, + std::ptrdiff_t, VectorDim, VectorDim> { + public: + Iterator(VectorDim::Indexer indexer, size_t index) + : indexer(indexer), index(index){}; + + // Iterator boilerplate. + ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; } + bool operator==(const Iterator &rhs) const { return index == rhs.index; } + bool operator<(const Iterator &rhs) const { return index < rhs.index; } + Iterator &operator+=(ptrdiff_t offset) { + index += offset; + return *this; + } + Iterator &operator-=(ptrdiff_t offset) { + index -= offset; + return *this; + } + VectorDim operator*() const { return indexer[index]; } + + VectorDim::Indexer getIndexer() const { return indexer; } + ptrdiff_t getIndex() const { return index; } + + private: + VectorDim::Indexer indexer; + ptrdiff_t index; + }; + + // Generic definitions. + using value_type = VectorDim; + using iterator = Iterator; + using const_iterator = Iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + using size_type = size_t; + using difference_type = ptrdiff_t; + + /// Construct from iterator pair. + VectorDims(Iterator begin, Iterator end) + : VectorDims(VectorDims(begin.getIndexer()) + .slice(begin.getIndex(), end - begin)) {} + + VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){}; + + Iterator begin() const { return Iterator(*this, 0); } + Iterator end() const { return Iterator(*this, size()); } + + /// Check if the dims are empty. + bool empty() const { return sizes.empty(); } + + /// Get the number of dims. + size_t size() const { return sizes.size(); } + + /// Return the first dim. + VectorDim front() const { return (*this)[0]; } + + /// Return the last dim. + VectorDim back() const { return (*this)[size() - 1]; } + + /// Chop of thie first \p n dims, and keep the remaining \p m + /// dims. + VectorDims slice(size_t n, size_t m) const { + ArrayRef newSizes = sizes.slice(n, m); + ArrayRef newScalableDims = + scalableDims.empty() ? ArrayRef{} : scalableDims.slice(n, m); + return VectorDims(newSizes, newScalableDims); + } + + /// Drop the first \p n dims. + VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); } + + /// Drop the last \p n dims. + VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); } + + /// Return a copy of *this with only the first \p n elements. + VectorDims takeFront(size_t n = 1) const { + if (n >= size()) + return *this; + return dropBack(size() - n); + } + + /// Return a copy of *this with only the last \p n elements. + VectorDims takeBack(size_t n = 1) const { + if (n >= size()) + return *this; + return dropFront(size() - n); + } + + /// Return copy of *this with the first n dims matching the predicate removed. + template + VectorDims dropWhile(PredicateT predicate) const { + return VectorDims(llvm::find_if_not(*this, predicate), end()); + } + + /// Returns true if one or more of the dims are scalable. + bool hasScalableDims() const { + return llvm::is_contained(getScalableDims(), true); + } + + /// Check for dim equality. + bool equals(VectorDims rhs) const { + if (size() != rhs.size()) + return false; + return std::equal(begin(), end(), rhs.begin()); + } + + /// Check for dim equality. + bool equals(ArrayRef rhs) const { + if (size() != rhs.size()) + return false; + return std::equal(begin(), end(), rhs.begin()); + } + + /// Return the underlying sizes. + ArrayRef getSizes() const { return sizes; } + + /// Return the underlying scalable dims. + ArrayRef getScalableDims() const { return scalableDims; } +}; + +inline bool operator==(VectorDims lhs, VectorDims rhs) { + return lhs.equals(rhs); +} + +inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); } + +inline bool operator==(VectorDims lhs, ArrayRef rhs) { + return lhs.equals(rhs); +} + +inline bool operator!=(VectorDims lhs, ArrayRef rhs) { + return !(lhs == rhs); +} + } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 4cade83dd3c32..8835074efbc66 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -1114,6 +1114,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty scalableDims = isScalableVec; } return $_get(elementType.getContext(), shape, elementType, scalableDims); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef": $shape), [{ + SmallVector sizes; + SmallVector scalableDims; + for (VectorDim dim : shape) { + sizes.push_back(dim.getMinSize()); + scalableDims.push_back(dim.isScalable()); + } + return get(sizes, elementType, scalableDims); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{ + return get(shape.getSizes(), elementType, shape.getScalableDims()); }]> ]; let extraClassDeclaration = [{ @@ -1121,6 +1133,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty /// Arguments that are passed into the builder must outlive the builder. class Builder; + /// Returns the value of the specified dimension (including scalability). + VectorDim getDim(unsigned idx) const { + assert(idx < getRank() && "invalid dim index for vector type"); + return getDims()[idx]; + } + + /// Returns the dimensions of this vector type (including scalability). + VectorDims getDims() const { + return VectorDims(getShape(), getScalableDims()); + } + /// Returns true if the given type can be used as an element of a vector /// type. In particular, vectors can consist of integer, index, or float /// primitives. diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 3a01795ce3f53..bfbe51d4e4e32 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -490,12 +490,11 @@ FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const { return {}; if (type.getShape().empty()) return VectorType::get({1}, elementType); - Type vectorType = VectorType::get(type.getShape().back(), elementType, - type.getScalableDims().back()); + Type vectorType = VectorType::get(elementType, type.getDims().takeBack()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); // Only the trailing dimension can be scalable. - if (llvm::is_contained(type.getScalableDims().drop_back(), true)) + if (type.getDims().dropBack().hasScalableDims()) return failure(); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index cd5df0be740b9..881bddaf228f9 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -37,8 +37,7 @@ using namespace mlir::vector; // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().take_back(), tp.getElementType(), - tp.getScalableDims().take_back()); + return VectorType::get(tp.getElementType(), tp.getDims().takeBack()); } // Helper that picks the proper sequence for inserting. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 2ee314e9fedfe..4869cf304e75b 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -319,12 +319,13 @@ static FailureOr unpackOneDim(MemRefType type) { auto vectorType = dyn_cast(type.getElementType()); // Vectors with leading scalable dims are not supported. // It may be possible to support these in future by using dynamic memref dims. - if (vectorType.getScalableDims().front()) + VectorDim leadingDim = vectorType.getDims().front(); + if (leadingDim.isScalable()) return failure(); auto memrefShape = type.getShape(); SmallVector newMemrefShape; newMemrefShape.append(memrefShape.begin(), memrefShape.end()); - newMemrefShape.push_back(vectorType.getDimSize(0)); + newMemrefShape.push_back(leadingDim.getFixedSize()); return MemRefType::get(newMemrefShape, VectorType::Builder(vectorType).dropDim(0)); } @@ -1091,18 +1092,17 @@ struct UnrollTransferReadConversion auto vecType = dyn_cast(vec.getType()); auto xferVecType = xferOp.getVectorType(); - if (xferVecType.getScalableDims()[0]) { + VectorDim dim = xferVecType.getDim(0); + if (dim.isScalable()) { // Cannot unroll a scalable dimension at compile time. return failure(); } VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); - int64_t dimSize = xferVecType.getShape()[0]; - // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); - for (int64_t i = 0; i < dimSize; ++i) { + for (int64_t i = 0; i < dim.getFixedSize(); ++i) { Value iv = rewriter.create(loc, i); vec = generateInBoundsCheck( diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp index 594c9b4c270f2..fa49a21eafa14 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -30,8 +30,7 @@ using namespace mlir::arm_sve; static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto sVectorType = llvm::dyn_cast(type)) - return VectorType::get(sVectorType.getShape(), i1Type, - sVectorType.getScalableDims()); + return VectorType::get(i1Type, sVectorType.getDims()); return nullptr; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index c21d007c931b9..08a476c37b3f3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1217,9 +1217,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, assert(vecOperand && "Vector operand couldn't be found"); if (firstMaxRankedType) { - auto vecType = VectorType::get(firstMaxRankedType.getShape(), - getElementTypeOrSelf(vecOperand.getType()), - firstMaxRankedType.getScalableDims()); + auto vecType = VectorType::get(getElementTypeOrSelf(vecOperand.getType()), + firstMaxRankedType.getDims()); vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType)); } else { vecOperands.push_back(vecOperand); @@ -1230,8 +1229,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, for (Type resultType : op->getResultTypes()) { resultTypes.push_back( firstMaxRankedType - ? VectorType::get(firstMaxRankedType.getShape(), resultType, - firstMaxRankedType.getScalableDims()) + ? VectorType::get(resultType, firstMaxRankedType.getDims()) : resultType); } // d. Build and return the new op. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c462b23e1133f..7c1af857800dc 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -422,23 +422,21 @@ MultiDimReductionOp::getShapeForUnroll() { } LogicalResult MultiDimReductionOp::verify() { - SmallVector targetShape; - SmallVector scalableDims; + SmallVector targetDims; Type inferredReturnType; - auto sourceScalableDims = getSourceVectorType().getScalableDims(); - for (auto it : llvm::enumerate(getSourceVectorType().getShape())) - if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return llvm::cast(attr).getValue() == it.index(); - })) { - targetShape.push_back(it.value()); - scalableDims.push_back(sourceScalableDims[it.index()]); + for (auto [idx, dim] : llvm::enumerate(getSourceVectorType().getDims())) + if (!llvm::any_of(getReductionDims().getValue(), + [idx = idx](Attribute attr) { + return llvm::cast(attr).getValue() == idx; + })) { + targetDims.push_back(dim); } // TODO: update to also allow 0-d vectors when available. - if (targetShape.empty()) + if (targetDims.empty()) inferredReturnType = getSourceVectorType().getElementType(); else - inferredReturnType = VectorType::get( - targetShape, getSourceVectorType().getElementType(), scalableDims); + inferredReturnType = + VectorType::get(getSourceVectorType().getElementType(), targetDims); if (getType() != inferredReturnType) return emitOpError() << "destination type " << getType() << " is incompatible with source type " @@ -450,9 +448,8 @@ LogicalResult MultiDimReductionOp::verify() { /// Returns the mask type expected by this operation. Type MultiDimReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getScalableDims()); + return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getDims()); } namespace { @@ -491,8 +488,7 @@ struct ElideUnitDimsInMultiDimReduction if (auto dstVecType = dyn_cast(reductionOp.getDestType())) { if (mask) { VectorType newMaskType = - VectorType::get(dstVecType.getShape(), rewriter.getI1Type(), - dstVecType.getScalableDims()); + VectorType::get(rewriter.getI1Type(), dstVecType.getDims()); mask = rewriter.create(loc, newMaskType, mask); } cast = rewriter.create( @@ -559,9 +555,8 @@ LogicalResult ReductionOp::verify() { /// Returns the mask type expected by this operation. Type ReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getScalableDims()); + return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getDims()); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -1252,8 +1247,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional, auto n = std::min(adaptor.getStaticPosition().size(), vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( - vectorType.getShape().drop_front(n), vectorType.getElementType(), - vectorType.getScalableDims().drop_front(n))); + vectorType.getElementType(), vectorType.getDims().dropFront(n))); } return success(); } @@ -3040,15 +3034,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { VectorType resType; if (vRHS) { - SmallVector scalableDimsRes{vLHS.getScalableDims()[0], - vRHS.getScalableDims()[0]}; - resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType(), scalableDimsRes); + resType = VectorType::get(vLHS.getElementType(), + {vLHS.getDim(0), vRHS.getDim(0)}); } else { // Scalar RHS operand - SmallVector scalableDimsRes{vLHS.getScalableDims()[0]}; - resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), - scalableDimsRes); + resType = VectorType::get(vLHS.getElementType(), {vLHS.getDim(0)}); } if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) { @@ -3115,9 +3105,8 @@ LogicalResult OuterProductOp::verify() { /// verification purposes. It requires the operation to be vectorized." Type OuterProductOp::getExpectedMaskType() { auto vecType = this->getResultVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getScalableDims()); + return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getDims()); } //===----------------------------------------------------------------------===// @@ -5064,25 +5053,13 @@ class ShapeCastConstantFolder final : public OpRewritePattern { /// vector<4x1x1xi1> --> vector<4x1> /// static VectorType trimTrailingOneDims(VectorType oldType) { - ArrayRef oldShape = oldType.getShape(); - ArrayRef newShape = oldShape; - - ArrayRef oldScalableDims = oldType.getScalableDims(); - ArrayRef newScalableDims = oldScalableDims; - - while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) { - newShape = newShape.drop_back(1); - newScalableDims = newScalableDims.drop_back(1); - } - - // Make sure we have at least 1 dimension. + // Note: This will always keep at least one dim (even if it's a unit dim). // TODO: Add support for 0-D vectors. - if (newShape.empty()) { - newShape = oldShape.take_back(); - newScalableDims = oldScalableDims.take_back(); - } + VectorDims newDims = oldType.getDims(); + while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1)) + newDims = newDims.dropBack(); - return VectorType::get(newShape, oldType.getElementType(), newScalableDims); + return VectorType::get(oldType.getElementType(), newDims); } /// Folds qualifying shape_cast(create_mask) into a new create_mask diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4a5e8fcfb6eda..941c00fd96b56 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -113,13 +113,11 @@ struct TransferReadPermutationLowering permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); // Apply the reverse transpose to deduce the type of the transfer_read. - ArrayRef originalShape = op.getVectorType().getShape(); - SmallVector newVectorShape(originalShape.size()); - ArrayRef originalScalableDims = op.getVectorType().getScalableDims(); - SmallVector newScalableDims(originalShape.size()); + auto originalDims = op.getVectorType().getDims(); + SmallVector newVectorDims(op.getVectorType().getRank(), + VectorDim::getFixed(0)); for (const auto &pos : llvm::enumerate(permutation)) { - newVectorShape[pos.value()] = originalShape[pos.index()]; - newScalableDims[pos.value()] = originalScalableDims[pos.index()]; + newVectorDims[pos.value()] = originalDims[pos.index()]; } // Transpose in_bounds attribute. @@ -129,8 +127,8 @@ struct TransferReadPermutationLowering : ArrayAttr(); // Generate new transfer_read operation. - VectorType newReadType = VectorType::get( - newVectorShape, op.getVectorType().getElementType(), newScalableDims); + VectorType newReadType = + VectorType::get(op.getVectorType().getElementType(), newVectorDims); Value newRead = rewriter.create( op.getLoc(), newReadType, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 4d43a76c4a4ef..0ede6239a9495 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -344,10 +344,8 @@ class TransposeOpLowering : public OpRewritePattern { // Source with leading unit dim (inverse) is also replaced. Unit dim must // be fixed. Non-unit can be scalable. if (resType.getRank() == 2 && - ((resType.getShape().front() == 1 && - !resType.getScalableDims().front()) || - (resType.getShape().back() == 1 && - !resType.getScalableDims().back())) && + (resType.getDims().front() == VectorDim::getFixed(1) || + resType.getDims().back() == VectorDim::getFixed(1)) && transp == ArrayRef({1, 0})) { rewriter.replaceOpWithNewOp(op, resType, input); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 84294e4552a60..5f8a9c9e9915a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -25,24 +25,15 @@ using namespace mlir::vector; // Trims leading one dimensions from `oldType` and returns the result type. // Returns `vector<1xT>` if `oldType` only has one element. static VectorType trimLeadingOneDims(VectorType oldType) { - ArrayRef oldShape = oldType.getShape(); - ArrayRef newShape = oldShape; - - ArrayRef oldScalableDims = oldType.getScalableDims(); - ArrayRef newScalableDims = oldScalableDims; - - while (!newShape.empty() && newShape.front() == 1 && - !newScalableDims.front()) { - newShape = newShape.drop_front(1); - newScalableDims = newScalableDims.drop_front(1); - } + VectorDims oldDims = oldType.getDims(); + VectorDims newDims = oldDims.dropWhile( + [](VectorDim dim) { return dim == VectorDim::getFixed(1); }); // Make sure we have at least 1 dimension per vector type requirements. - if (newShape.empty()) { - newShape = oldShape.take_back(); - newScalableDims = oldType.getScalableDims().take_back(); - } - return VectorType::get(newShape, oldType.getElementType(), newScalableDims); + if (newDims.empty()) + newDims = oldDims.takeBack(1); + + return VectorType::get(oldType.getElementType(), newDims); } /// Return a smallVector of size `rank` containing all zeros. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index ed42e6508b431..8eec25dbc04ba 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -316,15 +316,11 @@ static int getReducedRank(ArrayRef shape) { /// Trims non-scalable one dimensions from `oldType` and returns the result /// type. static VectorType trimNonScalableUnitDims(VectorType oldType) { - SmallVector newShape; - SmallVector newScalableDims; - for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { - if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) - continue; - newShape.push_back(dimSize); - newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); - } - return VectorType::get(newShape, oldType.getElementType(), newScalableDims); + auto newDims = llvm::to_vector( + llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) { + return dim != VectorDim::getFixed(1); + })); + return VectorType::get(oldType.getElementType(), newDims); } // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. @@ -337,9 +333,9 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, return failure(); SmallVector reducedOperands; - for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( - type.getShape(), type.getScalableDims(), op.getOperands())) { - if (dim == 1 && !dimIsScalable) { + for (auto [dim, operand] : + llvm::zip_equal(type.getDims(), op.getOperands())) { + if (dim == VectorDim::getFixed(1)) { // If the mask for the unit dim is not a constant of 1, do nothing. auto constant = operand.getDefiningOp(); if (!constant || (constant.value() != 1)) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 6e7fab293d3a1..b5bc167c1c53a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1039,10 +1039,7 @@ struct MaterializeTransferMask : public OpRewritePattern { vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); Value b = rewriter.create(loc, dim.getType(), dim, off); Value mask = rewriter.create( - loc, - VectorType::get(vtp.getShape(), rewriter.getI1Type(), - vtp.getScalableDims()), - b); + loc, VectorType::get(rewriter.getI1Type(), vtp.getDims()), b); if (xferOp.getMask()) { // Intersect the in-bounds with the mask specified as an op parameter. mask = rewriter.create(loc, mask, xferOp.getMask()); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4b76dcf7f8a9f..6eb187b6101a3 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2558,17 +2558,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { } }) .Case([&](VectorType vectorTy) { - auto scalableDims = vectorTy.getScalableDims(); os << "vector<"; - auto vShape = vectorTy.getShape(); - unsigned lastDim = vShape.size(); - unsigned dimIdx = 0; - for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { - if (!scalableDims.empty() && scalableDims[dimIdx]) - os << '['; - os << vShape[dimIdx]; - if (!scalableDims.empty() && scalableDims[dimIdx]) - os << ']'; + auto dims = vectorTy.getDims(); + if (!dims.empty()) { + llvm::interleave(dims, os, "x"); os << 'x'; } printType(vectorTy.getElementType()); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 9b8ee3d452803..eac79ea34d655 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -250,10 +250,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) { return VectorType(); if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) - return VectorType::get(getShape(), scaledEt, getScalableDims()); + return VectorType::get(scaledEt, getDims()); if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) - return VectorType::get(getShape(), scaledEt, getScalableDims()); + return VectorType::get(scaledEt, getDims()); return VectorType(); } diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp index 61264bc523648..07625da6ee889 100644 --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) { } } +TEST(ShapedTypeTest, VectorDims) { + MLIRContext context; + Type f32 = FloatType::getF32(&context); + + SmallVector dims{VectorDim::getFixed(2), VectorDim::getScalable(4), + VectorDim::getFixed(8), VectorDim::getScalable(9), + VectorDim::getFixed(1)}; + VectorType vectorType = VectorType::get(f32, dims); + + // Directly check values + { + auto dim0 = vectorType.getDim(0); + ASSERT_EQ(dim0.getMinSize(), 2); + ASSERT_TRUE(dim0.isFixed()); + + auto dim1 = vectorType.getDim(1); + ASSERT_EQ(dim1.getMinSize(), 4); + ASSERT_TRUE(dim1.isScalable()); + + auto dim2 = vectorType.getDim(2); + ASSERT_EQ(dim2.getMinSize(), 8); + ASSERT_TRUE(dim2.isFixed()); + + auto dim3 = vectorType.getDim(3); + ASSERT_EQ(dim3.getMinSize(), 9); + ASSERT_TRUE(dim3.isScalable()); + + auto dim4 = vectorType.getDim(4); + ASSERT_EQ(dim4.getMinSize(), 1); + ASSERT_TRUE(dim4.isFixed()); + } + + // Test indexing via getDim(idx) + { + for (unsigned i = 0; i < dims.size(); i++) + ASSERT_EQ(vectorType.getDim(i), dims[i]); + } + + // Test using VectorDims::Iterator in for-each loop + { + unsigned i = 0; + for (VectorDim dim : vectorType.getDims()) + ASSERT_EQ(dim, dims[i++]); + ASSERT_EQ(i, vectorType.getRank()); + } + + // Test using VectorDims::Iterator in LLVM iterator helper + { + for (auto [dim, expectedDim] : + llvm::zip_equal(vectorType.getDims(), dims)) { + ASSERT_EQ(dim, expectedDim); + } + } + + // Test dropFront() + { + auto vectorDims = vectorType.getDims(); + auto newDims = vectorDims.dropFront(); + + ASSERT_EQ(newDims.size(), vectorDims.size() - 1); + for (unsigned i = 0; i < newDims.size(); i++) + ASSERT_EQ(newDims[i], vectorDims[i + 1]); + } + + // Test dropBack() + { + auto vectorDims = vectorType.getDims(); + auto newDims = vectorDims.dropBack(); + + ASSERT_EQ(newDims.size(), vectorDims.size() - 1); + for (unsigned i = 0; i < newDims.size(); i++) + ASSERT_EQ(newDims[i], vectorDims[i]); + } + + // Test front() + { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); } + + // Test back() + { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); } + + // Test dropWhile. + { + SmallVector dims{ + VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1), + VectorDim::getScalable(1), VectorDim::getScalable(4)}; + + VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims); + ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(), + unsigned(vectorTypeWithLeadingUnitDims.getRank())); + + // Drop leading unit dims. + auto withoutLeadingUnitDims = + vectorTypeWithLeadingUnitDims.getDims().dropWhile( + [](VectorDim dim) { return dim == VectorDim::getFixed(1); }); + + SmallVector expectedDims{VectorDim::getScalable(1), + VectorDim::getScalable(4)}; + ASSERT_EQ(withoutLeadingUnitDims, expectedDims); + } +} + } // namespace