-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Add more tests for ConvertVectorToLLVM (2/n) #102203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add more tests for ConvertVectorToLLVM (2/n) #102203
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Patch is 47.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102203.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ac55433fadb2f..9f61f7c866d3d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -68,9 +68,13 @@ enum class BroadcastableToResult {
DimensionMismatch = 2,
SourceTypeNotAVector = 3
};
+struct VectorDim {
+ int64_t dim;
+ bool scalableFlag;
+};
BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
- std::pair<int, int> *mismatchingDims = nullptr);
+ std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..08bff3d5e1382 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -367,6 +367,8 @@ def Vector_BroadcastOp :
s_1 x .. x s_j x .. x s_k
<duplication> <potential stretch>
```
+ * a scalable unit dimeension, `[1]`, must match exactly.
+
The source operand is duplicated over all the missing leading dimensions
and stretched over the trailing dimensions where the source has a non-equal
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..673c128932893 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
return res;
}
-BroadcastableToResult
-mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
- std::pair<int, int> *mismatchingDims) {
+BroadcastableToResult mlir::vector::isBroadcastableTo(
+ Type srcType, VectorType dstVectorType,
+ std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
@@ -2391,12 +2391,28 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
+ bool mismatch = false;
+
+ // Check fixed-width dims
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
- if (srcDim != 1 && srcDim != dstDim) {
+ if ((srcDim != 1 && srcDim != dstDim))
+ mismatch = true;
+
+ // Check scalable flags
+ bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
+ bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
+ if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+ (srcDimScalableFlag && !dstDimScalableFlag))
+ mismatch = true;
+
+ if (mismatch) {
if (mismatchingDims) {
- mismatchingDims->first = srcDim;
- mismatchingDims->second = dstDim;
+ mismatchingDims->first.dim = srcDim;
+ mismatchingDims->first.scalableFlag = srcDimScalableFlag;
+
+ mismatchingDims->second.dim = dstDim;
+ mismatchingDims->second.scalableFlag = dstDimScalableFlag;
}
return BroadcastableToResult::DimensionMismatch;
}
@@ -2406,16 +2422,25 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
}
LogicalResult BroadcastOp::verify() {
- std::pair<int, int> mismatchingDims;
+ std::pair<VectorDim, VectorDim> mismatchingDims;
BroadcastableToResult res = isBroadcastableTo(
getSourceType(), getResultVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
return emitOpError("source rank higher than destination rank");
- if (res == BroadcastableToResult::DimensionMismatch)
- return emitOpError("dimension mismatch (")
- << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+ if (res == BroadcastableToResult::DimensionMismatch) {
+ std::string msg =
+ (Twine("dimension mismatch (") +
+ (mismatchingDims.first.scalableFlag ? "[" : "") +
+ std::to_string(mismatchingDims.first.dim) +
+ (mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
+ (mismatchingDims.second.scalableFlag ? "[" : "") +
+ std::to_string(mismatchingDims.second.dim) +
+ (mismatchingDims.second.scalableFlag ? "]" : "") + ")")
+ .str();
+ return emitOpError(msg);
+ }
if (res == BroadcastableToResult::SourceTypeNotAVector)
return emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 32e7eb27f5e29..6c36bbaee8523 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -125,7 +125,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// ..
// %x = [%a,%b,%c,%d]
VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
+ VectorType::get(dstType.getShape().drop_front(), eltType,
+ dstType.getScalableDims().drop_front());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
if (m == 0) {
@@ -136,6 +137,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
} else {
// Stetch not at start.
+ if (dstType.getScalableDims()[0]) {
+ // TODO: For scalable vectors we should emit an scf.for loop.
+ return failure();
+ }
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c310954b906e4..1034baedb0cd0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -23,6 +23,15 @@ func.func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> {
// CHECK-SAME: %[[input:.*]]: vector<16xf32>
// CHECK: llvm.bitcast %[[input]] : vector<16xf32> to vector<16xi32>
+func.func @bitcast_f32_to_i32_vector_scalable(%input: vector<[16]xf32>) -> vector<[16]xi32> {
+ %0 = vector.bitcast %input : vector<[16]xf32> to vector<[16]xi32>
+ return %0 : vector<[16]xi32>
+}
+
+// CHECK-LABEL: @bitcast_f32_to_i32_vector_scalable
+// CHECK-SAME: %[[input:.*]]: vector<[16]xf32>
+// CHECK: llvm.bitcast %[[input]] : vector<[16]xf32> to vector<[16]xi32>
+
// -----
func.func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
@@ -34,6 +43,15 @@ func.func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
// CHECK-SAME: %[[input:.*]]: vector<64xi8>
// CHECK: llvm.bitcast %[[input]] : vector<64xi8> to vector<16xf32>
+func.func @bitcast_i8_to_f32_vector_scalable(%input: vector<[64]xi8>) -> vector<[16]xf32> {
+ %0 = vector.bitcast %input : vector<[64]xi8> to vector<[16]xf32>
+ return %0 : vector<[16]xf32>
+}
+
+// CHECK-LABEL: @bitcast_i8_to_f32_vector_scalable
+// CHECK-SAME: %[[input:.*]]: vector<[64]xi8>
+// CHECK: llvm.bitcast %[[input]] : vector<[64]xi8> to vector<[16]xf32>
+
// -----
func.func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
@@ -46,6 +64,16 @@ func.func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<16xindex> to vector<16xi64>
// CHECK: llvm.bitcast %[[T0]] : vector<16xi64> to vector<128xi8>
+func.func @bitcast_index_to_i8_vector_scalable(%input: vector<[16]xindex>) -> vector<[128]xi8> {
+ %0 = vector.bitcast %input : vector<[16]xindex> to vector<[128]xi8>
+ return %0 : vector<[128]xi8>
+}
+
+// CHECK-LABEL: @bitcast_index_to_i8_vector_scalable
+// CHECK-SAME: %[[input:.*]]: vector<[16]xindex>
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<[16]xindex> to vector<[16]xi64>
+// CHECK: llvm.bitcast %[[T0]] : vector<[16]xi64> to vector<[128]xi8>
+
// -----
func.func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
@@ -80,6 +108,17 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
// CHECK: return %[[T1]] : vector<2xf32>
+
+func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_f32_scalable
+// CHECK-SAME: %[[A:.*]]: f32)
+// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK: return %[[T1]] : vector<[2]xf32>
+
// -----
func.func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
@@ -94,6 +133,18 @@ func.func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<2xi64> to vector<2xindex>
// CHECK: return %[[T2]] : vector<2xindex>
+func.func @broadcast_vec1d_from_index_scalable(%arg0: index) -> vector<[2]xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<[2]xindex>
+ return %0 : vector<[2]xindex>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_index_scalable
+// CHECK-SAME: %[[A:.*]]: index)
+// CHECK: %[[A1:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
+// CHECK: %[[T0:.*]] = llvm.insertelement %[[A1]]
+// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<[2]xi64> to vector<[2]xindex>
+// CHECK: return %[[T2]] : vector<[2]xindex>
+
// -----
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -109,6 +160,19 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
// CHECK: return %[[T4]] : vector<2x3xf32>
+func.func @broadcast_vec2d_from_scalar_scalable(%arg0: f32) -> vector<2x[3]xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_scalar_scalable(
+// CHECK-SAME: %[[A:.*]]: f32)
+// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
+// CHECK: return %[[T4]] : vector<2x[3]xf32>
+
// -----
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -125,6 +189,21 @@ func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<4xf32>>> to vector<2x3x4xf32>
// CHECK: return %[[T4]] : vector<2x3x4xf32>
+
+func.func @broadcast_vec3d_from_scalar_scalable(%arg0: f32) -> vector<2x3x[4]xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2x3x[4]xf32>
+ return %0 : vector<2x3x[4]xf32>
+}
+// CHECK-LABEL: @broadcast_vec3d_from_scalar_scalable(
+// CHECK-SAME: %[[A:.*]]: f32)
+// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0, 0] : !llvm.array<2 x array<3 x vector<[4]xf32>>>
+// ...
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1, 2] : !llvm.array<2 x array<3 x vector<[4]xf32>>>
+// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<[4]xf32>>> to vector<2x3x[4]xf32>
+// CHECK: return %[[T4]] : vector<2x3x[4]xf32>
+
// -----
func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
@@ -135,6 +214,14 @@ func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
// CHECK: return %[[A]] : vector<2xf32>
+func.func @broadcast_vec1d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<[2]xf32> {
+ %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_vec1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
+// CHECK: return %[[A]] : vector<[2]xf32>
+
// -----
func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
@@ -172,6 +259,20 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
// CHECK: return %[[T5]] : vector<3x2xf32>
+func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<3x[2]xf32> {
+ %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<3x[2]xf32>
+ return %0 : vector<3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_vec1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
+// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][2] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<3 x vector<[2]xf32>> to vector<3x[2]xf32>
+// CHECK: return %[[T5]] : vector<3x[2]xf32>
+
// -----
func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x2xindex> {
@@ -188,6 +289,20 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.array<3 x vector<2xi64>> to vector<3x2xindex>
// CHECK: return %[[T4]] : vector<3x2xindex>
+func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -> vector<3x[2]xindex> {
+ %0 = vector.broadcast %arg0 : vector<[2]xindex> to vector<3x[2]xindex>
+ return %0 : vector<3x[2]xindex>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[2]xindex>)
+// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
+// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
+// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
+
+// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.array<3 x vector<[2]xi64>> to vector<3x[2]xindex>
+// CHECK: return %[[T4]] : vector<3x[2]xindex>
+
// -----
func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
@@ -213,6 +328,29 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
// CHECK: %[[T11:.*]] = builtin.unrealized_conversion_cast %[[T10]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
// CHECK: return %[[T11]] : vector<4x3x2xf32>
+func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<4x3x[2]xf32> {
+ %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<4x3x[2]xf32>
+ return %0 : vector<4x3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec3d_from_vec1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
+// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
+
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[A]], %[[T4]][2] : !llvm.array<3 x vector<[2]xf32>>
+
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T5]], %[[T6]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T5]], %[[T7]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T5]], %[[T8]][2] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T5]], %[[T9]][3] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+
+// CHECK: %[[T11:.*]] = builtin.unrealized_conversion_cast %[[T10]] : !llvm.array<4 x array<3 x vector<[2]xf32>>> to vector<4x3x[2]xf32>
+// CHECK: return %[[T11]] : vector<4x3x[2]xf32>
+
// -----
func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
@@ -231,6 +369,22 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
// CHECK: return %[[T10]] : vector<4x3x2xf32>
+func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vector<4x3x[2]xf32> {
+ %0 = vector.broadcast %arg0 : vector<3x[2]xf32> to vector<4x3x[2]xf32>
+ return %0 : vector<4x3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec3d_from_vec2d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<3x[2]xf32>)
+// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T1]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T1]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<[2]xf32>>> to vector<4x3x[2]xf32>
+// CHECK: return %[[T10]] : vector<4x3x[2]xf32>
+
// -----
@@ -246,6 +400,18 @@ func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]]
// CHECK: return %[[T4]] : vector<4xf32>
+func.func @broadcast_stretch_scalable(%arg0: vector<1xf32>) -> vector<[4]xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<[4]...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please ping when the dependent MRs are merged for approval.
A bit off topic but, This MR got me wondering about the handling of cases where the lhs would be scalable and found out it was considered invalid. e.g.
// expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
// equivalent to
// for i = 0 to vscale
// | %0i = extract %0[i*4] vector<[4]xf32>-> vector<4xf32>
// | vector.outerproduct %0i, %1 : vector<4xf32>, vector<4xf32>
Does it make sense to support this usecase in the future or do we consider this should never be generated and the user never write it ? I have seen another MR recently (from you I think) which did not handle the lowering to a loop.
// This restriction reflects what's currently supported in terms of
// scalable vectors. However, we could relax this if there's a use case.
In case I get into sorting this out, do you think we should have some separate scalable vectors canonicalization patterns a bit like ArmSMEToSCF but without the SME dependency or simply have lowerVectorToLLVM support this case is enough ?
It's always good to have a motivating example - can you think of a situation where that could be used? Nothing on my mind today, but I tend to look at a rather limited set of examples. In general, I always advocate for good balance between complexity and what's actually needed. In particular with scalable vectors, we tend to discover that we don't really need every single pattern in the Vector dialect to support them. And we still manage to generate good code 😅 (just as an example) Having said that, I'm not opposed to improving the support for scalable vectors.
Depends on complexity. For more involved logic, I'd start with something "hidden" in the SVE dialect. And then, if we find that particularly useful, we could "elevate" that to Vector. Having said that, this particular functionality would be valid for any implementation of scalable vectors, not only SVE. Btw, I really appreciate all the reviews and wanted to say that I'm always happy to return the favour. In particular, I'm reviewing all Vector tests and if you ever feel like taking a look at one of the tests yourself, here's a list that I use to co-ordinate this: Sharing mostly so that you are aware :) |
Alright, thanks for clarification.
Thanks for your kind words. I have to deliver something downstream at present. I have a few features in mind for the short term future. I'll bookmark your status 😉 |
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.outerproduct
4ad7449
to
e3bb7e1
Compare
Ping :) |
|
||
// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> { | ||
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT:missing VAL_[4-7]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's intentional and consistent with other tests :) I am only testing the key stuff (to reduce noise).
@@ -663,6 +734,16 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v | |||
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32> | |||
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> | |||
|
|||
func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is renaming arg[0-9]*
to vec_[0-9]*
, ... in your future plans ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file consistently uses %arg0
and %arg1
, so renaming would be just extra noise. In the other files, we were mixing different styles, so I've renamed everything (to %vec
) for consistency. But, IMHO, the actual name isn't as important as consistency 😅
} | ||
|
||
// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we align arg definition with function body as in the other tests ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, sending update shortly!
@@ -663,6 +734,16 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v | |||
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32> | |||
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> | |||
|
|||
func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file consistently uses %arg0
and %arg1
, so renaming would be just extra noise. In the other files, we were mixing different styles, so I've renamed everything (to %vec
) for consistency. But, IMHO, the actual name isn't as important as consistency 😅
|
||
// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> { | ||
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's intentional and consistent with other tests :) I am only testing the key stuff (to reduce noise).
@nujaa I have noticed that you "click"ed merge without accepting ;-) Could you formally "+1" as well? Thanks for the review and all the support with this 🙏🏻 ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I thought It would not let me merge without Approval so I imagined it was approved 🤒 |
It's also OK to leave for the author to merge it ;-) In fact, I might not get any buildbot notifications otherwise 🤔 (perhaps that has changed) |
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops: