Skip to content

Commit 2d7fa91

Browse files
committed
Address comments
1 parent 789d2d4 commit 2d7fa91

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,22 +263,25 @@ class VectorGatherOpConversion
263263
LogicalResult
264264
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
265265
ConversionPatternRewriter &rewriter) const override {
266+
Location loc = gather->getLoc();
266267
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
267268
assert(memRefType && "The base should be bufferized");
268269

269270
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
270-
return failure();
271+
return rewriter.notifyMatchFailure(gather, "memref type not supported");
271272

272273
VectorType vType = gather.getVectorType();
273-
if (vType.getRank() > 1)
274-
return failure();
275-
276-
Location loc = gather->getLoc();
274+
if (vType.getRank() > 1) {
275+
return rewriter.notifyMatchFailure(
276+
gather, "only 1-D vectors can be lowered to LLVM");
277+
}
277278

278279
// Resolve alignment.
279280
unsigned align;
280-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
281-
return failure();
281+
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
282+
return rewriter.notifyMatchFailure(gather,
283+
"could not resolve memref alignment");
284+
}
282285

283286
// Resolve address.
284287
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -309,16 +312,20 @@ class VectorScatterOpConversion
309312
MemRefType memRefType = scatter.getMemRefType();
310313

311314
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
312-
return failure();
315+
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
313316

314317
VectorType vType = scatter.getVectorType();
315-
if (vType.getRank() > 1)
316-
return failure();
318+
if (vType.getRank() > 1) {
319+
return rewriter.notifyMatchFailure(
320+
scatter, "only 1-D vectors can be lowered to LLVM");
321+
}
317322

318323
// Resolve alignment.
319324
unsigned align;
320-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
321-
return failure();
325+
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
326+
return rewriter.notifyMatchFailure(scatter,
327+
"could not resolve memref alignment");
328+
}
322329

323330
// Resolve address.
324331
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
16541654

16551655
//===----------------------------------------------------------------------===//
16561656
// vector.gather
1657-
//
1657+
//
16581658
// NOTE: vector.constant_mask won't lower with
16591659
// * --convert-to-llvm="filter-dialects=vector",
16601660
// hence testing here.
@@ -1741,7 +1741,7 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
17411741
func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
17421742
%0 = arith.constant 0: index
17431743
// vector.constant_mask only supports 'none set' or 'all set' scalable
1744-
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
1744+
// dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
17451745
// width vectors above.
17461746
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
17471747
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,8 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
882882
return
883883
}
884884

885-
// CHECK-LABEL: @gather_multi_dims
886-
func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
885+
// CHECK-LABEL: @gather_and_scatter_multi_dims
886+
func.func @gather_and_scatter_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
887887
%c0 = arith.constant 0 : index
888888
// CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
889889
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>

0 commit comments

Comments
 (0)