Skip to content

[mlir] Add first-class support for scalability in VectorType dims #74251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
#include "llvm/ADT/STLExtras.h"

namespace llvm {
class BitVector;
@@ -181,6 +182,239 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};

//===----------------------------------------------------------------------===//
// VectorDim
//===----------------------------------------------------------------------===//

/// This class represents a dimension of a vector type. Unlike other ShapedTypes
/// vector dimensions can have scalable quantities, which means the dimension
/// has a known minimum size, which is scaled by a constant that is only
/// known at runtime.
class VectorDim {
public:
explicit constexpr VectorDim(int64_t quantity, bool scalable)
: quantity(quantity), scalable(scalable){};

/// Constructs a new fixed dimension.
constexpr static VectorDim getFixed(int64_t quantity) {
return VectorDim(quantity, false);
}

/// Constructs a new scalable dimension.
constexpr static VectorDim getScalable(int64_t quantity) {
return VectorDim(quantity, true);
}

/// Returns true if this dimension is scalable;
constexpr bool isScalable() const { return scalable; }

/// Returns true if this dimension is fixed.
constexpr bool isFixed() const { return !isScalable(); }

/// Returns the minimum number of elements this dimension can contain.
constexpr int64_t getMinSize() const { return quantity; }

/// If this dimension is fixed returns the number of elements, otherwise
/// aborts.
constexpr int64_t getFixedSize() const {
assert(isFixed());
return quantity;
}

constexpr bool operator==(VectorDim const &dim) const {
return quantity == dim.quantity && scalable == dim.scalable;
}

constexpr bool operator!=(VectorDim const &dim) const {
return !(*this == dim);
}

/// Print the dim.
void print(raw_ostream &os) {
if (isScalable())
os << '[';
os << getMinSize();
if (isScalable())
os << ']';
}

/// Helper class for indexing into a list of sizes (and possibly empty) list
/// of scalable dimensions, extracting VectorDim elements.
struct Indexer {
explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
: sizes(sizes), scalableDims(scalableDims) {
assert(
scalableDims.empty() ||
sizes.size() == scalableDims.size() &&
"expected `scalableDims` to be empty or match `sizes` in length");
}

VectorDim operator[](size_t idx) const {
int64_t size = sizes[idx];
bool scalable = scalableDims.empty() ? false : scalableDims[idx];
return VectorDim(size, scalable);
}

ArrayRef<int64_t> sizes;
ArrayRef<bool> scalableDims;
};

private:
int64_t quantity;
bool scalable;
};

inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) {
dim.print(os);
return os;
}

//===----------------------------------------------------------------------===//
// VectorDims
//===----------------------------------------------------------------------===//

/// Represents a non-owning list of vector dimensions. The underlying dimension
/// sizes and scalability flags are stored a two seperate lists to match the
/// storage of a VectorType.
class VectorDims : public VectorDim::Indexer {
public:
using VectorDim::Indexer::Indexer;

class Iterator : public llvm::iterator_facade_base<
Iterator, std::random_access_iterator_tag, VectorDim,
std::ptrdiff_t, VectorDim, VectorDim> {
public:
Iterator(VectorDim::Indexer indexer, size_t index)
: indexer(indexer), index(index){};

// Iterator boilerplate.
ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
bool operator==(const Iterator &rhs) const { return index == rhs.index; }
bool operator<(const Iterator &rhs) const { return index < rhs.index; }
Iterator &operator+=(ptrdiff_t offset) {
index += offset;
return *this;
}
Iterator &operator-=(ptrdiff_t offset) {
index -= offset;
return *this;
}
VectorDim operator*() const { return indexer[index]; }

VectorDim::Indexer getIndexer() const { return indexer; }
ptrdiff_t getIndex() const { return index; }

private:
VectorDim::Indexer indexer;
ptrdiff_t index;
};

// Generic definitions.
using value_type = VectorDim;
using iterator = Iterator;
using const_iterator = Iterator;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using size_type = size_t;
using difference_type = ptrdiff_t;

/// Construct from iterator pair.
VectorDims(Iterator begin, Iterator end)
: VectorDims(VectorDims(begin.getIndexer())
.slice(begin.getIndex(), end - begin)) {}

VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};

Iterator begin() const { return Iterator(*this, 0); }
Iterator end() const { return Iterator(*this, size()); }

