From cb7f52b667926c1485923e9839237a8419bcde36 Mon Sep 17 00:00:00 2001 From: Guojin He Date: Fri, 15 Nov 2024 15:05:45 -0500 Subject: [PATCH] Extend UnaryFPToFPBuiltinOp to vector of FP type --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 4 +- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 10 +++ .../CIR/Lowering/builtin-floating-point.cir | 86 ++++++++++++++++++- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index bd2f34dbfaaf..e6159d1474be 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -4342,8 +4342,8 @@ def LLrintOp : UnaryFPToIntBuiltinOp<"llrint", "LlrintOp">; class UnaryFPToFPBuiltinOp : CIR_Op { - let arguments = (ins CIR_AnyFloat:$src); - let results = (outs CIR_AnyFloat:$result); + let arguments = (ins CIR_AnyFloatOrVecOfFloat:$src); + let results = (outs CIR_AnyFloatOrVecOfFloat:$result); let summary = "libc builtin equivalent ignoring " "floating point exceptions and errno"; let assemblyFormat = "$src `:` type($src) attr-dict"; diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index f73d80402047..c805b6887cf3 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -553,10 +553,20 @@ def SignedIntegerVector : Type< ]>, "!cir.vector of !cir.int"> { } +// Vector of Float type +def FPVector : Type< + And<[ + CPred<"::mlir::isa<::cir::VectorType>($_self)">, + CPred<"::mlir::isa<::cir::SingleType, ::cir::DoubleType>(" + "::mlir::cast<::cir::VectorType>($_self).getEltType())">, + ]>, "!cir.vector of !cir.fp"> { +} + // Constraints def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>; def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf< [PrimitiveSInt, SignedIntegerVector]>; +def CIR_AnyFloatOrVecOfFloat: AnyTypeOf<[CIR_AnyFloat, FPVector]>; // Pointer to Arrays def ArrayPtr : Type< diff --git a/clang/test/CIR/Lowering/builtin-floating-point.cir b/clang/test/CIR/Lowering/builtin-floating-point.cir index 82b733233da3..157b3abe10f5 100644 --- a/clang/test/CIR/Lowering/builtin-floating-point.cir +++ b/clang/test/CIR/Lowering/builtin-floating-point.cir @@ -2,49 +2,133 @@ // RUN: FileCheck --input-file=%t.ll %s module { - cir.func @test(%arg0 : !cir.float) { + cir.func @test(%arg0 : !cir.float, %arg1 : !cir.vector, %arg2 : !cir.vector) { %1 = cir.cos %arg0 : !cir.float // CHECK: llvm.intr.cos(%arg0) : (f32) -> f32 + + %101 = cir.cos %arg1 : !cir.vector + // CHECK: llvm.intr.cos(%arg1) : (vector<2xf64>) -> vector<2xf64> + %201 = cir.cos %arg2 : !cir.vector + // CHECK: llvm.intr.cos(%arg2) : (vector<4xf32>) -> vector<4xf32> + %2 = cir.ceil %arg0 : !cir.float // CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32 + %102 = cir.ceil %arg1 : !cir.vector + // CHECK: llvm.intr.ceil(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %202 = cir.ceil %arg2 : !cir.vector + // CHECK: llvm.intr.ceil(%arg2) : (vector<4xf32>) -> vector<4xf32> + %3 = cir.exp %arg0 : !cir.float // CHECK: llvm.intr.exp(%arg0) : (f32) -> f32 + %103 = cir.exp %arg1 : !cir.vector + // CHECK: llvm.intr.exp(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %203 = cir.exp %arg2 : !cir.vector + // CHECK: llvm.intr.exp(%arg2) : (vector<4xf32>) -> vector<4xf32> + %4 = cir.exp2 %arg0 : !cir.float // CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32 + %104 = cir.exp2 %arg1 : !cir.vector + // CHECK: llvm.intr.exp2(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %204 = cir.exp2 %arg2 : !cir.vector + // CHECK: llvm.intr.exp2(%arg2) : (vector<4xf32>) -> vector<4xf32> + %5 = cir.fabs %arg0 : !cir.float // CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32 + %105 = cir.fabs %arg1 : !cir.vector + // CHECK: llvm.intr.fabs(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %205 = cir.fabs %arg2 : !cir.vector + // CHECK: llvm.intr.fabs(%arg2) : (vector<4xf32>) -> vector<4xf32> + %6 = cir.floor %arg0 : !cir.float // CHECK: llvm.intr.floor(%arg0) : (f32) -> f32 + %106 = cir.floor %arg1 : !cir.vector + // CHECK: llvm.intr.floor(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %206 = cir.floor %arg2 : !cir.vector + // CHECK: llvm.intr.floor(%arg2) : (vector<4xf32>) -> vector<4xf32> + %7 = cir.log %arg0 : !cir.float // CHECK: llvm.intr.log(%arg0) : (f32) -> f32 + %107 = cir.log %arg1 : !cir.vector + // CHECK: llvm.intr.log(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %207 = cir.log %arg2 : !cir.vector + // CHECK: llvm.intr.log(%arg2) : (vector<4xf32>) -> vector<4xf32> + %8 = cir.log10 %arg0 : !cir.float // CHECK: llvm.intr.log10(%arg0) : (f32) -> f32 + %108 = cir.log10 %arg1 : !cir.vector + // CHECK: llvm.intr.log10(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %208 = cir.log10 %arg2 : !cir.vector + // CHECK: llvm.intr.log10(%arg2) : (vector<4xf32>) -> vector<4xf32> + %9 = cir.log2 %arg0 : !cir.float // CHECK: llvm.intr.log2(%arg0) : (f32) -> f32 + %109 = cir.log2 %arg1 : !cir.vector + // CHECK: llvm.intr.log2(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %209 = cir.log2 %arg2 : !cir.vector + // CHECK: llvm.intr.log2(%arg2) : (vector<4xf32>) -> vector<4xf32> + %10 = cir.nearbyint %arg0 : !cir.float // CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32 + %110 = cir.nearbyint %arg1 : !cir.vector + // CHECK: llvm.intr.nearbyint(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %210 = cir.nearbyint %arg2 : !cir.vector + // CHECK: llvm.intr.nearbyint(%arg2) : (vector<4xf32>) -> vector<4xf32> + %11 = cir.rint %arg0 : !cir.float // CHECK: llvm.intr.rint(%arg0) : (f32) -> f32 + %111 = cir.rint %arg1 : !cir.vector + // CHECK: llvm.intr.rint(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %211 = cir.rint %arg2 : !cir.vector + // CHECK: llvm.intr.rint(%arg2) : (vector<4xf32>) -> vector<4xf32> + %12 = cir.round %arg0 : !cir.float // CHECK: llvm.intr.round(%arg0) : (f32) -> f32 + %112 = cir.round %arg1 : !cir.vector + // CHECK: llvm.intr.round(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %212 = cir.round %arg2 : !cir.vector + // CHECK: llvm.intr.round(%arg2) : (vector<4xf32>) -> vector<4xf32> + %13 = cir.sin %arg0 : !cir.float // CHECK: llvm.intr.sin(%arg0) : (f32) -> f32 + %113 = cir.sin %arg1 : !cir.vector + // CHECK: llvm.intr.sin(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %213 = cir.sin %arg2 : !cir.vector + // CHECK: llvm.intr.sin(%arg2) : (vector<4xf32>) -> vector<4xf32> + %14 = cir.sqrt %arg0 : !cir.float // CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32 + %114 = cir.sqrt %arg1 : !cir.vector + // CHECK: llvm.intr.sqrt(%arg1) : (vector<2xf64>) -> vector<2xf64> + + %214 = cir.sqrt %arg2 : !cir.vector + // CHECK: llvm.intr.sqrt(%arg2) : (vector<4xf32>) -> vector<4xf32> + cir.return } }