diff --git a/mlir/include/mlir/Support/ScalableVectorType.h b/mlir/include/mlir/Support/ScalableVectorType.h new file mode 100644 index 0000000000000..a28d1ac8d9065 --- /dev/null +++ b/mlir/include/mlir/Support/ScalableVectorType.h @@ -0,0 +1,364 @@ +//===- ScalableVectorType.h - Scalable Vector Helpers -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_SCALABLEVECTORTYPE_H +#define MLIR_SUPPORT_SCALABLEVECTORTYPE_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// VectorDim +//===----------------------------------------------------------------------===// + +/// This class represents a dimension of a vector type. Unlike other ShapedTypes +/// vector dimensions can have scalable quantities, which means the dimension +/// has a known minimum size, which is scaled by a constant that is only +/// known at runtime. +class VectorDim { +public: + explicit constexpr VectorDim(int64_t quantity, bool scalable) + : quantity(quantity), scalable(scalable) {}; + + /// Constructs a new fixed dimension. + constexpr static VectorDim getFixed(int64_t quantity) { + return VectorDim(quantity, false); + } + + /// Constructs a new scalable dimension. + constexpr static VectorDim getScalable(int64_t quantity) { + return VectorDim(quantity, true); + } + + /// Returns true if this dimension is scalable; + constexpr bool isScalable() const { return scalable; } + + /// Returns true if this dimension is fixed. + constexpr bool isFixed() const { return !isScalable(); } + + /// Returns the minimum number of elements this dimension can contain. + constexpr int64_t getMinSize() const { return quantity; } + + /// If this dimension is fixed returns the number of elements, otherwise + /// aborts. + constexpr int64_t getFixedSize() const { + assert(isFixed()); + return quantity; + } + + constexpr bool operator==(VectorDim const &dim) const { + return quantity == dim.quantity && scalable == dim.scalable; + } + + constexpr bool operator!=(VectorDim const &dim) const { + return !(*this == dim); + } + + /// Print the dim. + void print(raw_ostream &os) { + if (isScalable()) + os << '['; + os << getMinSize(); + if (isScalable()) + os << ']'; + } + + /// Helper class for indexing into a list of sizes (and possibly empty) list + /// of scalable dimensions, extracting VectorDim elements. + struct Indexer { + explicit Indexer(ArrayRef sizes, ArrayRef scalableDims) + : sizes(sizes), scalableDims(scalableDims) { + assert( + scalableDims.empty() || + sizes.size() == scalableDims.size() && + "expected `scalableDims` to be empty or match `sizes` in length"); + } + + VectorDim operator[](size_t idx) const { + int64_t size = sizes[idx]; + bool scalable = scalableDims.empty() ? false : scalableDims[idx]; + return VectorDim(size, scalable); + } + + ArrayRef sizes; + ArrayRef scalableDims; + }; + +private: + int64_t quantity; + bool scalable; +}; + +inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) { + dim.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// VectorDimList +//===----------------------------------------------------------------------===// + +/// Represents a non-owning list of vector dimensions. The underlying dimension +/// sizes and scalability flags are stored a two seperate lists to match the +/// storage of a VectorType. +class VectorDimList : public VectorDim::Indexer { +public: + using VectorDim::Indexer::Indexer; + + class Iterator : public llvm::iterator_facade_base< + Iterator, std::random_access_iterator_tag, VectorDim, + std::ptrdiff_t, VectorDim, VectorDim> { + public: + Iterator(VectorDim::Indexer indexer, size_t index) + : indexer(indexer), index(index) {}; + + // Iterator boilerplate. + ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; } + bool operator==(const Iterator &rhs) const { return index == rhs.index; } + bool operator<(const Iterator &rhs) const { return index < rhs.index; } + Iterator &operator+=(ptrdiff_t offset) { + index += offset; + return *this; + } + Iterator &operator-=(ptrdiff_t offset) { + index -= offset; + return *this; + } + VectorDim operator*() const { return indexer[index]; } + + VectorDim::Indexer getIndexer() const { return indexer; } + ptrdiff_t getIndex() const { return index; } + + private: + VectorDim::Indexer indexer; + ptrdiff_t index; + }; + + // Generic definitions. + using value_type = VectorDim; + using iterator = Iterator; + using const_iterator = Iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + using size_type = size_t; + using difference_type = ptrdiff_t; + + /// Construct from iterator pair. + VectorDimList(Iterator begin, Iterator end) + : VectorDimList(VectorDimList(begin.getIndexer()) + .slice(begin.getIndex(), end - begin)) {} + + VectorDimList(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer) {}; + + /// Construct from a VectorType. + static VectorDimList from(VectorType vectorType) { + if (!vectorType) + return VectorDimList({}, {}); + return VectorDimList(vectorType.getShape(), vectorType.getScalableDims()); + } + + Iterator begin() const { return Iterator(*this, 0); } + Iterator end() const { return Iterator(*this, size()); } + + /// Check if the dims are empty. + bool empty() const { return sizes.empty(); } + + /// Get the number of dims. + size_t size() const { return sizes.size(); } + + /// Return the first dim. + VectorDim front() const { return (*this)[0]; } + + /// Return the last dim. + VectorDim back() const { return (*this)[size() - 1]; } + + /// Chop of the first \p n dims, and keep the remaining \p m dims. + VectorDimList slice(size_t n, size_t m) const { + ArrayRef newSizes = sizes.slice(n, m); + ArrayRef newScalableDims = + scalableDims.empty() ? ArrayRef{} : scalableDims.slice(n, m); + return VectorDimList(newSizes, newScalableDims); + } + + /// Drop the first \p n dims. + VectorDimList dropFront(size_t n = 1) const { return slice(n, size() - n); } + + /// Drop the last \p n dims. + VectorDimList dropBack(size_t n = 1) const { return slice(0, size() - n); } + + /// Return a copy of *this with only the first \p n elements. + VectorDimList takeFront(size_t n = 1) const { + if (n >= size()) + return *this; + return dropBack(size() - n); + } + + /// Return a copy of *this with only the last \p n elements. + VectorDimList takeBack(size_t n = 1) const { + if (n >= size()) + return *this; + return dropFront(size() - n); + } + + /// Return copy of *this with the first n dims matching the predicate removed. + template + VectorDimList dropWhile(PredicateT predicate) const { + return VectorDimList(llvm::find_if_not(*this, predicate), end()); + } + + /// Returns true if one or more of the dims are scalable. + bool hasScalableDims() const { + return llvm::is_contained(getScalableDims(), true); + } + + /// Check for dim equality. + bool equals(VectorDimList rhs) const { + if (size() != rhs.size()) + return false; + return std::equal(begin(), end(), rhs.begin()); + } + + /// Check for dim equality. + bool equals(ArrayRef rhs) const { + if (size() != rhs.size()) + return false; + return std::equal(begin(), end(), rhs.begin()); + } + + /// Return the underlying sizes. + ArrayRef getSizes() const { return sizes; } + + /// Return the underlying scalable dims. + ArrayRef getScalableDims() const { return scalableDims; } +}; + +inline bool operator==(VectorDimList lhs, VectorDimList rhs) { + return lhs.equals(rhs); +} + +inline bool operator!=(VectorDimList lhs, VectorDimList rhs) { + return !(lhs == rhs); +} + +inline bool operator==(VectorDimList lhs, ArrayRef rhs) { + return lhs.equals(rhs); +} + +inline bool operator!=(VectorDimList lhs, ArrayRef rhs) { + return !(lhs == rhs); +} + +//===----------------------------------------------------------------------===// +// ScalableVectorType +//===----------------------------------------------------------------------===// + +/// A pseudo-type that wraps a VectorType that aims to provide safe APIs for +/// working with scalable vectors. Slightly contrary to the name this class can +/// represent both fixed and scalable vectors, however, if you are only dealing +/// with fixed vectors the plain VectorType is likely more convenient. +/// +/// The main difference from the regular VectorType is that vector dimensions +/// are _not_ represented as `int64_t`, which does not allow encoding the +/// scalability into the dimension. Instead, vector dimensions are represented +/// by a VectorDim class. A VectorDim stores both the size and scalability of a +/// dimension. This makes common errors like only checking the size (but not the +/// scalability) impossible (without being explicit with your intention). +/// +/// To make this convenient to work with there is VectorDimList which provides +/// ArrayRef-like helper methods along with an iterator for VectorDims. +/// +/// ScalableVectorType can freely converted to VectorType (and vice versa), +/// though there are two main ways to acquire a ScalableVectorType. +/// +/// Assignment: +/// +/// This does not check the scalability of `myVectorType`. This is valid and the +/// helpers on ScalableVectorType will function as normal. +/// ```c++ +/// VectorType myVectorType = ...; +/// ScalableVectorType scalableVector = myVectorType; +/// ``` +/// +/// Casting: +/// +/// This checks the scalability of `myVectorType`. In this case, +/// `scalableVector` will be falsy if `myVectorType` contains no scalable dims. +/// ```c++ +/// VectorType myVectorType = ...; +/// auto scalableVector = dyn_cast(myVectorType); +/// ``` +class ScalableVectorType { +public: + using Dim = VectorDim; + using DimList = VectorDimList; + + ScalableVectorType(VectorType vectorType) : vectorType(vectorType) {}; + + /// Construct a new ScalableVectorType. + static ScalableVectorType get(DimList shape, Type elementType) { + return VectorType::get(shape.getSizes(), elementType, + shape.getScalableDims()); + } + + /// Construct a new ScalableVectorType. + static ScalableVectorType get(ArrayRef shape, Type elementType) { + SmallVector sizes; + SmallVector scalableDims; + sizes.reserve(shape.size()); + scalableDims.reserve(shape.size()); + for (Dim dim : shape) { + sizes.push_back(dim.getMinSize()); + scalableDims.push_back(dim.isScalable()); + } + return VectorType::get(sizes, elementType, scalableDims); + } + + inline static bool classof(Type type) { + auto vectorType = dyn_cast_if_present(type); + return vectorType && vectorType.isScalable(); + } + + /// Returns the value of the specified dimension (including scalability). + Dim getDim(unsigned idx) const { + assert(idx < getRank() && "invalid dim index for vector type"); + return getDims()[idx]; + } + + /// Returns the dimensions of this vector type (including scalability). + DimList getDims() const { + return DimList(vectorType.getShape(), vectorType.getScalableDims()); + } + + /// Returns the rank of this vector type. + int64_t getRank() const { return vectorType.getRank(); } + + /// Returns true if the vector contains scalable dimensions. + bool isScalable() const { return vectorType.isScalable(); } + bool allDimsScalable() const { return vectorType.allDimsScalable(); } + + /// Returns the element type of this vector type. + Type getElementType() const { return vectorType.getElementType(); } + + /// Clones this vector type with a new element type. + ScalableVectorType clone(Type elementType) { + return vectorType.clone(elementType); + } + + operator VectorType() const { return vectorType; } + + explicit operator bool() const { return bool(vectorType); } + +private: + VectorType vectorType; +}; + +} // namespace mlir + +#endif diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index f4fae68da63b3..7c694ca7d55c8 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -29,6 +29,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/ScalableVectorType.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" @@ -39,24 +40,14 @@ using namespace mlir; using namespace mlir::math; using namespace mlir::vector; -// Helper to encapsulate a vector's shape (including scalable dims). -struct VectorShape { - ArrayRef sizes; - ArrayRef scalableFlags; - - bool empty() const { return sizes.empty(); } -}; - // Returns vector shape if the type is a vector. Returns an empty shape if it is // not a vector. -static VectorShape vectorShape(Type type) { +static VectorDimList vectorShape(Type type) { auto vectorType = dyn_cast(type); - return vectorType - ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()} - : VectorShape{}; + return VectorDimList::from(vectorType); } -static VectorShape vectorShape(Value value) { +static VectorDimList vectorShape(Value value) { return vectorShape(value.getType()); } @@ -65,16 +56,14 @@ static VectorShape vectorShape(Value value) { //----------------------------------------------------------------------------// // Broadcasts scalar type into vector type (iff shape is non-scalar). -static Type broadcast(Type type, VectorShape shape) { +static Type broadcast(Type type, VectorDimList shape) { assert(!isa(type) && "must be scalar type"); - return !shape.empty() - ? VectorType::get(shape.sizes, type, shape.scalableFlags) - : type; + return !shape.empty() ? ScalableVectorType::get(shape, type) : type; } // Broadcasts scalar value into vector (iff shape is non-scalar). static Value broadcast(ImplicitLocOpBuilder &builder, Value value, - VectorShape shape) { + VectorDimList shape) { assert(!isa(value.getType()) && "must be scalar value"); auto type = broadcast(value.getType(), shape); return !shape.empty() ? builder.create(type, value) : value; @@ -227,7 +216,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, static std::pair frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive = false) { assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); - VectorShape shape = vectorShape(arg); + VectorDimList shape = vectorShape(arg); auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); @@ -267,7 +256,7 @@ static std::pair frexp(ImplicitLocOpBuilder &builder, Value arg, // Computes exp2 for an i32 argument. static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); - VectorShape shape = vectorShape(arg); + VectorDimList shape = vectorShape(arg); auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); @@ -293,7 +282,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, Type elementType = getElementTypeOrSelf(x); assert((elementType.isF32() || elementType.isF16()) && "x must be f32 or f16 type"); - VectorShape shape = vectorShape(x); + VectorDimList shape = vectorShape(x); if (coeffs.empty()) return broadcast(builder, floatCst(builder, 0.0f, elementType), shape); @@ -391,7 +380,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, if (!getElementTypeOrSelf(operand).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); Value abs = builder.create(operand); @@ -490,7 +479,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op, return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - VectorShape shape = vectorShape(op.getResult()); + VectorDimList shape = vectorShape(op.getResult()); // Compute atan in the valid range. auto div = builder.create(y, x); @@ -556,7 +545,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op, if (!getElementTypeOrSelf(op.getOperand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -644,7 +633,7 @@ LogApproximationBase::logMatchAndRewrite(Op op, PatternRewriter &rewriter, if (!getElementTypeOrSelf(op.getOperand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -791,7 +780,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op, if (!getElementTypeOrSelf(op.getOperand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -846,7 +835,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, if (!(elementType.isF32() || elementType.isF16())) return rewriter.notifyMatchFailure(op, "only f32 and f16 type is supported."); - VectorShape shape = vectorShape(operand); + VectorDimList shape = vectorShape(operand); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -910,7 +899,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op, if (!(elementType.isF32() || elementType.isF16())) return rewriter.notifyMatchFailure(op, "only f32 and f16 type is supported."); - VectorShape shape = vectorShape(operand); + VectorDimList shape = vectorShape(operand); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -988,7 +977,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, if (!(elementType.isF32() || elementType.isF16())) return rewriter.notifyMatchFailure(op, "only f32 and f16 type is supported."); - VectorShape shape = vectorShape(operand); + VectorDimList shape = vectorShape(operand); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -1097,7 +1086,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, namespace { -Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape, +Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorDimList shape, Value value, float lowerBound, float upperBound) { assert(!std::isnan(lowerBound)); assert(!std::isnan(upperBound)); @@ -1289,7 +1278,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, if (!getElementTypeOrSelf(op.getOperand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -1359,7 +1348,7 @@ LogicalResult SinAndCosApproximation::matchAndRewrite( if (!getElementTypeOrSelf(op.getOperand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -1486,7 +1475,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op, return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder b(op->getLoc(), rewriter); - VectorShape shape = vectorShape(operand); + VectorDimList shape = vectorShape(operand); Type floatTy = getElementTypeOrSelf(operand.getType()); Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth()); @@ -1575,7 +1564,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, if (!getElementTypeOrSelf(op.getOperand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - VectorShape shape = vectorShape(op.getOperand()); + VectorDimList shape = vectorShape(op.getOperand()); // Only support already-vectorized rsqrt's. if (shape.empty() || shape.sizes.back() % 8 != 0) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6734c80f2760d..e2ce56e9e188a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -35,6 +35,7 @@ #include "mlir/Interfaces/SubsetOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/ScalableVectorType.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -463,23 +464,22 @@ MultiDimReductionOp::getShapeForUnroll() { } LogicalResult MultiDimReductionOp::verify() { - SmallVector targetShape; - SmallVector scalableDims; + SmallVector targetDims; Type inferredReturnType; - auto sourceScalableDims = getSourceVectorType().getScalableDims(); - for (auto it : llvm::enumerate(getSourceVectorType().getShape())) - if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return llvm::cast(attr).getValue() == it.index(); - })) { - targetShape.push_back(it.value()); - scalableDims.push_back(sourceScalableDims[it.index()]); + auto sourceDims = VectorDimList::from(getSourceVectorType()); + for (auto [idx, dim] : llvm::enumerate(sourceDims)) + if (!llvm::any_of(getReductionDims().getValue(), + [idx = idx](Attribute attr) { + return llvm::cast(attr).getValue() == idx; + })) { + targetDims.push_back(dim); } // TODO: update to also allow 0-d vectors when available. - if (targetShape.empty()) + if (targetDims.empty()) inferredReturnType = getSourceVectorType().getElementType(); else - inferredReturnType = VectorType::get( - targetShape, getSourceVectorType().getElementType(), scalableDims); + inferredReturnType = ScalableVectorType::get( + targetDims, getSourceVectorType().getElementType()); if (getType() != inferredReturnType) return emitOpError() << "destination type " << getType() << " is incompatible with source type " @@ -3247,23 +3247,19 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), "expected at least 2 operands"); - VectorType vLHS = llvm::dyn_cast(tLHS); - VectorType vRHS = llvm::dyn_cast(tRHS); + ScalableVectorType vLHS = llvm::dyn_cast(tLHS); + ScalableVectorType vRHS = llvm::dyn_cast(tRHS); if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); VectorType resType; if (vRHS) { - SmallVector scalableDimsRes{vLHS.getScalableDims()[0], - vRHS.getScalableDims()[0]}; - resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType(), scalableDimsRes); + resType = ScalableVectorType::get({vLHS.getDim(0), vRHS.getDim(0)}, + vLHS.getElementType()); } else { // Scalar RHS operand - SmallVector scalableDimsRes{vLHS.getScalableDims()[0]}; - resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), - scalableDimsRes); + resType = ScalableVectorType::get(vLHS.getDim(0), vLHS.getElementType()); } if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) { @@ -5308,26 +5304,11 @@ class ShapeCastConstantFolder final : public OpRewritePattern { /// /// vector<4x1x1xi1> --> vector<4x1> /// -static VectorType trimTrailingOneDims(VectorType oldType) { - ArrayRef oldShape = oldType.getShape(); - ArrayRef newShape = oldShape; - - ArrayRef oldScalableDims = oldType.getScalableDims(); - ArrayRef newScalableDims = oldScalableDims; - - while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) { - newShape = newShape.drop_back(1); - newScalableDims = newScalableDims.drop_back(1); - } - - // Make sure we have at least 1 dimension. - // TODO: Add support for 0-D vectors. - if (newShape.empty()) { - newShape = oldShape.take_back(); - newScalableDims = oldScalableDims.take_back(); - } - - return VectorType::get(newShape, oldType.getElementType(), newScalableDims); +static ScalableVectorType trimTrailingOneDims(ScalableVectorType oldType) { + VectorDimList newDims = oldType.getDims(); + while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1)) + newDims = newDims.dropBack(); + return ScalableVectorType::get(newDims, oldType.getElementType()); } /// Folds qualifying shape_cast(create_mask) into a new create_mask diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index c31c51489ecc9..b4dd274914edb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/ScalableVectorType.h" using namespace mlir; using namespace mlir::vector; @@ -122,13 +123,11 @@ struct TransferReadPermutationLowering permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); // Apply the reverse transpose to deduce the type of the transfer_read. - ArrayRef originalShape = op.getVectorType().getShape(); - SmallVector newVectorShape(originalShape.size()); - ArrayRef originalScalableDims = op.getVectorType().getScalableDims(); - SmallVector newScalableDims(originalShape.size()); - for (const auto &pos : llvm::enumerate(permutation)) { - newVectorShape[pos.value()] = originalShape[pos.index()]; - newScalableDims[pos.value()] = originalScalableDims[pos.index()]; + auto originalDims = VectorDimList::from(op.getVectorType()); + SmallVector newDims(op.getVectorType().getRank(), + VectorDim::getFixed(0)); + for (auto [originalIdx, newIdx] : llvm::enumerate(permutation)) { + newDims[newIdx] = originalDims[originalIdx]; } // Transpose in_bounds attribute. @@ -138,8 +137,8 @@ struct TransferReadPermutationLowering : ArrayAttr(); // Generate new transfer_read operation. - VectorType newReadType = VectorType::get( - newVectorShape, op.getVectorType().getElementType(), newScalableDims); + VectorType newReadType = + ScalableVectorType::get(newDims, op.getVectorType().getElementType()); Value newRead = rewriter.create( op.getLoc(), newReadType, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index ca8a6f6d82a6e..fa259d7bf1449 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/ScalableVectorType.h" #define DEBUG_TYPE "lower-vector-transpose" @@ -432,18 +433,17 @@ class Transpose2DWithUnitDimToShapeCast LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { Value input = op.getVector(); - VectorType resType = op.getResultVectorType(); + ScalableVectorType resType = op.getResultVectorType(); // Set up convenience transposition table. ArrayRef transp = op.getPermutation(); if (resType.getRank() == 2 && - ((resType.getShape().front() == 1 && - !resType.getScalableDims().front()) || - (resType.getShape().back() == 1 && - !resType.getScalableDims().back())) && + (resType.getDims().front() == VectorDim::getFixed(1) || + resType.getDims().back() == VectorDim::getFixed(1)) && transp == ArrayRef({1, 0})) { - rewriter.replaceOpWithNewOp(op, resType, input); + rewriter.replaceOpWithNewOp(op, Type(resType), + input); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 7ed3dea42b771..ca61fe81a3d16 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/ScalableVectorType.h" #define DEBUG_TYPE "vector-drop-unit-dim" @@ -24,25 +25,15 @@ using namespace mlir::vector; // Trims leading one dimensions from `oldType` and returns the result type. // Returns `vector<1xT>` if `oldType` only has one element. -static VectorType trimLeadingOneDims(VectorType oldType) { - ArrayRef oldShape = oldType.getShape(); - ArrayRef newShape = oldShape; - - ArrayRef oldScalableDims = oldType.getScalableDims(); - ArrayRef newScalableDims = oldScalableDims; - - while (!newShape.empty() && newShape.front() == 1 && - !newScalableDims.front()) { - newShape = newShape.drop_front(1); - newScalableDims = newScalableDims.drop_front(1); - } +static ScalableVectorType trimLeadingOneDims(ScalableVectorType oldType) { + VectorDimList oldDims = oldType.getDims(); + VectorDimList newDims = oldDims.dropWhile( + [](VectorDim dim) { return dim == VectorDim::getFixed(1); }); // Make sure we have at least 1 dimension per vector type requirements. - if (newShape.empty()) { - newShape = oldShape.take_back(); - newScalableDims = oldType.getScalableDims().take_back(); - } - return VectorType::get(newShape, oldType.getElementType(), newScalableDims); + if (newDims.empty()) + newDims = oldDims.takeBack(1); + return ScalableVectorType::get(newDims, oldType.getElementType()); } /// Return a smallVector of size `rank` containing all zeros. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c131fde517f80..18b6882ac37f2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/ScalableVectorType.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -312,31 +313,27 @@ static int getReducedRank(ArrayRef shape) { /// Trims non-scalable one dimensions from `oldType` and returns the result /// type. -static VectorType trimNonScalableUnitDims(VectorType oldType) { - SmallVector newShape; - SmallVector newScalableDims; - for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { - if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) - continue; - newShape.push_back(dimSize); - newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); - } - return VectorType::get(newShape, oldType.getElementType(), newScalableDims); +static ScalableVectorType trimNonScalableUnitDims(ScalableVectorType oldType) { + auto newDims = llvm::to_vector( + llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) { + return dim != VectorDim::getFixed(1); + })); + return ScalableVectorType::get(newDims, oldType.getElementType()); } // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. static FailureOr createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op) { - auto type = op.getType(); + ScalableVectorType type = op.getType(); VectorType reducedType = trimNonScalableUnitDims(type); if (reducedType.getRank() == type.getRank()) return failure(); SmallVector reducedOperands; - for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( - type.getShape(), type.getScalableDims(), op.getOperands())) { - if (dim == 1 && !dimIsScalable) { + for (auto [dim, operand] : + llvm::zip_equal(type.getDims(), op.getOperands())) { + if (dim == VectorDim::getFixed(1)) { // If the mask for the unit dim is not a constant of 1, do nothing. auto constant = operand.getDefiningOp(); if (!constant || (constant.value() != 1)) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index b824508728ac8..61ea5a8098f8d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -39,6 +39,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/ScalableVectorType.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" @@ -1239,16 +1240,13 @@ struct FoldI1Select : public OpRewritePattern { /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable /// unit") static FailureOr -getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { +getTransferFoldableInnerUnitDims(MemRefType srcType, + ScalableVectorType vectorType) { SmallVector srcStrides; int64_t srcOffset; if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) return failure(); - auto isUnitDim = [](VectorType type, int dim) { - return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim]; - }; - // According to vector.transfer_read/write semantics, the vector can be a // slice. Thus, we have to offset the check index with `rankDiff` in // `srcStrides` and source dim sizes. @@ -1259,7 +1257,8 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { // It can be folded only if they are 1 and the stride is 1. int dim = vectorType.getRank() - i - 1; if (srcStrides[dim + rankDiff] != 1 || - srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim)) + srcType.getDimSize(dim + rankDiff) != 1 || + vectorType.getDim(dim) != VectorDim::getFixed(1)) break; result++; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2c43a6f15aa83..ef1b1812de7c7 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Verifier.h" +#include "mlir/Support/ScalableVectorType.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -2607,17 +2608,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { } }) .Case([&](VectorType vectorTy) { - auto scalableDims = vectorTy.getScalableDims(); os << "vector<"; - auto vShape = vectorTy.getShape(); - unsigned lastDim = vShape.size(); - unsigned dimIdx = 0; - for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { - if (!scalableDims.empty() && scalableDims[dimIdx]) - os << '['; - os << vShape[dimIdx]; - if (!scalableDims.empty() && scalableDims[dimIdx]) - os << ']'; + auto dims = VectorDimList::from(vectorTy); + if (!dims.empty()) { + llvm::interleave(dims, os, "x"); os << 'x'; } printType(vectorTy.getElementType()); diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt index 1dbf072bcbbfd..680572bbb0cbe 100644 --- a/mlir/unittests/Support/CMakeLists.txt +++ b/mlir/unittests/Support/CMakeLists.txt @@ -1,7 +1,8 @@ add_mlir_unittest(MLIRSupportTests IndentedOstreamTest.cpp StorageUniquerTest.cpp + ScalableVectorTypeTest.cpp ) target_link_libraries(MLIRSupportTests - PRIVATE MLIRSupport) + PRIVATE MLIRSupport MLIRIR) diff --git a/mlir/unittests/Support/ScalableVectorTypeTest.cpp b/mlir/unittests/Support/ScalableVectorTypeTest.cpp new file mode 100644 index 0000000000000..5f0237ac68414 --- /dev/null +++ b/mlir/unittests/Support/ScalableVectorTypeTest.cpp @@ -0,0 +1,76 @@ +//===- ScalableVectorTypeTest.cpp - ScalableVectorType Tests --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/ScalableVectorType.h" +#include "mlir/IR/Dialect.h" +#include "gtest/gtest.h" + +using namespace mlir; + +TEST(ScalableVectorTypeTest, TestVectorDim) { + auto fixedDim = VectorDim::getFixed(4); + ASSERT_FALSE(fixedDim.isScalable()); + ASSERT_TRUE(fixedDim.isFixed()); + ASSERT_EQ(fixedDim.getFixedSize(), 4); + + auto scalableDim = VectorDim::getScalable(8); + ASSERT_TRUE(scalableDim.isScalable()); + ASSERT_FALSE(scalableDim.isFixed()); + ASSERT_EQ(scalableDim.getMinSize(), 8); +} + +TEST(ScalableVectorTypeTest, BasicFunctionality) { + MLIRContext context; + + Type f32 = FloatType::getF32(&context); + + // Construct n-D scalable vector. + VectorType scalableVector = ScalableVectorType::get( + {VectorDim::getFixed(1), VectorDim::getFixed(2), + VectorDim::getScalable(3), VectorDim::getFixed(4), + VectorDim::getScalable(5)}, + f32); + // Construct fixed vector. + VectorType fixedVector = ScalableVectorType::get(VectorDim::getFixed(1), f32); + + // Check casts. + ASSERT_TRUE(isa(scalableVector)); + ASSERT_FALSE(isa(fixedVector)); + ASSERT_FALSE(VectorDimList::from(fixedVector).hasScalableDims()); + + // Check rank/size. + auto vType = cast(scalableVector); + ASSERT_EQ(vType.getDims().size(), unsigned(scalableVector.getRank())); + ASSERT_TRUE(vType.getDims().hasScalableDims()); + + // Check iterating over dimensions. + std::array expectedDims{VectorDim::getFixed(1), VectorDim::getFixed(2), + VectorDim::getScalable(3), VectorDim::getFixed(4), + VectorDim::getScalable(5)}; + unsigned i = 0; + for (VectorDim dim : vType.getDims()) { + ASSERT_EQ(dim, expectedDims[i]); + i++; + } +} + +TEST(ScalableVectorTypeTest, VectorDimListHelpers) { + std::array sizes{42, 10, 3, 1}; + std::array scalableFlags{false, true, false, true}; + + // Manually construct from sizes + flags. + VectorDimList dimList(sizes, scalableFlags); + + ASSERT_EQ(dimList.size(), 4U); + + ASSERT_EQ(dimList.front(), VectorDim::getFixed(42)); + ASSERT_EQ(dimList.back(), VectorDim::getScalable(1)); + + std::array innerDims{VectorDim::getScalable(10), VectorDim::getFixed(3)}; + ASSERT_EQ(dimList.slice(1, 2), innerDims); +}