/// Check if the dims are empty.
bool empty() const { return sizes.empty(); }

/// Get the number of dims.
size_t size() const { return sizes.size(); }

/// Return the first dim.
VectorDim front() const { return (*this)[0]; }

/// Return the last dim.
VectorDim back() const { return (*this)[size() - 1]; }

/// Chop of thie first \p n dims, and keep the remaining \p m
/// dims.
VectorDims slice(size_t n, size_t m) const {
ArrayRef<int64_t> newSizes = sizes.slice(n, m);
ArrayRef<bool> newScalableDims =
scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
return VectorDims(newSizes, newScalableDims);
}

/// Drop the first \p n dims.
VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }

/// Drop the last \p n dims.
VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }

/// Return a copy of *this with only the first \p n elements.
VectorDims takeFront(size_t n = 1) const {
if (n >= size())
return *this;
return dropBack(size() - n);
}

/// Return a copy of *this with only the last \p n elements.
VectorDims takeBack(size_t n = 1) const {
if (n >= size())
return *this;
return dropFront(size() - n);
}

/// Return copy of *this with the first n dims matching the predicate removed.
template <class PredicateT>
VectorDims dropWhile(PredicateT predicate) const {
return VectorDims(llvm::find_if_not(*this, predicate), end());
}

/// Returns true if one or more of the dims are scalable.
bool hasScalableDims() const {
return llvm::is_contained(getScalableDims(), true);
}

/// Check for dim equality.
bool equals(VectorDims rhs) const {
if (size() != rhs.size())
return false;
return std::equal(begin(), end(), rhs.begin());
}

/// Check for dim equality.
bool equals(ArrayRef<VectorDim> rhs) const {
if (size() != rhs.size())
return false;
return std::equal(begin(), end(), rhs.begin());
}

/// Return the underlying sizes.
ArrayRef<int64_t> getSizes() const { return sizes; }

/// Return the underlying scalable dims.
ArrayRef<bool> getScalableDims() const { return scalableDims; }
};

inline bool operator==(VectorDims lhs, VectorDims rhs) {
return lhs.equals(rhs);
}

inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }

inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
return lhs.equals(rhs);
}

inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
return !(lhs == rhs);
}

} // namespace mlir

//===----------------------------------------------------------------------===//
23 changes: 23 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
@@ -1114,13 +1114,36 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
SmallVector<int64_t> sizes;
SmallVector<bool> scalableDims;
for (VectorDim dim : shape) {
sizes.push_back(dim.getMinSize());
scalableDims.push_back(dim.isScalable());
}
return get(sizes, elementType, scalableDims);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
return get(shape.getSizes(), elementType, shape.getScalableDims());
}]>
];
let extraClassDeclaration = [{
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;

/// Returns the value of the specified dimension (including scalability).
VectorDim getDim(unsigned idx) const {
assert(idx < getRank() && "invalid dim index for vector type");
return getDims()[idx];
}

/// Returns the dimensions of this vector type (including scalability).
VectorDims getDims() const {
return VectorDims(getShape(), getScalableDims());
}

/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
@@ -490,12 +490,11 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
return {};
if (type.getShape().empty())
return VectorType::get({1}, elementType);
Type vectorType = VectorType::get(type.getShape().back(), elementType,
type.getScalableDims().back());
Type vectorType = VectorType::get(elementType, type.getDims().takeBack());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
// Only the trailing dimension can be scalable.
if (llvm::is_contained(type.getScalableDims().drop_back(), true))
if (type.getDims().dropBack().hasScalableDims())
return failure();
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -37,8 +37,7 @@ using namespace mlir::vector;
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
tp.getScalableDims().take_back());
return VectorType::get(tp.getElementType(), tp.getDims().takeBack());
}

// Helper that picks the proper sequence for inserting.
12 changes: 6 additions & 6 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
@@ -319,12 +319,13 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
auto vectorType = dyn_cast<VectorType>(type.getElementType());
// Vectors with leading scalable dims are not supported.
// It may be possible to support these in future by using dynamic memref dims.
if (vectorType.getScalableDims().front())
VectorDim leadingDim = vectorType.getDims().front();
if (leadingDim.isScalable())
return failure();
auto memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
newMemrefShape.push_back(vectorType.getDimSize(0));
newMemrefShape.push_back(leadingDim.getFixedSize());
return MemRefType::get(newMemrefShape,
VectorType::Builder(vectorType).dropDim(0));
}
@@ -1091,18 +1092,17 @@ struct UnrollTransferReadConversion
auto vecType = dyn_cast<VectorType>(vec.getType());
auto xferVecType = xferOp.getVectorType();

if (xferVecType.getScalableDims()[0]) {
VectorDim dim = xferVecType.getDim(0);
if (dim.isScalable()) {
// Cannot unroll a scalable dimension at compile time.
return failure();
}

VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);

int64_t dimSize = xferVecType.getShape()[0];

// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
for (int64_t i = 0; i < dim.getFixedSize(); ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);

vec = generateInBoundsCheck(
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
Original file line number Diff line number Diff line change
@@ -30,8 +30,7 @@ using namespace mlir::arm_sve;
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(sVectorType.getShape(), i1Type,
sVectorType.getScalableDims());
return VectorType::get(i1Type, sVectorType.getDims());
return nullptr;
}

8 changes: 3 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
@@ -1217,9 +1217,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
assert(vecOperand && "Vector operand couldn't be found");

if (firstMaxRankedType) {
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
getElementTypeOrSelf(vecOperand.getType()),
firstMaxRankedType.getScalableDims());
auto vecType = VectorType::get(getElementTypeOrSelf(vecOperand.getType()),
firstMaxRankedType.getDims());
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
} else {
vecOperands.push_back(vecOperand);
@@ -1230,8 +1229,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
firstMaxRankedType
? VectorType::get(firstMaxRankedType.getShape(), resultType,
firstMaxRankedType.getScalableDims())
? VectorType::get(resultType, firstMaxRankedType.getDims())
: resultType);
}
// d. Build and return the new op.
75 changes: 26 additions & 49 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
@@ -422,23 +422,21 @@ MultiDimReductionOp::getShapeForUnroll() {
}

LogicalResult MultiDimReductionOp::verify() {
SmallVector<int64_t> targetShape;
SmallVector<bool> scalableDims;
SmallVector<VectorDim> targetDims;
Type inferredReturnType;
auto sourceScalableDims = getSourceVectorType().getScalableDims();
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
})) {
targetShape.push_back(it.value());
scalableDims.push_back(sourceScalableDims[it.index()]);
for (auto [idx, dim] : llvm::enumerate(getSourceVectorType().getDims()))
if (!llvm::any_of(getReductionDims().getValue(),
[idx = idx](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getValue() == idx;
})) {
targetDims.push_back(dim);
}
// TODO: update to also allow 0-d vectors when available.
if (targetShape.empty())
if (targetDims.empty())
inferredReturnType = getSourceVectorType().getElementType();
else
inferredReturnType = VectorType::get(
targetShape, getSourceVectorType().getElementType(), scalableDims);
inferredReturnType =
VectorType::get(getSourceVectorType().getElementType(), targetDims);
if (getType() != inferredReturnType)
return emitOpError() << "destination type " << getType()
<< " is incompatible with source type "
@@ -450,9 +448,8 @@ LogicalResult MultiDimReductionOp::verify() {
/// Returns the mask type expected by this operation.
Type MultiDimReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getScalableDims());
return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getDims());
}

namespace {
@@ -491,8 +488,7 @@ struct ElideUnitDimsInMultiDimReduction
if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
if (mask) {
VectorType newMaskType =
VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
dstVecType.getScalableDims());
VectorType::get(rewriter.getI1Type(), dstVecType.getDims());
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
}
cast = rewriter.create<vector::ShapeCastOp>(
@@ -559,9 +555,8 @@ LogicalResult ReductionOp::verify() {
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getScalableDims());
return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getDims());
}

Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -1252,8 +1247,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
vectorType.getRank());
inferredReturnTypes.push_back(VectorType::get(
vectorType.getShape().drop_front(n), vectorType.getElementType(),
vectorType.getScalableDims().drop_front(n)));
vectorType.getElementType(), vectorType.getDims().dropFront(n)));
}
return success();
}
@@ -3040,15 +3034,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {

VectorType resType;
if (vRHS) {
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
vRHS.getScalableDims()[0]};
resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
vLHS.getElementType(), scalableDimsRes);
resType = VectorType::get(vLHS.getElementType(),
{vLHS.getDim(0), vRHS.getDim(0)});
} else {
// Scalar RHS operand
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
scalableDimsRes);
resType = VectorType::get(vLHS.getElementType(), {vLHS.getDim(0)});
}

