3
3
// variable in order to cache the compiled artifacts and avoid recompiling too often.
4
4
use anyhow:: { Context , Result } ;
5
5
use rayon:: prelude:: * ;
6
+ use std:: fs;
6
7
use std:: path:: PathBuf ;
7
8
use std:: str:: FromStr ;
8
9
9
- const KERNEL_FILES : [ & str ; 4 ] = [
10
- "flash_api.cu" ,
11
- "fmha_fwd_hdim32.cu" ,
12
- "fmha_fwd_hdim64.cu" ,
13
- "fmha_fwd_hdim128.cu" ,
14
- ] ;
10
+ // const KERNEL_FILES: [&str; 4] = [
11
+ // "flash_api.cu",
12
+ // "fmha_fwd_hdim32.cu",
13
+ // "fmha_fwd_hdim64.cu",
14
+ // "fmha_fwd_hdim128.cu",
15
+ // ];
16
+
17
+ /// Recursively reads the filenames in a directory and stores them in a Vec.
18
+ fn _read_dir_recursively ( dir_path : & PathBuf , paths : & mut Vec < PathBuf > ) -> std:: io:: Result < ( ) > {
19
+ for entry in fs:: read_dir ( dir_path) ? {
20
+ let entry = entry?;
21
+ let path = entry. path ( ) ;
22
+
23
+ if path. is_dir ( ) {
24
+ _read_dir_recursively ( & path, paths) ?;
25
+ } else {
26
+ paths. push ( path) ;
27
+ }
28
+ }
29
+
30
+ Ok ( ( ) )
31
+ }
32
+
33
+ /// Recursively reads the filenames in a directory and stores them in a Vec.
34
+ fn read_dir_recursively ( dir_path : & PathBuf ) -> std:: io:: Result < Vec < PathBuf > > {
35
+ let mut paths = Vec :: new ( ) ;
36
+ _read_dir_recursively ( dir_path, & mut paths) ?;
37
+ Ok ( paths)
38
+ }
15
39
16
40
fn main ( ) -> Result < ( ) > {
17
41
let num_cpus = std:: env:: var ( "RAYON_NUM_THREADS" ) . map_or_else (
@@ -25,12 +49,11 @@ fn main() -> Result<()> {
25
49
. unwrap ( ) ;
26
50
27
51
println ! ( "cargo:rerun-if-changed=build.rs" ) ;
28
- for kernel_file in KERNEL_FILES . iter ( ) {
29
- println ! ( "cargo:rerun-if-changed=kernels/{kernel_file}" ) ;
52
+
53
+ let paths = read_dir_recursively ( & PathBuf :: from_str ( "kernels" ) ?) ?;
54
+ for file in paths. iter ( ) {
55
+ println ! ( "cargo:rerun-if-changed={}" , file. display( ) ) ;
30
56
}
31
- println ! ( "cargo:rerun-if-changed=kernels/**.h" ) ;
32
- println ! ( "cargo:rerun-if-changed=kernels/**.cuh" ) ;
33
- println ! ( "cargo:rerun-if-changed=kernels/fmha/**.h" ) ;
34
57
let out_dir = PathBuf :: from ( std:: env:: var ( "OUT_DIR" ) . context ( "OUT_DIR not set" ) ?) ;
35
58
let build_dir = match std:: env:: var ( "CANDLE_FLASH_ATTN_BUILD_DIR" ) {
36
59
Err ( _) =>
@@ -57,12 +80,17 @@ fn main() -> Result<()> {
57
80
let out_file = build_dir. join ( "libflashattentionv1.a" ) ;
58
81
59
82
let kernel_dir = PathBuf :: from ( "kernels" ) ;
60
- let cu_files: Vec < _ > = KERNEL_FILES
83
+ let kernels: Vec < _ > = paths
84
+ . iter ( )
85
+ . filter ( |f| f. extension ( ) . map ( |ext| ext == "cu" ) . unwrap_or_default ( ) )
86
+ . collect ( ) ;
87
+ let cu_files: Vec < _ > = kernels
61
88
. iter ( )
62
89
. map ( |f| {
63
90
let mut obj_file = out_dir. join ( f) ;
91
+ fs:: create_dir_all ( obj_file. parent ( ) . unwrap ( ) ) . unwrap ( ) ;
64
92
obj_file. set_extension ( "o" ) ;
65
- ( kernel_dir . join ( f ) , obj_file)
93
+ ( f , obj_file)
66
94
} )
67
95
. collect ( ) ;
68
96
let out_modified: Result < _ , _ > = out_file. metadata ( ) . and_then ( |m| m. modified ( ) ) ;
0 commit comments