Skip to content

[mlir][vector] Add more tests for ConvertVectorToLLVM (2/n) #102203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 9, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Aug 6, 2024

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.outerproduct

@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Clarify the semantics of BroadcastOp
  • [mlir][vector] Add more tests for ConvertVectorToLLVM (1/n)
  • [mlir][vector] Add more tests for ConvertVectorToLLVM (2/n)

Patch is 47.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102203.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+35-10)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+6-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+416)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ac55433fadb2f..9f61f7c866d3d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -68,9 +68,13 @@ enum class BroadcastableToResult {
   DimensionMismatch = 2,
   SourceTypeNotAVector = 3
 };
+struct VectorDim {
+  int64_t dim;
+  bool scalableFlag;
+};
 BroadcastableToResult
 isBroadcastableTo(Type srcType, VectorType dstVectorType,
-                  std::pair<int, int> *mismatchingDims = nullptr);
+                  std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
 
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..08bff3d5e1382 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -367,6 +367,8 @@ def Vector_BroadcastOp :
                                s_1     x .. x s_j x .. x s_k
                <duplication>         <potential stretch>
        ```
+    * a scalable unit dimeension, `[1]`, must match exactly.
+
     The source operand is duplicated over all the missing leading dimensions
     and stretched over the trailing dimensions where the source has a non-equal
     dimension of 1. These rules imply that any scalar broadcast (k=0) to any
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..673c128932893 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
   return res;
 }
 
-BroadcastableToResult
-mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
-                                std::pair<int, int> *mismatchingDims) {
+BroadcastableToResult mlir::vector::isBroadcastableTo(
+    Type srcType, VectorType dstVectorType,
+    std::pair<VectorDim, VectorDim> *mismatchingDims) {
   // Broadcast scalar to vector of the same element type.
   if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
       getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
@@ -2391,12 +2391,28 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
   // (all leading dimensions are simply duplicated).
   int64_t lead = dstRank - srcRank;
   for (int64_t r = 0; r < srcRank; ++r) {
+    bool mismatch = false;
+
+    // Check fixed-width dims
     int64_t srcDim = srcVectorType.getDimSize(r);
     int64_t dstDim = dstVectorType.getDimSize(lead + r);
-    if (srcDim != 1 && srcDim != dstDim) {
+    if ((srcDim != 1 && srcDim != dstDim))
+      mismatch = true;
+
+    // Check scalable flags
+    bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
+    bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
+    if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+        (srcDimScalableFlag && !dstDimScalableFlag))
+      mismatch = true;
+
+    if (mismatch) {
       if (mismatchingDims) {
-        mismatchingDims->first = srcDim;
-        mismatchingDims->second = dstDim;
+        mismatchingDims->first.dim = srcDim;
+        mismatchingDims->first.scalableFlag = srcDimScalableFlag;
+
+        mismatchingDims->second.dim = dstDim;
+        mismatchingDims->second.scalableFlag = dstDimScalableFlag;
       }
       return BroadcastableToResult::DimensionMismatch;
     }
@@ -2406,16 +2422,25 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
 }
 
 LogicalResult BroadcastOp::verify() {
-  std::pair<int, int> mismatchingDims;
+  std::pair<VectorDim, VectorDim> mismatchingDims;
   BroadcastableToResult res = isBroadcastableTo(
       getSourceType(), getResultVectorType(), &mismatchingDims);
   if (res == BroadcastableToResult::Success)
     return success();
   if (res == BroadcastableToResult::SourceRankHigher)
     return emitOpError("source rank higher than destination rank");
-  if (res == BroadcastableToResult::DimensionMismatch)
-    return emitOpError("dimension mismatch (")
-           << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+  if (res == BroadcastableToResult::DimensionMismatch) {
+    std::string msg =
+        (Twine("dimension mismatch (") +
+         (mismatchingDims.first.scalableFlag ? "[" : "") +
+         std::to_string(mismatchingDims.first.dim) +
+         (mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
+         (mismatchingDims.second.scalableFlag ? "[" : "") +
+         std::to_string(mismatchingDims.second.dim) +
+         (mismatchingDims.second.scalableFlag ? "]" : "") + ")")
+            .str();
+    return emitOpError(msg);
+  }
   if (res == BroadcastableToResult::SourceTypeNotAVector)
     return emitOpError("source type is not a vector");
   llvm_unreachable("unexpected vector.broadcast op error");
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 32e7eb27f5e29..6c36bbaee8523 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -125,7 +125,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     //   ..
     //   %x = [%a,%b,%c,%d]
     VectorType resType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
+        VectorType::get(dstType.getShape().drop_front(), eltType,
+                        dstType.getScalableDims().drop_front());
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
     if (m == 0) {
@@ -136,6 +137,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
     } else {
       // Stetch not at start.
+      if (dstType.getScalableDims()[0]) {
+        // TODO: For scalable vectors we should emit an scf.for loop.
+        return failure();
+      }
       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
         Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c310954b906e4..1034baedb0cd0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -23,6 +23,15 @@ func.func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> {
 // CHECK-SAME:  %[[input:.*]]: vector<16xf32>
 // CHECK:       llvm.bitcast %[[input]] : vector<16xf32> to vector<16xi32>
 
+func.func @bitcast_f32_to_i32_vector_scalable(%input: vector<[16]xf32>) -> vector<[16]xi32> {
+  %0 = vector.bitcast %input : vector<[16]xf32> to vector<[16]xi32>
+  return %0 : vector<[16]xi32>
+}
+
+// CHECK-LABEL: @bitcast_f32_to_i32_vector_scalable
+// CHECK-SAME:  %[[input:.*]]: vector<[16]xf32>
+// CHECK:       llvm.bitcast %[[input]] : vector<[16]xf32> to vector<[16]xi32>
+
 // -----
 
 func.func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
@@ -34,6 +43,15 @@ func.func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
 // CHECK-SAME:  %[[input:.*]]: vector<64xi8>
 // CHECK:       llvm.bitcast %[[input]] : vector<64xi8> to vector<16xf32>
 
+func.func @bitcast_i8_to_f32_vector_scalable(%input: vector<[64]xi8>) -> vector<[16]xf32> {
+  %0 = vector.bitcast %input : vector<[64]xi8> to vector<[16]xf32>
+  return %0 : vector<[16]xf32>
+}
+
+// CHECK-LABEL: @bitcast_i8_to_f32_vector_scalable
+// CHECK-SAME:  %[[input:.*]]: vector<[64]xi8>
+// CHECK:       llvm.bitcast %[[input]] : vector<[64]xi8> to vector<[16]xf32>
+
 // -----
 
 func.func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
@@ -46,6 +64,16 @@ func.func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8
 // CHECK:       %[[T0:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<16xindex> to vector<16xi64>
 // CHECK:       llvm.bitcast %[[T0]] : vector<16xi64> to vector<128xi8>
 
+func.func @bitcast_index_to_i8_vector_scalable(%input: vector<[16]xindex>) -> vector<[128]xi8> {
+  %0 = vector.bitcast %input : vector<[16]xindex> to vector<[128]xi8>
+  return %0 : vector<[128]xi8>
+}
+
+// CHECK-LABEL: @bitcast_index_to_i8_vector_scalable
+// CHECK-SAME:  %[[input:.*]]: vector<[16]xindex>
+// CHECK:       %[[T0:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<[16]xindex> to vector<[16]xi64>
+// CHECK:       llvm.bitcast %[[T0]] : vector<[16]xi64> to vector<[128]xi8>
+
 // -----
 
 func.func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
@@ -80,6 +108,17 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
 // CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
 // CHECK:       return %[[T1]] : vector<2xf32>
 
+
+func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_f32_scalable
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       return %[[T1]] : vector<[2]xf32>
+
 // -----
 
 func.func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
@@ -94,6 +133,18 @@ func.func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<2xi64> to vector<2xindex>
 // CHECK:       return %[[T2]] : vector<2xindex>
 
+func.func @broadcast_vec1d_from_index_scalable(%arg0: index) -> vector<[2]xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<[2]xindex>
+  return %0 : vector<[2]xindex>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_index_scalable
+// CHECK-SAME:  %[[A:.*]]: index)
+// CHECK:       %[[A1:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A1]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<[2]xi64> to vector<[2]xindex>
+// CHECK:       return %[[T2]] : vector<[2]xindex>
+
 // -----
 
 func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -109,6 +160,19 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
 // CHECK:       return %[[T4]] : vector<2x3xf32>
 
+func.func @broadcast_vec2d_from_scalar_scalable(%arg0: f32) -> vector<2x[3]xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_scalar_scalable(
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
+// CHECK:       return %[[T4]] : vector<2x[3]xf32>
+
 // -----
 
 func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -125,6 +189,21 @@ func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<4xf32>>> to vector<2x3x4xf32>
 // CHECK:       return %[[T4]] : vector<2x3x4xf32>
 
+
+func.func @broadcast_vec3d_from_scalar_scalable(%arg0: f32) -> vector<2x3x[4]xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<2x3x[4]xf32>
+  return %0 : vector<2x3x[4]xf32>
+}
+// CHECK-LABEL: @broadcast_vec3d_from_scalar_scalable(
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0, 0] : !llvm.array<2 x array<3 x vector<[4]xf32>>>
+// ...
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1, 2] : !llvm.array<2 x array<3 x vector<[4]xf32>>>
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<[4]xf32>>> to vector<2x3x[4]xf32>
+// CHECK:       return %[[T4]] : vector<2x3x[4]xf32>
+
 // -----
 
 func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
@@ -135,6 +214,14 @@ func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
 // CHECK:       return %[[A]] : vector<2xf32>
 
+func.func @broadcast_vec1d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_vec1d_scalable(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
+// CHECK:       return %[[A]] : vector<[2]xf32>
+
 // -----
 
 func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
@@ -172,6 +259,20 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
 // CHECK:       %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
 // CHECK:       return %[[T5]] : vector<3x2xf32>
 
+func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<3x[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_vec1d_scalable(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
+// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][2] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<3 x vector<[2]xf32>> to vector<3x[2]xf32>
+// CHECK:       return %[[T5]] : vector<3x[2]xf32>
+
 // -----
 
 func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x2xindex> {
@@ -188,6 +289,20 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.array<3 x vector<2xi64>> to vector<3x2xindex>
 // CHECK:       return %[[T4]] : vector<3x2xindex>
 
+func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -> vector<3x[2]xindex> {
+  %0 = vector.broadcast %arg0 : vector<[2]xindex> to vector<3x[2]xindex>
+  return %0 : vector<3x[2]xindex>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d_scalable(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xindex>)
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
+// CHECK:       %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
+
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.array<3 x vector<[2]xi64>> to vector<3x[2]xindex>
+// CHECK:       return %[[T4]] : vector<3x[2]xindex>
+
 // -----
 
 func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
@@ -213,6 +328,29 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
 // CHECK:       %[[T11:.*]] = builtin.unrealized_conversion_cast %[[T10]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
 // CHECK:       return %[[T11]] : vector<4x3x2xf32>
 
+func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<4x3x[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<4x3x[2]xf32>
+  return %0 : vector<4x3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec3d_from_vec1d_scalable(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
+// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
+
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T5:.*]] = llvm.insertvalue %[[A]], %[[T4]][2] : !llvm.array<3 x vector<[2]xf32>>
+
+// CHECK:       %[[T7:.*]] = llvm.insertvalue %[[T5]], %[[T6]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T8:.*]] = llvm.insertvalue %[[T5]], %[[T7]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T9:.*]] = llvm.insertvalue %[[T5]], %[[T8]][2] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T10:.*]] = llvm.insertvalue %[[T5]], %[[T9]][3] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+
+// CHECK:       %[[T11:.*]] = builtin.unrealized_conversion_cast %[[T10]] : !llvm.array<4 x array<3 x vector<[2]xf32>>> to vector<4x3x[2]xf32>
+// CHECK:       return %[[T11]] : vector<4x3x[2]xf32>
+
 // -----
 
 func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
@@ -231,6 +369,22 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
 // CHECK:       %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
 // CHECK:       return %[[T10]] : vector<4x3x2xf32>
 
+func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vector<4x3x[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<3x[2]xf32> to vector<4x3x[2]xf32>
+  return %0 : vector<4x3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_vec3d_from_vec2d_scalable(
+// CHECK-SAME:  %[[A:.*]]: vector<3x[2]xf32>)
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T7:.*]] = llvm.insertvalue %[[T1]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T9:.*]] = llvm.insertvalue %[[T1]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<[2]xf32>>> to vector<4x3x[2]xf32>
+// CHECK:       return %[[T10]] : vector<4x3x[2]xf32>
+
 
 // -----
 
@@ -246,6 +400,18 @@ func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
 // CHECK:       %[[T4:.*]] = llvm.shufflevector %[[T3]]
 // CHECK:       return %[[T4]] : vector<4xf32>
 
+func.func @broadcast_stretch_scalable(%arg0: vector<1xf32>) -> vector<[4]xf32> {
+  %0 = vector.broadcast %arg0 : vector<1xf32> to vector<[4]...
[truncated]

@banach-space banach-space changed the title andrzej/extend vector to llvm test 2 [mlir][vector] Add more tests for ConvertVectorToLLVM (2/n) Aug 6, 2024
@banach-space banach-space requested a review from nujaa August 6, 2024 19:41
Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please ping when the dependent MRs are merged for approval.
A bit off topic but, This MR got me wondering about the handling of cases where the lhs would be scalable and found out it was considered invalid. e.g.

 // expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
// equivalent to
// for i = 0 to vscale
// | %0i = extract %0[i*4] vector<[4]xf32>-> vector<4xf32>
// | vector.outerproduct %0i, %1 : vector<4xf32>, vector<4xf32>

Does it make sense to support this usecase in the future or do we consider this should never be generated and the user never write it ? I have seen another MR recently (from you I think) which did not handle the lowering to a loop.

      // This restriction reflects what's currently supported in terms of
      // scalable vectors. However, we could relax this if there's a use case.

In case I get into sorting this out, do you think we should have some separate scalable vectors canonicalization patterns a bit like ArmSMEToSCF but without the SME dependency or simply have lowerVectorToLLVM support this case is enough ?

@banach-space
Copy link
Contributor Author

Does it make sense to support this usecase in the future or do we consider this should never be generated and the user never write it ?

It's always good to have a motivating example - can you think of a situation where that could be used? Nothing on my mind today, but I tend to look at a rather limited set of examples.

In general, I always advocate for good balance between complexity and what's actually needed. In particular with scalable vectors, we tend to discover that we don't really need every single pattern in the Vector dialect to support them. And we still manage to generate good code 😅 (just as an example)

Having said that, I'm not opposed to improving the support for scalable vectors.

I have seen another MR recently (from you I think) which did not handle the lowering to a loop.

      // This restriction reflects what's currently supported in terms of
      // scalable vectors. However, we could relax this if there's a use case.

In case I get into sorting this out, do you think we should have some separate scalable vectors canonicalization patterns a bit like ArmSMEToSCF but without the SME dependency or simply have lowerVectorToLLVM support this case is enough ?

Depends on complexity. For more involved logic, I'd start with something "hidden" in the SVE dialect. And then, if we find that particularly useful, we could "elevate" that to Vector. Having said that, this particular functionality would be valid for any implementation of scalable vectors, not only SVE.

Btw, I really appreciate all the reviews and wanted to say that I'm always happy to return the favour. In particular, I'm reviewing all Vector tests and if you ever feel like taking a look at one of the tests yourself, here's a list that I use to co-ordinate this:

Sharing mostly so that you are aware :)

@nujaa
Copy link
Contributor

nujaa commented Aug 8, 2024

Alright, thanks for clarification.
One usecase I find is, if scalable vectors(resp. tiles) have been generated but the machine does not support scalable vectors (resp 2d scalable). Let's consider there are no sve/sme specific ops but e.g. vector.outerproduct <[4]xf32> <[4]xf32> . Should lower-to-llvm be able to lower scalable ops to fixed size or compiling without sve/sme/... flags sets vscale to 1 and shall never build e.g. z registers ?
Mmmh, In any case, if vscale is not supported, there is no loop to be generated. So only remains the case we have 2d scalable vectors but only have SVE. That's quite niche.

Btw, I really appreciate all the reviews and wanted to say that I'm always happy to return the favour. In particular, I'm reviewing all Vector tests and if you ever feel like taking a look at one of the tests yourself, here's a list that I use to co-ordinate this:

Sharing mostly so that you are aware :)

Thanks for your kind words. I have to deliver something downstream at present. I have a few features in mind for the short term future. I'll bookmark your status 😉

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.outerproduct
@banach-space banach-space force-pushed the andrzej/extend_vector_to_llvm_test_2 branch from 4ad7449 to e3bb7e1 Compare August 8, 2024 15:00
@banach-space
Copy link
Contributor Author

LGTM. Please ping when the dependent MRs are merged for approval.

Ping :)


// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT:missing VAL_[4-7]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's intentional and consistent with other tests :) I am only testing the key stuff (to reduce noise).

@@ -663,6 +734,16 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>

func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is renaming arg[0-9]* to vec_[0-9]*, ... in your future plans ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file consistently uses %arg0 and %arg1, so renaming would be just extra noise. In the other files, we were mixing different styles, so I've renamed everything (to %vec) for consistency. But, IMHO, the actual name isn't as important as consistency 😅

}

// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we align arg definition with function body as in the other tests ?

Copy link
Contributor Author

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, sending update shortly!

@@ -663,6 +734,16 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>

func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f32, %arg2: vector<[2]xf32>, %m: vector<[2]xi1>) -> vector<[2]xf32> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file consistently uses %arg0 and %arg1, so renaming would be just extra noise. In the other files, we were mixing different styles, so I've renamed everything (to %vec) for consistency. But, IMHO, the actual name isn't as important as consistency 😅


// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's intentional and consistent with other tests :) I am only testing the key stuff (to reduce noise).

@nujaa nujaa merged commit 7e175b3 into llvm:main Aug 9, 2024
8 checks passed
@banach-space
Copy link
Contributor Author

@nujaa I have noticed that you "click"ed merge without accepting ;-) Could you formally "+1" as well?

Thanks for the review and all the support with this 🙏🏻 !

@banach-space banach-space deleted the andrzej/extend_vector_to_llvm_test_2 branch August 11, 2024 16:39
Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nujaa
Copy link
Contributor

nujaa commented Aug 12, 2024

I thought It would not let me merge without Approval so I imagined it was approved 🤒

@banach-space
Copy link
Contributor Author

I thought It would not let me merge without Approval so I imagined it was approved 🤒

It's also OK to leave for the author to merge it ;-) In fact, I might not get any buildbot notifications otherwise 🤔 (perhaps that has changed)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants