@@ -422,23 +422,21 @@ MultiDimReductionOp::getShapeForUnroll() {
422
422
}
423
423
424
424
LogicalResult MultiDimReductionOp::verify () {
425
- SmallVector<int64_t > targetShape;
426
- SmallVector<bool > scalableDims;
425
+ SmallVector<VectorDim> targetDims;
427
426
Type inferredReturnType;
428
- auto sourceScalableDims = getSourceVectorType ().getScalableDims ();
429
- for (auto it : llvm::enumerate (getSourceVectorType ().getShape ()))
430
- if (!llvm::any_of (getReductionDims ().getValue (), [&](Attribute attr) {
431
- return llvm::cast<IntegerAttr>(attr).getValue () == it.index ();
432
- })) {
433
- targetShape.push_back (it.value ());
434
- scalableDims.push_back (sourceScalableDims[it.index ()]);
427
+ for (auto [idx, dim] : llvm::enumerate (getSourceVectorType ().getDims ()))
428
+ if (!llvm::any_of (getReductionDims ().getValue (),
429
+ [idx = idx](Attribute attr) {
430
+ return llvm::cast<IntegerAttr>(attr).getValue () == idx;
431
+ })) {
432
+ targetDims.push_back (dim);
435
433
}
436
434
// TODO: update to also allow 0-d vectors when available.
437
- if (targetShape .empty ())
435
+ if (targetDims .empty ())
438
436
inferredReturnType = getSourceVectorType ().getElementType ();
439
437
else
440
- inferredReturnType = VectorType::get (
441
- targetShape, getSourceVectorType ().getElementType (), scalableDims );
438
+ inferredReturnType =
439
+ VectorType::get ( getSourceVectorType ().getElementType (), targetDims );
442
440
if (getType () != inferredReturnType)
443
441
return emitOpError () << " destination type " << getType ()
444
442
<< " is incompatible with source type "
@@ -450,9 +448,8 @@ LogicalResult MultiDimReductionOp::verify() {
450
448
// / Returns the mask type expected by this operation.
451
449
Type MultiDimReductionOp::getExpectedMaskType () {
452
450
auto vecType = getSourceVectorType ();
453
- return VectorType::get (vecType.getShape (),
454
- IntegerType::get (vecType.getContext (), /* width=*/ 1 ),
455
- vecType.getScalableDims ());
451
+ return VectorType::get (IntegerType::get (vecType.getContext (), /* width=*/ 1 ),
452
+ vecType.getDims ());
456
453
}
457
454
458
455
namespace {
@@ -491,8 +488,7 @@ struct ElideUnitDimsInMultiDimReduction
491
488
if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType ())) {
492
489
if (mask) {
493
490
VectorType newMaskType =
494
- VectorType::get (dstVecType.getShape (), rewriter.getI1Type (),
495
- dstVecType.getScalableDims ());
491
+ VectorType::get (rewriter.getI1Type (), dstVecType.getDims ());
496
492
mask = rewriter.create <vector::ShapeCastOp>(loc, newMaskType, mask);
497
493
}
498
494
cast = rewriter.create <vector::ShapeCastOp>(
@@ -559,9 +555,8 @@ LogicalResult ReductionOp::verify() {
559
555
// / Returns the mask type expected by this operation.
560
556
Type ReductionOp::getExpectedMaskType () {
561
557
auto vecType = getSourceVectorType ();
562
- return VectorType::get (vecType.getShape (),
563
- IntegerType::get (vecType.getContext (), /* width=*/ 1 ),
564
- vecType.getScalableDims ());
558
+ return VectorType::get (IntegerType::get (vecType.getContext (), /* width=*/ 1 ),
559
+ vecType.getDims ());
565
560
}
566
561
567
562
Value mlir::vector::getVectorReductionOp (arith::AtomicRMWKind op,
@@ -1252,8 +1247,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1252
1247
auto n = std::min<size_t >(adaptor.getStaticPosition ().size (),
1253
1248
vectorType.getRank ());
1254
1249
inferredReturnTypes.push_back (VectorType::get (
1255
- vectorType.getShape ().drop_front (n), vectorType.getElementType (),
1256
- vectorType.getScalableDims ().drop_front (n)));
1250
+ vectorType.getElementType (), vectorType.getDims ().dropFront (n)));
1257
1251
}
1258
1252
return success ();
1259
1253
}
@@ -3040,15 +3034,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3040
3034
3041
3035
VectorType resType;
3042
3036
if (vRHS) {
3043
- SmallVector<bool > scalableDimsRes{vLHS.getScalableDims ()[0 ],
3044
- vRHS.getScalableDims ()[0 ]};
3045
- resType = VectorType::get ({vLHS.getDimSize (0 ), vRHS.getDimSize (0 )},
3046
- vLHS.getElementType (), scalableDimsRes);
3037
+ resType = VectorType::get (vLHS.getElementType (),
3038
+ {vLHS.getDim (0 ), vRHS.getDim (0 )});
3047
3039
} else {
3048
3040
// Scalar RHS operand
3049
- SmallVector<bool > scalableDimsRes{vLHS.getScalableDims ()[0 ]};
3050
- resType = VectorType::get ({vLHS.getDimSize (0 )}, vLHS.getElementType (),
3051
- scalableDimsRes);
3041
+ resType = VectorType::get (vLHS.getElementType (), {vLHS.getDim (0 )});
3052
3042
}
3053
3043
3054
3044
if (!result.attributes .get (OuterProductOp::getKindAttrName (result.name ))) {
@@ -3115,9 +3105,8 @@ LogicalResult OuterProductOp::verify() {
3115
3105
// / verification purposes. It requires the operation to be vectorized."
3116
3106
Type OuterProductOp::getExpectedMaskType () {
3117
3107
auto vecType = this ->getResultVectorType ();
3118
- return VectorType::get (vecType.getShape (),
3119
- IntegerType::get (vecType.getContext (), /* width=*/ 1 ),
3120
- vecType.getScalableDims ());
3108
+ return VectorType::get (IntegerType::get (vecType.getContext (), /* width=*/ 1 ),
3109
+ vecType.getDims ());
3121
3110
}
3122
3111
3123
3112
// ===----------------------------------------------------------------------===//
@@ -5064,25 +5053,13 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5064
5053
// / vector<4x1x1xi1> --> vector<4x1>
5065
5054
// /
5066
5055
static VectorType trimTrailingOneDims (VectorType oldType) {
5067
- ArrayRef<int64_t > oldShape = oldType.getShape ();
5068
- ArrayRef<int64_t > newShape = oldShape;
5069
-
5070
- ArrayRef<bool > oldScalableDims = oldType.getScalableDims ();
5071
- ArrayRef<bool > newScalableDims = oldScalableDims;
5072
-
5073
- while (!newShape.empty () && newShape.back () == 1 && !newScalableDims.back ()) {
5074
- newShape = newShape.drop_back (1 );
5075
- newScalableDims = newScalableDims.drop_back (1 );
5076
- }
5077
-
5078
- // Make sure we have at least 1 dimension.
5056
+ // Note: This will always keep at least one dim (even if it's a unit dim).
5079
5057
// TODO: Add support for 0-D vectors.
5080
- if (newShape.empty ()) {
5081
- newShape = oldShape.take_back ();
5082
- newScalableDims = oldScalableDims.take_back ();
5083
- }
5058
+ VectorDims newDims = oldType.getDims ();
5059
+ while (newDims.size () > 1 && newDims.back () == VectorDim::getFixed (1 ))
5060
+ newDims = newDims.dropBack ();
5084
5061
5085
- return VectorType::get (newShape, oldType.getElementType (), newScalableDims );
5062
+ return VectorType::get (oldType.getElementType (), newDims );
5086
5063
}
5087
5064
5088
5065
// / Folds qualifying shape_cast(create_mask) into a new create_mask
0 commit comments