Skip to content

Commit dad733e

Browse files
committed
added typetree support for memcpy
1 parent bd70be1 commit dad733e

File tree

20 files changed

+131
-31
lines changed

20 files changed

+131
-31
lines changed

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13831383
_src_align: Align,
13841384
size: RValue<'gcc>,
13851385
flags: MemFlags,
1386+
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
13861387
) {
13871388
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13881389
let size = self.intcast(size, self.type_size_t(), false);

compiler/rustc_codegen_gcc/src/intrinsic/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
771771
scratch_align,
772772
bx.const_usize(self.layout.size.bytes()),
773773
MemFlags::empty(),
774+
None,
774775
);
775776

776777
bx.lifetime_end(scratch, scratch_size);

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
244244
scratch_align,
245245
bx.const_usize(copy_bytes),
246246
MemFlags::empty(),
247+
None,
247248
);
248249
bx.lifetime_end(llscratch, scratch_size);
249250
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1118,11 +1119,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11181119
src_align: Align,
11191120
size: &'ll Value,
11201121
flags: MemFlags,
1122+
tt: Option<FncTree>,
11211123
) {
11221124
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11231125
let size = self.intcast(size, self.type_isize(), false);
11241126
let is_volatile = flags.contains(MemFlags::VOLATILE);
1125-
unsafe {
1127+
let memcpy = unsafe {
11261128
llvm::LLVMRustBuildMemCpy(
11271129
self.llbuilder,
11281130
dst,
@@ -1131,7 +1133,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11311133
src_align.bytes() as c_uint,
11321134
size,
11331135
is_volatile,
1134-
);
1136+
)
1137+
};
1138+
1139+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1140+
// a memcpy during autodiff, it needs to know the structure of the data being
1141+
// copied to properly track derivatives. For example, copying an array of floats
1142+
// vs. copying a struct with mixed types requires different derivative handling.
1143+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1144+
if let Some(tt) = tt {
1145+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11351146
}
11361147
}
11371148

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#[cfg(llvm_enzyme)]
12
use std::ffi::{CString, c_char, c_uint};
23

3-
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
4+
use rustc_ast::expand::typetree::FncTree;
5+
#[cfg(llvm_enzyme)]
6+
use rustc_ast::expand::typetree::TypeTree as RustTypeTree;
47

58
use crate::attributes;
69
use crate::llvm::{self, Value};
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
5053
enzyme_tt
5154
}
5255

53-
#[cfg(not(llvm_enzyme))]
54-
fn to_enzyme_typetree(
55-
_rust_typetree: RustTypeTree,
56-
_data_layout: &str,
57-
_llcx: &llvm::Context,
58-
) -> ! {
59-
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
60-
}
61-
6256
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
6357
#[cfg(llvm_enzyme)]
6458
pub(crate) fn add_tt<'ll>(

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
735735
src_align,
736736
bx.const_u32(layout.layout.size().bytes() as u32),
737737
MemFlags::empty(),
738+
None,
738739
);
739740
tmp
740741
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
16231623
align,
16241624
bx.const_usize(copy_bytes),
16251625
MemFlags::empty(),
1626+
None,
16261627
);
16271628
// ...and then load it with the ABI type.
16281629
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

compiler/rustc_codegen_ssa/src/traits/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ pub trait BuilderMethods<'a, 'tcx>:
424424
src_align: Align,
425425
size: Self::Value,
426426
flags: MemFlags,
427+
tt: Option<rustc_ast::expand::typetree::FncTree>,
427428
);
428429
fn memmove(
429430
&mut self,
@@ -480,7 +481,7 @@ pub trait BuilderMethods<'a, 'tcx>:
480481
temp.val.store_with_flags(self, dst.with_type(layout), flags);
481482
} else if !layout.is_zst() {
482483
let bytes = self.const_usize(layout.size.bytes());
483-
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
484+
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
484485
}
485486
}
486487

0 commit comments

Comments
 (0)