Skip to content

Commit f69fad1

Browse files
committed
Demonstrate using the new APIs in scalable-aware code :)
This is not a complete change, this just updates a few examples found by grepping for getScalableDims().
1 parent ca8250c commit f69fad1

File tree

13 files changed

+68
-123
lines changed

13 files changed

+68
-123
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,11 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
490490
return {};
491491
if (type.getShape().empty())
492492
return VectorType::get({1}, elementType);
493-
Type vectorType = VectorType::get(type.getShape().back(), elementType,
494-
type.getScalableDims().back());
493+
Type vectorType = VectorType::get(elementType, type.getDims().takeBack());
495494
assert(LLVM::isCompatibleVectorType(vectorType) &&
496495
"expected vector type compatible with the LLVM dialect");
497496
// Only the trailing dimension can be scalable.
498-
if (llvm::is_contained(type.getScalableDims().drop_back(), true))
497+
if (type.getDims().dropBack().hasScalableDims())
499498
return failure();
500499
auto shape = type.getShape();
501500
for (int i = shape.size() - 2; i >= 0; --i)

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ using namespace mlir::vector;
3737
// Helper to reduce vector type by *all* but one rank at back.
3838
static VectorType reducedVectorTypeBack(VectorType tp) {
3939
assert((tp.getRank() > 1) && "unlowerable vector type");
40-
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41-
tp.getScalableDims().take_back());
40+
return VectorType::get(tp.getElementType(), tp.getDims().takeBack());
4241
}
4342

4443
// Helper that picks the proper sequence for inserting.

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,13 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
319319
auto vectorType = dyn_cast<VectorType>(type.getElementType());
320320
// Vectors with leading scalable dims are not supported.
321321
// It may be possible to support these in future by using dynamic memref dims.
322-
if (vectorType.getScalableDims().front())
322+
VectorDim leadingDim = vectorType.getDims().front();
323+
if (leadingDim.isScalable())
323324
return failure();
324325
auto memrefShape = type.getShape();
325326
SmallVector<int64_t, 8> newMemrefShape;
326327
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
327-
newMemrefShape.push_back(vectorType.getDimSize(0));
328+
newMemrefShape.push_back(leadingDim.getFixedSize());
328329
return MemRefType::get(newMemrefShape,
329330
VectorType::Builder(vectorType).dropDim(0));
330331
}
@@ -1091,18 +1092,17 @@ struct UnrollTransferReadConversion
10911092
auto vecType = dyn_cast<VectorType>(vec.getType());
10921093
auto xferVecType = xferOp.getVectorType();
10931094

