Skip to content

Commit 7d622ca

Browse files
committed
WIP: working solution, but a lot of cruft. Clean up and unwind the
unneeded stuff
1 parent 3fd5949 commit 7d622ca

File tree

5 files changed

+128
-23
lines changed

5 files changed

+128
-23
lines changed

crates/cuda_std/src/rt/mod.rs

Lines changed: 94 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,15 @@ bitflags::bitflags! {
3131

3232
#[derive(Debug)]
3333
pub struct Stream {
34-
raw: cuda::cudaStream_t,
34+
pub raw: cuda::cudaStream_t,
3535
}
3636

3737
impl Stream {
38+
// /// Creates a new stream with flags.
39+
// pub fn new(flags: StreamFlags) -> Self {
40+
// Self {}
41+
// }
42+
3843
/// Creates a new stream with flags.
3944
pub fn new(flags: StreamFlags) -> CudaResult<Self> {
4045
let mut stream = MaybeUninit::uninit();
@@ -47,10 +52,11 @@ impl Stream {
4752
}
4853
}
4954

50-
#[doc(hidden)]
51-
pub fn launch(&self, param_buf: *mut c_void) -> CudaResult<()> {
52-
unsafe { cuda::cudaLaunchDeviceV2(param_buf, self.raw).to_result() }
53-
}
55+
// #[doc(hidden)]
56+
// pub fn launch(&self, param_buf: *mut c_void) -> CudaResult<()> {
57+
// unsafe { cuda::cudaLaunchDeviceV2(param_buf, core::ptr::null_mut()).to_result() }
58+
// // unsafe { cuda::cudaLaunchDeviceV2(param_buf, self.raw).to_result() }
59+
// }
5460
}
5561

5662
impl Drop for Stream {
@@ -63,13 +69,17 @@ impl Drop for Stream {
6369

6470
#[macro_export]
6571
macro_rules! launch {
66-
($func:ident<<<$grid_dim:expr, $block_dim:expr, $smem_size:expr, $stream:ident>>>($($param:expr),* $(,)?)) => {{
72+
// ($func:ident<<<$grid_dim:expr, $block_dim:expr, $smem_size:expr, $stream:ident>>>($($param:expr),* $(,)?)) => {{
73+
($func:ident<<<$grid_dim:expr, $block_dim:expr, ($smem_size:expr)>>>($($param:expr),* $(,)?)) => {{
6774
use $crate::rt::ToResult;
6875
use $crate::float::GpuFloat;
6976
let grid_dim = $crate::rt::GridSize::from($grid_dim);
7077
let block_dim = $crate::rt::BlockSize::from($block_dim);
78+
79+
// Get a device buffer for kernel launch.
80+
let fptr = $func as *const ();
7181
let mut buf = $crate::rt::sys::cudaGetParameterBufferV2(
72-
&$func as *const _ as *const ::core::ffi::c_void,
82+
fptr as *const ::core::ffi::c_void,
7383
$crate::rt::sys::dim3 {
7484
x: grid_dim.x,
7585
y: grid_dim.y,
@@ -80,24 +90,87 @@ macro_rules! launch {
8090
y: block_dim.y,
8191
z: block_dim.z
8292
},
83-
$smem_size
84-
) as *mut u8;
85-
unsafe {
86-
let mut offset = 0;
87-
$(
88-
let param = $param;
89-
let size = ::core::mem::size_of_val(&param);
90-
let mut buf_idx = (offset as f32 / size as f32).ceil() as usize + 1;
91-
offset = buf_idx * size;
92-
let ptr = &param as *const _ as *const u8;
93-
let dst = buf.add(offset);
94-
::core::ptr::copy_nonoverlapping(&param as *const _ as *const u8, dst, size);
95-
)*
93+
$smem_size,
94+
);
95+
96+
// Ensure buffer is not a nil ptr.
97+
if buf.is_null() {
98+
return;
9699
}
100+
101+
// Load data into buffer.
102+
let mut offset = 0;
103+
$(
104+
let param = $param;
105+
let size = ::core::mem::size_of_val(&param);
106+
let param_ptr = &param as *const _ as *const ::core::ffi::c_void;
107+
let dst = buf.add(offset).copy_from(param_ptr, size);
108+
offset += size;
109+
)*
97110
if false {
98111
$func($($param),*);
99112
}
100-
$stream.launch(buf as *mut ::core::ffi::c_void)
113+
// unsafe {
114+
// let mut offset = 0;
115+
// $(
116+
// let param = $param;
117+
// let size = ::core::mem::size_of_val(&param);
118+
// let mut buf_idx = (offset as f32 / size as f32).ceil() as usize + 1;
119+
// offset = buf_idx * size;
120+
// let ptr = &param as *const _ as *const u8;
121+
// let dst = buf.add(offset);
122+
// ::core::ptr::copy_nonoverlapping(&param as *const _ as *const u8, dst, size);
123+
// )*
124+
// }
125+
// if false {
126+
// $func($($param),*);
127+
// }
128+
129+
// Launch the kernel.
130+
$crate::rt::sys::cudaLaunchDeviceV2(buf as *mut ::core::ffi::c_void, ::core::ptr::null_mut() as *mut _)
131+
132+
// let mut buf = $crate::rt::sys::cudaGetParameterBuffer(alignment, size) as *mut u8;
133+
134+
// // Populate the buffer with given arguments.
135+
// let mut offset = 0;
136+
// $(
137+
// let param = $param;
138+
// let size = ::core::mem::size_of_val(&param);
139+
// let buf_bytes_ptr = (buf as *mut u8).add(offset);
140+
// ::core::ptr::copy_nonoverlapping($param as *const _, buf_bytes_ptr.into(), size);
141+
// offset += size;
142+
// )*
143+
144+
// let mut offset = 0;
145+
// $(
146+
// let param = $param;
147+
// let size = ::core::mem::size_of_val(&param);
148+
// let mut buf_idx = (offset as f32 / size as f32).ceil() as usize + 1;
149+
// offset = buf_idx * size;
150+
// let ptr = &param as *const _ as *const u8;
151+
// let dst = buf.add(offset);
152+
// ::core::ptr::copy_nonoverlapping(&param as *const _ as *const u8, dst, size);
153+
// )*
154+
155+
// // Launch the kernel.
156+
// let fptr = $func as *const ();
157+
// $crate::rt::sys::cudaLaunchDevice(
158+
// fptr as *const ::core::ffi::c_void,
159+
// buf as *mut ::core::ffi::c_void,
160+
// $crate::rt::sys::dim3 {
161+
// x: grid_dim.x,
162+
// y: grid_dim.y,
163+
// z: grid_dim.z
164+
// },
165+
// $crate::rt::sys::dim3 {
166+
// x: block_dim.x,
167+
// y: block_dim.y,
168+
// z: block_dim.z
169+
// },
170+
// $smem_size,
171+
// ::core::ptr::null_mut() as *mut _,
172+
// // $stream.raw,
173+
// )
101174
}};
102175
}
103176

crates/cuda_std/src/rt/sys.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ pub use crate::rt::driver_types_sys::*;
1414
// to share this stuff with cust.
1515

1616
extern "C" {
17+
pub fn cudaGetParameterBuffer(alignment: usize, size: usize) -> *mut c_void;
18+
pub fn cudaLaunchDevice(
19+
func: *const c_void,
20+
parameterBuffer: *const c_void,
21+
gridDimension: dim3,
22+
blockDimension: dim3,
23+
sharedMemSize: c_uint,
24+
stream: cudaStream_t,
25+
) -> cudaError_t;
26+
1727
pub fn cudaDeviceGetAttribute(
1828
value: *mut c_int,
1929
attr: cudaDeviceAttr,

crates/cust/src/link.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,28 @@ impl Linker {
114114
}
115115
}
116116

117+
/// Link device runtime lib.
118+
pub fn add_libcudadevrt(&mut self) -> CudaResult<()> {
119+
let mut bytes = std::fs::read("/usr/local/cuda-11/lib64/libcudadevrt.a")
120+
.expect("could not read libcudadevrt.a");
121+
122+
unsafe {
123+
cuda::cuLinkAddData_v2(
124+
self.raw,
125+
cuda::CUjitInputType::CU_JIT_INPUT_LIBRARY,
126+
// cuda_sys wants *mut but from the API docs we know we retain ownership so
127+
// this cast is sound.
128+
bytes.as_mut_ptr() as *mut _,
129+
bytes.len(),
130+
UNNAMED.as_ptr().cast(),
131+
0,
132+
std::ptr::null_mut(),
133+
std::ptr::null_mut(),
134+
)
135+
.to_result()
136+
}
137+
}
138+
117139
/// Runs the linker to generate the final cubin bytes. Also returns a duration
118140
/// for how long it took to run the linker.
119141
pub fn complete(self) -> CudaResult<Vec<u8>> {

crates/cust/src/module.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ impl Module {
338338
/// ```
339339
#[deprecated(
340340
since = "0.3.0",
341-
note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing
341+
note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing
342342
an empty slice of options (usually)
343343
"
344344
)]

crates/rustc_codegen_nvvm/build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn main() {
2121
// this is set by cuda_builder, but in case somebody is using the codegen
2222
// manually, default to 520 (which is what nvvm defaults to).
2323
if option_env!("CUDA_ARCH").is_none() {
24-
println!("cargo:rustc-env=CUDA_ARCH=520")
24+
println!("cargo:rustc-env=CUDA_ARCH=750")
2525
}
2626
}
2727

0 commit comments

Comments
 (0)