diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 601a65333d026..14cff4ff893b5 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -241,16 +241,20 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, /// Populate the pattern set with the following patterns: /// -/// [FlattenGather] -/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the +/// [UnrollGather] +/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the /// outermost dimension. +void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: /// /// [Gather1DToConditionalLoads] /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. -void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Populates instances of `MaskOpRewritePattern` to lower masked operations /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 94efec61a466c..357152eba8003 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -269,49 +269,30 @@ class VectorGatherOpConversion if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return failure(); - auto loc = gather->getLoc(); + VectorType vType = gather.getVectorType(); + if (vType.getRank() > 1) + return failure(); + + Location loc = gather->getLoc(); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); + // Resolve address. Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); Value base = adaptor.getBase(); + Value ptrs = + getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, + base, ptr, adaptor.getIndexVec(), vType); - auto llvmNDVectorTy = adaptor.getIndexVec().getType(); - // Handle the simple case of 1-D vector. - if (!isa(llvmNDVectorTy)) { - auto vType = gather.getVectorType(); - // Resolve address. - Value ptrs = - getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - base, ptr, adaptor.getIndexVec(), vType); - // Replace with the gather intrinsic. - rewriter.replaceOpWithNewOp( - gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), - adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); - return success(); - } - - const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); - auto callback = [align, memRefType, base, ptr, loc, &rewriter, - &typeConverter](Type llvm1DVectorTy, - ValueRange vectorOperands) { - // Resolve address. - Value ptrs = getIndexedPtrs( - rewriter, loc, typeConverter, memRefType, base, ptr, - /*index=*/vectorOperands[0], cast(llvm1DVectorTy)); - // Create the gather intrinsic. - return rewriter.create( - loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], - /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); - }; - SmallVector vectorOperands = { - adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; - return LLVM::detail::handleMultidimensionalVectors( - gather, vectorOperands, *getTypeConverter(), callback, rewriter); + // Replace with the gather intrinsic. + rewriter.replaceOpWithNewOp( + gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), + adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); + return success(); } }; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index eb1555df5d574..7082b92c95d1d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorInsertExtractStridedSliceTransforms(patterns); populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); + populateVectorGatherLoweringPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 3b38505becd18..3000204c8ce17 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -38,7 +38,7 @@ using namespace mlir; using namespace mlir::vector; namespace { -/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the +/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the /// outermost dimension. For example: /// ``` /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru : @@ -56,14 +56,14 @@ namespace { /// When applied exhaustively, this will produce a sequence of 1-d gather ops. /// /// Supports vector types with a fixed leading dimension. -struct FlattenGather : OpRewritePattern { +struct UnrollGather : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { VectorType resultTy = op.getType(); if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already flat"); + return rewriter.notifyMatchFailure(op, "already 1-D"); // Unrolling doesn't take vscale into account. Pattern is disabled for // vectors with leading scalable dim(s). @@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern { /// ```mlir /// %subview = memref.subview %M (...) /// : memref<100x3xf32> to memref<100xf32, strided<[3]>> -/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>> +/// %gather = vector.gather %subview[%idxs] (...) +/// : memref<100xf32, strided<[3]>> /// ``` /// ==> /// ```mlir @@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern { void mlir::vector::populateVectorGatherLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), benefit); +} + +void mlir::vector::populateVectorGatherToConditionalLoadPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + patterns.getContext(), benefit); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index c3f06dd4d5dd1..44b4a25a051f1 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref, %arg1: vector<[3]xindex // ----- -func.func @gather_2d_from_1d(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> { - %0 = arith.constant 0: index - %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> - return %1 : vector<2x3xf32> -} - -// CHECK-LABEL: func @gather_2d_from_1d -// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>> -// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>> -// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>> -// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32 -// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> -// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>> -// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>> -// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>> -// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>> -// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32 -// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> -// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>> - -// ----- - -func.func @gather_2d_from_1d_scalable(%arg0: memref, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> { - %0 = arith.constant 0: index - %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32> - return %1 : vector<2x[3]xf32> -} - -// CHECK-LABEL: func @gather_2d_from_1d_scalable -// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>> -// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>> -// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>> -// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec, f32 -// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> -// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>> -// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>> -// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>> -// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>> -// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec, f32 -// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32> -// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>> - -// ----- - func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { %0 = arith.constant 3 : index diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 1ab28b9df2d19..5404fdda033ee 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { func.func @gather_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { %0 = arith.constant 0: index - %1 = vector.constant_mask [1, 2] : vector<2x3xi1> + %1 = vector.constant_mask [2, 2] : vector<2x3xi1> %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> return %2 : vector<2x3xf32> } @@ -1677,9 +1677,9 @@ func.func @gather_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: func.func @gather_with_mask_scalable(%arg0: memref, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> { %0 = arith.constant 0: index // vector.constant_mask only supports 'none set' or 'all set' scalable - // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed + // dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed // width vectors above. - %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1> + %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1> %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32> return %2 : vector<2x[3]xf32> } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 74838bc0ca2fb..2cf1dde9bd1b8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -782,6 +782,7 @@ struct TestVectorGatherLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorGatherLoweringPatterns(patterns); + populateVectorGatherToConditionalLoadPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } };