Skip to content

Commit cf0efb3

Browse files
authored
[mlir][vector] Decouple unrolling gather and gather to llvm lowering (#132206)
This patch decouples unrolling vector.gather and lowering vector.gather to llvm.masked.gather. This is consistent with how vector.load, vector.store, vector.maskedload, vector.maskedstore lower to LLVM. Some interesting test changes from this patch: - 2D vector.gather lowering to llvm tests are deleted. This is consistent with other memory load/store ops. - There are still tests for 2D vector.gather, but the constant mask for these test is modified. This is because with the updated lowering, one of the unrolled vector.gather disappears because it is masked off (also demonstrating why this is a better lowering path) Overall, this makes vector.gather take the same consistent path for lowering to LLVM as other load/store ops. Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13
1 parent 94783a8 commit cf0efb3

File tree

7 files changed

+39
-92
lines changed

7 files changed

+39
-92
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,20 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
241241

242242
/// Populate the pattern set with the following patterns:
243243
///
244-
/// [FlattenGather]
245-
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
244+
/// [UnrollGather]
245+
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
246246
/// outermost dimension.
247+
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
248+
PatternBenefit benefit = 1);
249+
250+
/// Populate the pattern set with the following patterns:
247251
///
248252
/// [Gather1DToConditionalLoads]
249253
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
250254
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
251255
/// loads/extracts are made conditional using `scf.if` ops.
252-
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
253-
PatternBenefit benefit = 1);
256+
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
257+
PatternBenefit benefit = 1);
254258

255259
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
256260
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -269,49 +269,30 @@ class VectorGatherOpConversion
269269
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
270270
return failure();
271271

272-
auto loc = gather->getLoc();
272+
VectorType vType = gather.getVectorType();
273+
if (vType.getRank() > 1)
274+
return failure();
275+
276+
Location loc = gather->getLoc();
273277

274278
// Resolve alignment.
275279
unsigned align;
276280
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
277281
return failure();
278282

283+
// Resolve address.
279284
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
280285
adaptor.getIndices(), rewriter);
281286
Value base = adaptor.getBase();
287+
Value ptrs =
288+
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
289+
base, ptr, adaptor.getIndexVec(), vType);
282290

283-
auto llvmNDVectorTy = adaptor.getIndexVec().getType();
284-
// Handle the simple case of 1-D vector.
285-
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
286-
auto vType = gather.getVectorType();
287-
// Resolve address.
288-
Value ptrs =
289-
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
290-
base, ptr, adaptor.getIndexVec(), vType);
291-
// Replace with the gather intrinsic.
292-
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
293-
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
294-
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
295-
return success();
296-
}
297-
298-
const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
299-
auto callback = [align, memRefType, base, ptr, loc, &rewriter,
300-
&typeConverter](Type llvm1DVectorTy,
301-
ValueRange vectorOperands) {
302-
// Resolve address.
303-
Value ptrs = getIndexedPtrs(
304-
rewriter, loc, typeConverter, memRefType, base, ptr,
305-
/*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
306-
// Create the gather intrinsic.
307-
return rewriter.create<LLVM::masked_gather>(
308-
loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
309-
/*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
310-
};
311-
SmallVector<Value> vectorOperands = {
312-
adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
313-
return LLVM::detail::handleMultidimensionalVectors(
314-
gather, vectorOperands, *getTypeConverter(), callback, rewriter);
291+
// Replace with the gather intrinsic.
292+
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
293+
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
294+
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
295+
return success();
315296
}
316297
};
317298

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8181
populateVectorInsertExtractStridedSliceTransforms(patterns);
8282
populateVectorStepLoweringPatterns(patterns);
8383
populateVectorRankReducingFMAPattern(patterns);
84+
populateVectorGatherLoweringPatterns(patterns);
8485
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
8586
}
8687

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using namespace mlir;
3838
using namespace mlir::vector;
3939

4040
namespace {
41-
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
41+
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
4242
/// outermost dimension. For example:
4343
/// ```
4444
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
5656
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
5757
///
5858
/// Supports vector types with a fixed leading dimension.
59-
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
59+
struct UnrollGather : OpRewritePattern<vector::GatherOp> {
6060
using OpRewritePattern::OpRewritePattern;
6161

6262
LogicalResult matchAndRewrite(vector::GatherOp op,
6363
PatternRewriter &rewriter) const override {
6464
VectorType resultTy = op.getType();
6565
if (resultTy.getRank() < 2)
66-
return rewriter.notifyMatchFailure(op, "already flat");
66+
return rewriter.notifyMatchFailure(op, "already 1-D");
6767

6868
// Unrolling doesn't take vscale into account. Pattern is disabled for
6969
// vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
107107
/// ```mlir
108108
/// %subview = memref.subview %M (...)
109109
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
110-
/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
110+
/// %gather = vector.gather %subview[%idxs] (...)
111+
/// : memref<100xf32, strided<[3]>>
111112
/// ```
112113
/// ==>
113114
/// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
269270

270271
void mlir::vector::populateVectorGatherLoweringPatterns(
271272
RewritePatternSet &patterns, PatternBenefit benefit) {
272-
patterns.add<FlattenGather, RemoveStrideFromGatherSource,
273-
Gather1DToConditionalLoads>(patterns.getContext(), benefit);
273+
patterns.add<UnrollGather>(patterns.getContext(), benefit);
274+
}
275+
276+
void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
277+
RewritePatternSet &patterns, PatternBenefit benefit) {
278+
patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
279+
patterns.getContext(), benefit);
274280
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
20742074

20752075
// -----
20762076

2077-
func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
2078-
%0 = arith.constant 0: index
2079-
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
2080-
return %1 : vector<2x3xf32>
2081-
}
2082-
2083-
// CHECK-LABEL: func @gather_2d_from_1d
2084-
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
2085-
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
2086-
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
2087-
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
2088-
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2089-
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2090-
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
2091-
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
2092-
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
2093-
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
2094-
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2095-
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2096-
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
2097-
2098-
// -----
2099-
2100-
func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
2101-
%0 = arith.constant 0: index
2102-
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
2103-
return %1 : vector<2x[3]xf32>
2104-
}
2105-
2106-
// CHECK-LABEL: func @gather_2d_from_1d_scalable
2107-
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
2108-
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
2109-
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
2110-
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
2111-
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2112-
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2113-
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
2114-
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
2115-
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
2116-
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
2117-
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2118-
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2119-
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
2120-
2121-
// -----
2122-
21232077

21242078
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
21252079
%0 = arith.constant 3 : index

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
16631663

16641664
func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
16651665
%0 = arith.constant 0: index
1666-
%1 = vector.constant_mask [1, 2] : vector<2x3xi1>
1666+
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
16671667
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
16681668
return %2 : vector<2x3xf32>
16691669
}
@@ -1677,9 +1677,9 @@ func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2:
16771677
func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
16781678
%0 = arith.constant 0: index
16791679
// vector.constant_mask only supports 'none set' or 'all set' scalable
1680-
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
1680+
// dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
16811681
// width vectors above.
1682-
%1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
1682+
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
16831683
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
16841684
return %2 : vector<2x[3]xf32>
16851685
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,7 @@ struct TestVectorGatherLowering
782782
void runOnOperation() override {
783783
RewritePatternSet patterns(&getContext());
784784
populateVectorGatherLoweringPatterns(patterns);
785+
populateVectorGatherToConditionalLoadPatterns(patterns);
785786
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
786787
}
787788
};

0 commit comments

Comments
 (0)