if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -3115,9 +3105,8 @@ LogicalResult OuterProductOp::verify() {
/// verification purposes. It requires the operation to be vectorized."
Type OuterProductOp::getExpectedMaskType() {
auto vecType = this->getResultVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getScalableDims());
return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
vecType.getDims());
}

//===----------------------------------------------------------------------===//
@@ -5064,25 +5053,13 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
/// vector<4x1x1xi1> --> vector<4x1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
ArrayRef<int64_t> newShape = oldShape;

ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
ArrayRef<bool> newScalableDims = oldScalableDims;

while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
newShape = newShape.drop_back(1);
newScalableDims = newScalableDims.drop_back(1);
}

// Make sure we have at least 1 dimension.
// Note: This will always keep at least one dim (even if it's a unit dim).
// TODO: Add support for 0-D vectors.
if (newShape.empty()) {
newShape = oldShape.take_back();
newScalableDims = oldScalableDims.take_back();
}
VectorDims newDims = oldType.getDims();
while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1))
newDims = newDims.dropBack();

return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
return VectorType::get(oldType.getElementType(), newDims);
}

/// Folds qualifying shape_cast(create_mask) into a new create_mask
14 changes: 6 additions & 8 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
@@ -113,13 +113,11 @@ struct TransferReadPermutationLowering
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
// Apply the reverse transpose to deduce the type of the transfer_read.
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
SmallVector<int64_t> newVectorShape(originalShape.size());
ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
SmallVector<bool> newScalableDims(originalShape.size());
auto originalDims = op.getVectorType().getDims();
SmallVector<VectorDim> newVectorDims(op.getVectorType().getRank(),
VectorDim::getFixed(0));
for (const auto &pos : llvm::enumerate(permutation)) {
newVectorShape[pos.value()] = originalShape[pos.index()];
newScalableDims[pos.value()] = originalScalableDims[pos.index()];
newVectorDims[pos.value()] = originalDims[pos.index()];
}

// Transpose in_bounds attribute.
@@ -129,8 +127,8 @@ struct TransferReadPermutationLowering
: ArrayAttr();

// Generate new transfer_read operation.
VectorType newReadType = VectorType::get(
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
VectorType newReadType =
VectorType::get(op.getVectorType().getElementType(), newVectorDims);
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
@@ -344,10 +344,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// Source with leading unit dim (inverse) is also replaced. Unit dim must
// be fixed. Non-unit can be scalable.
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
(resType.getDims().front() == VectorDim::getFixed(1) ||
resType.getDims().back() == VectorDim::getFixed(1)) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
23 changes: 7 additions & 16 deletions mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Original file line number Diff line number Diff line change
@@ -25,24 +25,15 @@ using namespace mlir::vector;
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
ArrayRef<int64_t> newShape = oldShape;

ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
ArrayRef<bool> newScalableDims = oldScalableDims;

while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
}
VectorDims oldDims = oldType.getDims();
VectorDims newDims = oldDims.dropWhile(
[](VectorDim dim) { return dim == VectorDim::getFixed(1); });

// Make sure we have at least 1 dimension per vector type requirements.
if (newShape.empty()) {
newShape = oldShape.take_back();
newScalableDims = oldType.getScalableDims().take_back();
}
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
if (newDims.empty())
newDims = oldDims.takeBack(1);

return VectorType::get(oldType.getElementType(), newDims);
}

/// Return a smallVector of size `rank` containing all zeros.
Original file line number Diff line number Diff line change
@@ -316,15 +316,11 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
/// Trims non-scalable one dimensions from `oldType` and returns the result
/// type.
static VectorType trimNonScalableUnitDims(VectorType oldType) {
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
continue;
newShape.push_back(dimSize);
newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
}
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
auto newDims = llvm::to_vector(
llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) {
return dim != VectorDim::getFixed(1);
}));
return VectorType::get(oldType.getElementType(), newDims);
}

// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
@@ -337,9 +333,9 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
return failure();

