diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index d723fc8c7af9..9707f576c728 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -35,6 +35,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/StorageUniquerSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -570,6 +571,17 @@ LogicalResult cir::CastOp::verify() { mlir::isa(resType)) return success(); + // Handle scalar to vector and vector to scalar conversions. + if (mlir::isa(getSrc().getType()) != + mlir::isa(getType())) { + // The source and result must be the same size. + mlir::DataLayout dataLayout( + getOperation()->getParentOfType()); + if (dataLayout.getTypeSize(getSrc().getType()) == + dataLayout.getTypeSize(getType())) + return success(); + } + // This is the only cast kind where we don't want vector types to decay // into the element type. if ((!mlir::isa(getSrc().getType()) || diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp index e365b4bd7b12..13e5e54794d3 100644 --- a/clang/test/CIR/CodeGen/vectype.cpp +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -120,19 +120,31 @@ void vector_int_test(int x, unsigned short usx) { // Shifts vi4 w = a << b; - // CHECK: %{{[0-9]+}} = cir.shift(left, {{%.*}} : !cir.vector, + // CHECK: %{{[0-9]+}} = cir.shift(left, {{%.*}} : !cir.vector, // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector vi4 y = a >> b; - // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, + // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector - vus2 z = { usx, usx }; + vus2 z = { usx, usx }; // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !u16i, !u16i) : !cir.vector vus2 zamt = { 3, 4 }; // CHECK: %{{[0-9]+}} = cir.const #cir.const_vector<[#cir.int<3> : !u16i, #cir.int<4> : !u16i]> : !cir.vector vus2 zzz = z >> zamt; - // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, - // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector + // CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector, + // CHECK-SAME: {{%.*}} : !cir.vector) -> !cir.vector + + // Vector to scalar conversion + unsigned int zi = (unsigned int)z; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !cir.vector), !u32i + + // Scalar to vector conversion + vus2 zz = (vus2)zi; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !u32i), !cir.vector + + // Vector to vector conversion + vll2 aaa = (vll2)a; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !cir.vector), !cir.vector } void vector_double_test(int x, double y) { diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 535a76552f2a..b03dffff46af 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -1405,7 +1405,35 @@ module { // expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}} %1 = cir.cast(bitcast, %0 : !cir.ptr), !cir.ptr } +} + +// ----- + +!s16i = !cir.int +!s64i = !cir.int +module { + cir.func @test_bitcast_vec2scalar_diff_size() { + %0 = cir.const #cir.int<1> : !s16i + %1 = cir.vec.create(%0, %0 : !s16i, !s16i) : !cir.vector + // expected-error@+1 {{'cir.cast' op requires !cir.ptr or !cir.vector type for source and result}} + %2 = cir.cast(bitcast, %1 : !cir.vector), !s64i + cir.return + } +} + +// ----- + +!s32i = !cir.int +!s64i = !cir.int + +module { + cir.func @test_bitcast_scalar2vec_diff_size() { + %0 = cir.const #cir.int<1> : !s64i + // expected-error@+1 {{'cir.cast' op requires !cir.ptr or !cir.vector type for source and result}} + %1 = cir.cast(bitcast, %0 : !s64i), !cir.vector + cir.return + } } // -----