From 6ed4aaeb3ac97fc9fd151669ae7364720251d55d Mon Sep 17 00:00:00 2001 From: Tommy McMichen Date: Tue, 1 Jul 2025 17:07:37 -0700 Subject: [PATCH] [CIR] Improved `cir::CastOp` verifier to allow bitcasts between types of the same size. The `cir::CastOp::verify` method was overly conservative, and would fail on any `bitcast` from vector to scalar or scalar to vector. This diff extends the `cir::CastOp::verify` method to check if the source and result types are the same size using the `mlir::DataLayout` of the current scope, and succeeds if the sizes match. This diff also extends the CodeGen vectype tests with vector to scalar, scalar to vector and vector to vector conversions. This diff also extends the IR invalid tests with vector to scalar and scalar to vector conversions with different source and result sizes. --- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 12 +++++++++++ clang/test/CIR/CodeGen/vectype.cpp | 22 ++++++++++++++----- clang/test/CIR/IR/invalid.cir | 28 +++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index d723fc8c7af9..9707f576c728 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -35,6 +35,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/StorageUniquerSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -570,6 +571,17 @@ LogicalResult cir::CastOp::verify() { mlir::isa(resType)) return success(); + // Handle scalar to vector and vector to scalar conversions. + if (mlir::isa(getSrc().getType()) != + mlir::isa(getType())) { + // The source and result must be the same size. + mlir::DataLayout dataLayout( + getOperation()->getParentOfType()); + if (dataLayout.getTypeSize(getSrc().getType()) == + dataLayout.getTypeSize(getType())) + return success(); + } + // This is the only cast kind where we don't want vector types to decay // into the element type. if ((!mlir::isa(getSrc().getType()) || diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp index e365b4bd7b12..13e5e54794d3 100644 --- a/clang/test/CIR/CodeGen/vectype.cpp +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -120,19 +120,31 @@ void vector_int_test(int x, unsigned short usx) { // Shifts vi4 w = a << b; - // CHECK: %{{[0-9]+}} = cir.shift(left, {{%.*}} : !cir.vector, + // CHECK: %{{[0-9]+}} = cir.shift(left, {{%.*}} : !cir.vector, // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector vi4 y = a >> b; - // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, + // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector - vus2 z = { usx, usx }; + vus2 z = { usx, usx }; // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !u16i, !u16i) : !cir.vector vus2 zamt = { 3, 4 }; // CHECK: %{{[0-9]+}} = cir.const #cir.const_vector<[#cir.int<3> : !u16i, #cir.int<4> : !u16i]> : !cir.vector vus2 zzz = z >> zamt; - // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, - // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector + // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, + // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector + + // Vector to scalar conversion + unsigned int zi = (unsigned int)z; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !cir.vector), !u32i + + // Scalar to vector conversion + vus2 zz = (vus2)zi; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !u32i), !cir.vector + + // Vector to vector conversion + vll2 aaa = (vll2)a; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !cir.vector), !cir.vector } void vector_double_test(int x, double y) { diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 535a76552f2a..b03dffff46af 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -1405,7 +1405,35 @@ module { // expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}} %1 = cir.cast(bitcast, %0 : !cir.ptr), !cir.ptr } +} + +// ----- + +!s16i = !cir.int +!s64i = !cir.int +module { + cir.func @test_bitcast_vec2scalar_diff_size() { + %0 = cir.const #cir.int<1> : !s16i + %1 = cir.vec.create(%0, %0 : !s16i, !s16i) : !cir.vector + // expected-error@+1 {{'cir.cast' op requires !cir.ptr or !cir.vector type for source and result}} + %2 = cir.cast(bitcast, %1 : !cir.vector), !s64i + cir.return + } +} + +// ----- + +!s32i = !cir.int +!s64i = !cir.int + +module { + cir.func @test_bitcast_scalar2vec_diff_size() { + %0 = cir.const #cir.int<1> : !s64i + // expected-error@+1 {{'cir.cast' op requires !cir.ptr or !cir.vector type for source and result}} + %1 = cir.cast(bitcast, %0 : !s64i), !cir.vector + cir.return + } } // -----