SmallVector<Value> reducedOperands;
for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
type.getShape(), type.getScalableDims(), op.getOperands())) {
if (dim == 1 && !dimIsScalable) {
for (auto [dim, operand] :
llvm::zip_equal(type.getDims(), op.getOperands())) {
if (dim == VectorDim::getFixed(1)) {
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
@@ -1039,10 +1039,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
Value mask = rewriter.create<vector::CreateMaskOp>(
loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
vtp.getScalableDims()),
b);
loc, VectorType::get(rewriter.getI1Type(), vtp.getDims()), b);
if (xferOp.getMask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
13 changes: 3 additions & 10 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
@@ -2558,17 +2558,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
})
.Case<VectorType>([&](VectorType vectorTy) {
auto scalableDims = vectorTy.getScalableDims();
os << "vector<";
auto vShape = vectorTy.getShape();
unsigned lastDim = vShape.size();
unsigned dimIdx = 0;
for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
if (!scalableDims.empty() && scalableDims[dimIdx])
os << '[';
os << vShape[dimIdx];
if (!scalableDims.empty() && scalableDims[dimIdx])
os << ']';
auto dims = vectorTy.getDims();
if (!dims.empty()) {
llvm::interleave(dims, os, "x");
os << 'x';
}
printType(vectorTy.getElementType());
4 changes: 2 additions & 2 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
@@ -250,10 +250,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
return VectorType();
if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt, getScalableDims());
return VectorType::get(scaledEt, getDims());
if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt, getScalableDims());
return VectorType::get(scaledEt, getDims());
return VectorType();
}

101 changes: 101 additions & 0 deletions mlir/unittests/IR/ShapedTypeTest.cpp
Original file line number Diff line number Diff line change
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}

TEST(ShapedTypeTest, VectorDims) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
VectorDim::getFixed(8), VectorDim::getScalable(9),
VectorDim::getFixed(1)};
VectorType vectorType = VectorType::get(f32, dims);

// Directly check values
{
auto dim0 = vectorType.getDim(0);
ASSERT_EQ(dim0.getMinSize(), 2);
ASSERT_TRUE(dim0.isFixed());

auto dim1 = vectorType.getDim(1);
ASSERT_EQ(dim1.getMinSize(), 4);
ASSERT_TRUE(dim1.isScalable());

auto dim2 = vectorType.getDim(2);
ASSERT_EQ(dim2.getMinSize(), 8);
ASSERT_TRUE(dim2.isFixed());

auto dim3 = vectorType.getDim(3);
ASSERT_EQ(dim3.getMinSize(), 9);
ASSERT_TRUE(dim3.isScalable());

auto dim4 = vectorType.getDim(4);
ASSERT_EQ(dim4.getMinSize(), 1);
ASSERT_TRUE(dim4.isFixed());
}

// Test indexing via getDim(idx)
{
for (unsigned i = 0; i < dims.size(); i++)
ASSERT_EQ(vectorType.getDim(i), dims[i]);
}

// Test using VectorDims::Iterator in for-each loop
{
unsigned i = 0;
for (VectorDim dim : vectorType.getDims())
ASSERT_EQ(dim, dims[i++]);
ASSERT_EQ(i, vectorType.getRank());
}

// Test using VectorDims::Iterator in LLVM iterator helper
{
for (auto [dim, expectedDim] :
llvm::zip_equal(vectorType.getDims(), dims)) {
ASSERT_EQ(dim, expectedDim);
}
}

// Test dropFront()
{
auto vectorDims = vectorType.getDims();
auto newDims = vectorDims.dropFront();

ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
for (unsigned i = 0; i < newDims.size(); i++)
ASSERT_EQ(newDims[i], vectorDims[i + 1]);
}

// Test dropBack()
{
auto vectorDims = vectorType.getDims();
auto newDims = vectorDims.dropBack();

ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
for (unsigned i = 0; i < newDims.size(); i++)
ASSERT_EQ(newDims[i], vectorDims[i]);
}

// Test front()
{ ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }

// Test back()
{ ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }

// Test dropWhile.
{
SmallVector<VectorDim> dims{
VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
VectorDim::getScalable(1), VectorDim::getScalable(4)};

VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
unsigned(vectorTypeWithLeadingUnitDims.getRank()));

// Drop leading unit dims.
auto withoutLeadingUnitDims =
vectorTypeWithLeadingUnitDims.getDims().dropWhile(
[](VectorDim dim) { return dim == VectorDim::getFixed(1); });

SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
VectorDim::getScalable(4)};
ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
}
}

} // namespace