diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 5b81d0d33d484..5863b0b7d45ab 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -35,8 +35,8 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, /// Extend the rank of a vector Value by `addedRanks` by adding outer unit /// dimensions. -static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, - int64_t addedRank) { +static TypedValue extendVectorRank(OpBuilder &builder, Location loc, + Value vec, int64_t addedRank) { auto originalVecType = cast(vec.getType()); SmallVector newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), @@ -53,16 +53,21 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, /// Extend the rank of a vector Value by `addedRanks` by adding inner unit /// dimensions. static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, - int64_t addedRank) { - Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); - SmallVector permutation; - for (int64_t i = addedRank, - e = cast(broadcasted.getType()).getRank(); - i < e; ++i) - permutation.push_back(i); - for (int64_t i = 0; i < addedRank; ++i) - permutation.push_back(i); - return builder.create(loc, broadcasted, permutation); + ArrayRef missingInnerDims) { + TypedValue broadcasted = + extendVectorRank(builder, loc, vec, missingInnerDims.size()); + SmallVector missing(broadcasted.getType().getRank(), false); + SmallVector inversePerm; + for (int64_t i : missingInnerDims) { + inversePerm.push_back(i); + missing[i] = true; + } + for (auto [i, used] : llvm::enumerate(missing)) { + if (!used) + inversePerm.push_back(i); + } + return builder.create( + loc, broadcasted, invertPermutationVector(inversePerm)); } //===----------------------------------------------------------------------===// @@ -268,28 +273,30 @@ struct TransferWriteNonPermutationLowering for (AffineExpr exp : map.getResults()) foundDim[cast(exp).getPosition()] = true; SmallVector exprs; - bool foundFirstDim = false; + std::optional firstDim = std::nullopt; SmallVector missingInnerDim; for (size_t i = 0; i < foundDim.size(); i++) { if (foundDim[i]) { - foundFirstDim = true; + if (!firstDim) { + firstDim = i; + } continue; } - if (!foundFirstDim) + if (!firstDim) continue; // Once we found one outer dimension existing in the map keep track of all // the missing dimensions after that. - missingInnerDim.push_back(i); + missingInnerDim.push_back(i - firstDim.value()); exprs.push_back(rewriter.getAffineDimExpr(i)); } // Vector: add unit dims at the beginning of the shape. Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), missingInnerDim.size()); - // Mask: add unit dims at the end of the shape. + // Mask: add unit dims at the positions of the missing dimensions. Value newMask; if (op.getMask()) - newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), - missingInnerDim.size()); + newMask = + extendMaskRank(rewriter, op.getLoc(), op.getMask(), missingInnerDim); exprs.append(map.getResults().begin(), map.getResults().end()); AffineMap newMap = AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index 3ae18835c8367..07e49d7e12e25 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -160,6 +160,31 @@ func.func @xfer_write_non_minor_identity_with_mask_out_of_bounds( return } +// CHECK-LABEL: func.func @xfer_write_non_minor_identity_with_mask_broadcast( +// CHECK-SAME: %[[MEM:.*]]: memref, +// CHECK-SAME: %[[VEC:.*]]: vector<7x8xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<7x8xi1>, +// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, %[[IDX_3:.*]]: index) { +// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7x8xf32> to vector<1x7x8xf32> +// CHECK: %[[MBC:.*]] = vector.broadcast %[[MASK]] : vector<7x8xi1> to vector<1x7x8xi1> +// CHECK: %[[MTR:.*]] = vector.transpose %[[MBC]], [1, 0, 2] : vector<1x7x8xi1> to vector<7x1x8xi1> +// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0, 2] : vector<1x7x8xf32> to vector<7x1x8xf32> +// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[IDX_1]], %[[IDX_2]], %[[IDX_3]]], %[[MTR]] {in_bounds = [false, true, false]} : vector<7x1x8xf32>, memref +func.func @xfer_write_non_minor_identity_with_mask_broadcast_and_transpose( + %mem : memref, + %vec : vector<7x8xf32>, + %mask : vector<7x8xi1>, + %idx_1 : index, + %idx_2 : index, + %idx_3 : index) { + + vector.transfer_write %vec, %mem[%idx_1, %idx_2, %idx_3], %mask { + permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)> + } : vector<7x8xf32>, memref + + return +} + // CHECK-LABEL: func.func @xfer_write_non_minor_identity_with_mask_scalable( // CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>, // CHECK-SAME: %[[MEM:.*]]: memref<1x4x?x1xi16>,