|
10 | 10 |
|
11 | 11 | namespace tl {
|
12 | 12 |
|
| 13 | +TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar, |
| 14 | + uint32_t size) { |
| 15 | + uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); |
| 16 | + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); |
| 17 | + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" |
| 18 | + "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), |
| 19 | + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar) |
| 20 | + :); |
| 21 | +} |
| 22 | + |
| 23 | +TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, |
| 24 | + uint64_t &smem_mbar, uint32_t size, |
| 25 | + uint16_t mask) { |
| 26 | + uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); |
| 27 | + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); |
| 28 | + asm volatile( |
| 29 | + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes." |
| 30 | + "multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr), |
| 31 | + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask) |
| 32 | + :); |
| 33 | +} |
| 34 | + |
13 | 35 | TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
|
14 | 36 | void const *const smem_ptr, int32_t const &crd0) {
|
15 | 37 | uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
|
@@ -105,6 +127,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
|
105 | 127 | : "memory");
|
106 | 128 | }
|
107 | 129 |
|
| 130 | +TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) { |
| 131 | + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); |
| 132 | + asm volatile( |
| 133 | + "cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"( |
| 134 | + smem_int_ptr), |
| 135 | + "l"(dst_gmem_ptr), "r"(size) |
| 136 | + :); |
| 137 | +} |
| 138 | + |
108 | 139 | TL_DEVICE void tma_store(const CUtensorMap &descriptor,
|
109 | 140 | void const *const smem_ptr, int32_t const &crd0) {
|
110 | 141 | uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
|
|
0 commit comments