From 14103cbf7dce52b46b9324b1d896a07a905a99d0 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 21 Mar 2024 19:01:12 +0000 Subject: [PATCH] [mlir][arith] Refine the verifier for arith.constant Disallows initialization of scalable vectors with an attribute of arbitrary values, e.g.: ```mlir %c = arith.constant dense<[0, 1]> : vector<[2] x i32> ``` Initialization using vector splats remains allowed (i.e. when all the init values are identical): ```mlir %c = arith.constant dense<[1, 1]> : vector<[2] x i32> ``` Note: This is a re-upload of #86178 --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 9 +++++++++ mlir/test/Dialect/Arith/invalid.mlir | 18 ++++++++++++++++++ mlir/test/Dialect/Vector/linearize.mlir | 11 ----------- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1d68a4f7292b5..6f995b93bc3ec 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -213,6 +213,15 @@ LogicalResult arith::ConstantOp::verify() { return emitOpError( "value must be an integer, float, or elements attribute"); } + + // Note, we could relax this for vectors with 1 scalable dim, e.g.: + // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> + // However, this would most likely require updating the lowerings to LLVM. + auto vecType = dyn_cast(type); + if (vecType && vecType.isScalable() && !isa(getValue())) + return emitOpError( + "intializing scalable vectors with elements attribute is not supported" + " unless it's a vector splat"); return success(); } diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 6d8ac0ada52be..ada849220bb83 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -64,6 +64,24 @@ func.func @constant_out_of_range() { // ----- +func.func @constant_invalid_scalable_1d_vec_initialization() { +^bb0: + // expected-error@+1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}} + %c = arith.constant dense<[0, 1]> : vector<[2] x i32> + return +} + +// ----- + +func.func @constant_invalid_scalable_2d_vec_initialization() { +^bb0: + // expected-error@+1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}} + %c = arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> + return +} + +// ----- + func.func @constant_wrong_type() { ^bb: %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}} diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 212541c79565b..22be78cd68205 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -153,14 +153,3 @@ func.func @test_0d_vector() -> vector { // ALL: return %[[CST]] return %0 : vector } - -// ----- - -func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> { - // expected-error@+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} - %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32> - %1 = math.sin %arg0 : vector<2x[2]xf32> - %2 = arith.addf %0, %1 : vector<2x[2]xf32> - - return %2 : vector<2x[2]xf32> -}