Skip to content

Commit 2384f28

Browse files
committed
Add autocasts for bf16 and bf16xN
1 parent 46f7571 commit 2384f28

File tree

5 files changed

+37
-6
lines changed

5 files changed

+37
-6
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
376376
}
377377

378378
match self.type_kind(llvm_ty) {
379+
TypeKind::BFloat => rust_ty == self.type_i16(),
380+
379381
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
380382
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
381383
// as, well, packed structs, so they won't match with those either)
@@ -393,11 +395,18 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
393395
},
394396
)
395397
}
396-
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
398+
TypeKind::Vector => {
397399
let element_count = self.vector_length(llvm_ty) as u64;
398-
let int_width = element_count.next_power_of_two().max(8);
400+
let llvm_element_ty = self.element_type(llvm_ty);
399401

400-
rust_ty == self.type_ix(int_width)
402+
if llvm_element_ty == self.type_bf16() {
403+
rust_ty == self.type_vector(self.type_i16(), element_count)
404+
} else if llvm_element_ty == self.type_i1() {
405+
let int_width = element_count.next_power_of_two().max(8);
406+
rust_ty == self.type_ix(int_width)
407+
} else {
408+
false
409+
}
401410
}
402411
_ => false,
403412
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17551755
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
17561756
}
17571757
}
1758-
_ => unreachable!(),
1758+
_ => self.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
17591759
}
17601760
}
17611761

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,9 @@ unsafe extern "C" {
11011101
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
11021102
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
11031103

1104+
// Operations on non-IEEE real types
1105+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
1106+
11041107
// Operations on function types
11051108
pub(crate) fn LLVMFunctionType<'a>(
11061109
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
176176
)
177177
}
178178
}
179+
180+
pub(crate) fn type_bf16(&self) -> &'ll Type {
181+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
182+
}
179183
}
180184

181185
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -249,7 +253,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
249253

250254
fn float_width(&self, ty: &'ll Type) -> usize {
251255
match self.type_kind(ty) {
252-
TypeKind::Half => 16,
256+
TypeKind::Half | TypeKind::BFloat => 16,
253257
TypeKind::Float => 32,
254258
TypeKind::Double => 64,
255259
TypeKind::X86_FP80 => 80,

tests/codegen-llvm/inject-autocast.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
55
#![crate_type = "lib"]
66

7-
use std::simd::i64x2;
7+
use std::simd::{f32x4, i16x8, i64x2};
88

99
#[repr(simd)]
1010
pub struct Tile([i8; 1024]);
@@ -36,6 +36,19 @@ pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
3636
foo(a, b)
3737
}
3838

39+
// CHECK-LABEL: @bf16_vector_autocast
40+
#[no_mangle]
41+
pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
42+
extern "unadjusted" {
43+
#[link_name = "llvm.x86.vcvtneps2bf16128"]
44+
fn foo(a: f32x4) -> i16x8;
45+
}
46+
47+
// CHECK: %1 = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> %0)
48+
// CHECK-NEXT: %2 = bitcast <8 x bfloat> %1 to <8 x i16>
49+
foo(a)
50+
}
51+
3952
// CHECK-LABEL: @struct_autocast
4053
#[no_mangle]
4154
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
@@ -77,6 +90,8 @@ pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
7790

7891
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
7992

93+
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)
94+
8095
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
8196

8297
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)

0 commit comments

Comments
 (0)