Skip to content

Commit 415fc73

Browse files
authored
[Enhancement] Add tma bulk copy. (#600)
1 parent 29cae4e commit 415fc73

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

src/tl_templates/cuda/copy_sm90.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,28 @@
1010

1111
namespace tl {
1212

13+
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar,
14+
uint32_t size) {
15+
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
16+
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
17+
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
18+
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
19+
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar)
20+
:);
21+
}
22+
23+
TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
24+
uint64_t &smem_mbar, uint32_t size,
25+
uint16_t mask) {
26+
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
27+
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
28+
asm volatile(
29+
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes."
30+
"multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr),
31+
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask)
32+
:);
33+
}
34+
1335
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
1436
void const *const smem_ptr, int32_t const &crd0) {
1537
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
@@ -105,6 +127,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
105127
: "memory");
106128
}
107129

130+
TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) {
131+
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
132+
asm volatile(
133+
"cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"(
134+
smem_int_ptr),
135+
"l"(dst_gmem_ptr), "r"(size)
136+
:);
137+
}
138+
108139
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
109140
void const *const smem_ptr, int32_t const &crd0) {
110141
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);

0 commit comments

Comments
 (0)