1094-
if (xferVecType.getScalableDims()[0]) {
1095+
VectorDim dim = xferVecType.getDim(0);
1096+
if (dim.isScalable()) {
10951097
// Cannot unroll a scalable dimension at compile time.
10961098
return failure();
10971099
}
10981100

10991101
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
11001102

1101-
int64_t dimSize = xferVecType.getShape()[0];
1102-
11031103
// Generate fully unrolled loop of transfer ops.
11041104
Location loc = xferOp.getLoc();
1105-
for (int64_t i = 0; i < dimSize; ++i) {
1105+
for (int64_t i = 0; i < dim.getFixedSize(); ++i) {
11061106
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
11071107

11081108
vec = generateInBoundsCheck(

mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ using namespace mlir::arm_sve;
3030
static Type getI1SameShape(Type type) {
3131
auto i1Type = IntegerType::get(type.getContext(), 1);
3232
if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
33-
return VectorType::get(sVectorType.getShape(), i1Type,
34-
sVectorType.getScalableDims());
33+
return VectorType::get(i1Type, sVectorType.getDims());
3534
return nullptr;
3635
}
3736

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,9 +1217,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12171217
assert(vecOperand && "Vector operand couldn't be found");
12181218

12191219
if (firstMaxRankedType) {
1220-
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1221-
getElementTypeOrSelf(vecOperand.getType()),
1222-
firstMaxRankedType.getScalableDims());
1220+
auto vecType = VectorType::get(getElementTypeOrSelf(vecOperand.getType()),
1221+
firstMaxRankedType.getDims());
12231222
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
12241223
} else {
12251224
vecOperands.push_back(vecOperand);
@@ -1230,8 +1229,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12301229
for (Type resultType : op->getResultTypes()) {
12311230
resultTypes.push_back(
12321231
firstMaxRankedType
1233-
? VectorType::get(firstMaxRankedType.getShape(), resultType,
1234-
firstMaxRankedType.getScalableDims())
1232+
? VectorType::get(resultType, firstMaxRankedType.getDims())
12351233
: resultType);
12361234
}
12371235
// d. Build and return the new op.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -422,23 +422,21 @@ MultiDimReductionOp::getShapeForUnroll() {
422422
}
423423

424424
LogicalResult MultiDimReductionOp::verify() {
425-
SmallVector<int64_t> targetShape;
426-
SmallVector<bool> scalableDims;
425+
SmallVector<VectorDim> targetDims;
427426
Type inferredReturnType;
428-
auto sourceScalableDims = getSourceVectorType().getScalableDims();
429-
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
430-
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
431-
return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
432-
})) {
433-
targetShape.push_back(it.value());
434-
scalableDims.push_back(sourceScalableDims[it.index()]);
427+
for (auto [idx, dim] : llvm::enumerate(getSourceVectorType().getDims()))
428+
if (!llvm::any_of(getReductionDims().getValue(),
429+
[idx = idx](Attribute attr) {
430+
return llvm::cast<IntegerAttr>(attr).getValue() == idx;
431+
})) {
432+
targetDims.push_back(dim);
435433
}
436434
// TODO: update to also allow 0-d vectors when available.
437-
if (targetShape.empty())
435+
if (targetDims.empty())
438436
inferredReturnType = getSourceVectorType().getElementType();
439437
else
440-
inferredReturnType = VectorType::get(
441-
targetShape, getSourceVectorType().getElementType(), scalableDims);
438+
inferredReturnType =
439+
VectorType::get(getSourceVectorType().getElementType(), targetDims);
442440
if (getType() != inferredReturnType)
443441
return emitOpError() << "destination type " << getType()
444442
<< " is incompatible with source type "
@@ -450,9 +448,8 @@ LogicalResult MultiDimReductionOp::verify() {
450448
/// Returns the mask type expected by this operation.
451449
Type MultiDimReductionOp::getExpectedMaskType() {
452450
auto vecType = getSourceVectorType();
453-
return VectorType::get(vecType.getShape(),
454-
IntegerType::get(vecType.getContext(), /*width=*/1),
455-
vecType.getScalableDims());
451+
return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
452+
vecType.getDims());
456453
}
457454

458455
namespace {
@@ -491,8 +488,7 @@ struct ElideUnitDimsInMultiDimReduction
491488
if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
492489
if (mask) {
493490
VectorType newMaskType =
494-
VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
495-
dstVecType.getScalableDims());
491+
VectorType::get(rewriter.getI1Type(), dstVecType.getDims());
496492
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
497493
}
498494
cast = rewriter.create<vector::ShapeCastOp>(
@@ -559,9 +555,8 @@ LogicalResult ReductionOp::verify() {
559555
/// Returns the mask type expected by this operation.
560556
Type ReductionOp::getExpectedMaskType() {
561557
auto vecType = getSourceVectorType();
562-
return VectorType::get(vecType.getShape(),
563-
IntegerType::get(vecType.getContext(), /*width=*/1),
564-
vecType.getScalableDims());
558+
return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
559+
vecType.getDims());
565560
}
566561

567562
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -1252,8 +1247,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
12521247
auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
12531248
vectorType.getRank());
12541249
inferredReturnTypes.push_back(VectorType::get(
1255-
vectorType.getShape().drop_front(n), vectorType.getElementType(),
1256-
vectorType.getScalableDims().drop_front(n)));
1250+
vectorType.getElementType(), vectorType.getDims().dropFront(n)));
12571251
}
12581252
return success();
12591253
}
@@ -3040,15 +3034,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
30403034

30413035
VectorType resType;
30423036
if (vRHS) {
3043-
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3044-
vRHS.getScalableDims()[0]};
3045-
resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3046-
vLHS.getElementType(), scalableDimsRes);
3037+
resType = VectorType::get(vLHS.getElementType(),
3038+
{vLHS.getDim(0), vRHS.getDim(0)});
30473039
} else {
30483040
// Scalar RHS operand
3049-
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3050-
resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3051-
scalableDimsRes);
3041+
resType = VectorType::get(vLHS.getElementType(), {vLHS.getDim(0)});
30523042
}
30533043

30543044
if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -3115,9 +3105,8 @@ LogicalResult OuterProductOp::verify() {
31153105
/// verification purposes. It requires the operation to be vectorized."
31163106
Type OuterProductOp::getExpectedMaskType() {
31173107
auto vecType = this->getResultVectorType();
3118-
return VectorType::get(vecType.getShape(),
3119-
IntegerType::get(vecType.getContext(), /*width=*/1),
3120-
vecType.getScalableDims());
3108+
return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
3109+
vecType.getDims());
31213110
}
31223111

31233112
//===----------------------------------------------------------------------===//
@@ -5064,25 +5053,13 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
50645053
/// vector<4x1x1xi1> --> vector<4x1>
50655054
///
50665055
static VectorType trimTrailingOneDims(VectorType oldType) {
5067-
ArrayRef<int64_t> oldShape = oldType.getShape();
5068-
ArrayRef<int64_t> newShape = oldShape;
5069-
5070-
ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5071-
ArrayRef<bool> newScalableDims = oldScalableDims;
5072-
5073-
while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5074-
newShape = newShape.drop_back(1);
5075-
newScalableDims = newScalableDims.drop_back(1);
5076-
}
5077-
5078-
// Make sure we have at least 1 dimension.
5056+
// Note: This will always keep at least one dim (even if it's a unit dim).
50795057
// TODO: Add support for 0-D vectors.
5080-
if (newShape.empty()) {
5081-
newShape = oldShape.take_back();
5082-
newScalableDims = oldScalableDims.take_back();
5083-
}
5058+
VectorDims newDims = oldType.getDims();
5059+
while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1))
5060+
newDims = newDims.dropBack();
50845061

5085-
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5062+
return VectorType::get(oldType.getElementType(), newDims);
50865063
}
50875064

50885065
/// Folds qualifying shape_cast(create_mask) into a new create_mask

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,11 @@ struct TransferReadPermutationLowering
113113
permutationMap = inversePermutation(permutationMap);
114114
AffineMap newMap = permutationMap.compose(map);
115115
// Apply the reverse transpose to deduce the type of the transfer_read.
116-
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
117-
SmallVector<int64_t> newVectorShape(originalShape.size());
118-
ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
119-
SmallVector<bool> newScalableDims(originalShape.size());
116+
auto originalDims = op.getVectorType().getDims();
117+
SmallVector<VectorDim> newVectorDims(op.getVectorType().getRank(),
118+
VectorDim::getFixed(0));
120119
for (const auto &pos : llvm::enumerate(permutation)) {
121-
newVectorShape[pos.value()] = originalShape[pos.index()];
122-
newScalableDims[pos.value()] = originalScalableDims[pos.index()];
120+
newVectorDims[pos.value()] = originalDims[pos.index()];
123121
}
124122

125123
// Transpose in_bounds attribute.
@@ -129,8 +127,8 @@ struct TransferReadPermutationLowering
129127
: ArrayAttr();
130128

131129
// Generate new transfer_read operation.
132-
VectorType newReadType = VectorType::get(
133-
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
130+
VectorType newReadType =
131+
VectorType::get(op.getVectorType().getElementType(), newVectorDims);
134132
Value newRead = rewriter.create<vector::TransferReadOp>(
135133
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
136134
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
344344
// Source with leading unit dim (inverse) is also replaced. Unit dim must
345345
// be fixed. Non-unit can be scalable.
346346
if (resType.getRank() == 2 &&
347-
((resType.getShape().front() == 1 &&
348-
!resType.getScalableDims().front()) ||
349-
(resType.getShape().back() == 1 &&
350-
!resType.getScalableDims().back())) &&
347+
(resType.getDims().front() == VectorDim::getFixed(1) ||
348+
resType.getDims().back() == VectorDim::getFixed(1)) &&
351349
transp == ArrayRef<int64_t>({1, 0})) {
352350
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
353351
return success();

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,15 @@ using namespace mlir::vector;
2525
// Trims leading one dimensions from `oldType` and returns the result type.
2626
// Returns `vector<1xT>` if `oldType` only has one element.
2727
static VectorType trimLeadingOneDims(VectorType oldType) {
28-
ArrayRef<int64_t> oldShape = oldType.getShape();
29-
ArrayRef<int64_t> newShape = oldShape;
30-
31-
ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
32-
ArrayRef<bool> newScalableDims = oldScalableDims;
33-
34-
while (!newShape.empty() && newShape.front() == 1 &&
35-
!newScalableDims.front()) {
36-
newShape = newShape.drop_front(1);
37-
newScalableDims = newScalableDims.drop_front(1);
38-
}
28+
VectorDims oldDims = oldType.getDims();
29+
VectorDims newDims = oldDims.dropWhile(
30+
[](VectorDim dim) { return dim == VectorDim::getFixed(1); });
3931

4032
// Make sure we have at least 1 dimension per vector type requirements.
41-
if (newShape.empty()) {
42-
newShape = oldShape.take_back();
43-
newScalableDims = oldType.getScalableDims().take_back();
44-
}
45-
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
33+
if (newDims.empty())
34+
newDims = oldDims.takeBack(1);
35+
36+
return VectorType::get(oldType.getElementType(), newDims);
4637
}
4738

4839
/// Return a smallVector of size `rank` containing all zeros.

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,11 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
316316
/// Trims non-scalable one dimensions from `oldType` and returns the result
317317
/// type.
318318
static VectorType trimNonScalableUnitDims(VectorType oldType) {
319-
SmallVector<int64_t> newShape;
320-
SmallVector<bool> newScalableDims;
321-
for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
322-
if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
323-
continue;
324-
newShape.push_back(dimSize);
325-
newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
326-
}
327-
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
319+
auto newDims = llvm::to_vector(
320+
llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) {
321+
return dim != VectorDim::getFixed(1);
322+
}));
323+
return VectorType::get(oldType.getElementType(), newDims);
328324
}
329325

330326
// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
@@ -337,9 +333,9 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
337333
return failure();
338334

339335
SmallVector<Value> reducedOperands;
340-
for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
341-
type.getShape(), type.getScalableDims(), op.getOperands())) {
342-
if (dim == 1 && !dimIsScalable) {
336+
for (auto [dim, operand] :
337+
llvm::zip_equal(type.getDims(), op.getOperands())) {
338+
if (dim == VectorDim::getFixed(1)) {
343339
// If the mask for the unit dim is not a constant of 1, do nothing.
344340
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
345341
if (!constant || (constant.value() != 1))

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,10 +1039,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
10391039
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
10401040
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
10411041
Value mask = rewriter.create<vector::CreateMaskOp>(
1042-
loc,
1043-
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1044-
vtp.getScalableDims()),
1045-
b);
1042+
loc, VectorType::get(rewriter.getI1Type(), vtp.getDims()), b);
10461043
if (xferOp.getMask()) {
10471044
// Intersect the in-bounds with the mask specified as an op parameter.
10481045
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,17 +2558,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
25582558
}
25592559
})
25602560
.Case<VectorType>([&](VectorType vectorTy) {
2561-
auto scalableDims = vectorTy.getScalableDims();
25622561
os << "vector<";
2563-
auto vShape = vectorTy.getShape();
2564-
unsigned lastDim = vShape.size();
2565-
unsigned dimIdx = 0;
2566-
for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
2567-
if (!scalableDims.empty() && scalableDims[dimIdx])
2568-
os << '[';
2569-
os << vShape[dimIdx];
2570-
if (!scalableDims.empty() && scalableDims[dimIdx])
2571-
os << ']';
2562+
auto dims = vectorTy.getDims();
2563+
if (!dims.empty()) {
2564+
llvm::interleave(dims, os, "x");
25722565
os << 'x';
25732566
}
25742567
printType(vectorTy.getElementType());

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
250250
return VectorType();
251251
if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
252252
if (auto scaledEt = et.scaleElementBitwidth(scale))
253-
return VectorType::get(getShape(), scaledEt, getScalableDims());
253+
return VectorType::get(scaledEt, getDims());
254254
if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
255255
if (auto scaledEt = et.scaleElementBitwidth(scale))
256-
return VectorType::get(getShape(), scaledEt, getScalableDims());
256+
return VectorType::get(scaledEt, getDims());
257257
return VectorType();
258258
}
259259

0 commit comments

Comments
 (0)