-
Notifications
You must be signed in to change notification settings - Fork 254
[Enhancement] Add tma bulk copy. #600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,28 @@ | |
|
||
namespace tl { | ||
|
||
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar, | ||
uint32_t size) { | ||
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); | ||
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" | ||
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), | ||
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar) | ||
:); | ||
} | ||
|
||
TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, | ||
uint64_t &smem_mbar, uint32_t size, | ||
uint16_t mask) { | ||
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); | ||
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||
asm volatile( | ||
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes." | ||
"multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr), | ||
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask) | ||
:); | ||
Comment on lines
+23
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, it's better to be explicit about the memory spaces involved by casting the Consider using TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
uint64_t &smem_mbar, uint32_t size,
uint16_t mask) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uintptr_t smem_ptr_int = reinterpret_cast<uintptr_t>(smem_ptr); // Cast to uintptr_t
asm volatile(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes."
"multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_ptr_int),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask)
:);
} |
||
} | ||
|
||
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, | ||
void const *const smem_ptr, int32_t const &crd0) { | ||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); | ||
|
@@ -105,6 +127,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor, | |
: "memory"); | ||
} | ||
|
||
TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) { | ||
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||
asm volatile( | ||
"cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"( | ||
smem_int_ptr), | ||
"l"(dst_gmem_ptr), "r"(size) | ||
:); | ||
Comment on lines
+130
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comments, it's better to be explicit about the memory spaces involved by casting the Consider using TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
uintptr_t dst_gmem_ptr_int = reinterpret_cast<uintptr_t>(dst_gmem_ptr); // Cast to uintptr_t
asm volatile(
"cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"(
smem_int_ptr),
"l"(dst_gmem_ptr_int), "r"(size)
:);
} |
||
} | ||
|
||
TL_DEVICE void tma_store(const CUtensorMap &descriptor, | ||
void const *const smem_ptr, int32_t const &crd0) { | ||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inline assembly uses raw pointers, but the function signature uses
void*
. It's better to be explicit about the memory spaces involved by casting thevoid*
touintptr_t
and then casting to the appropriate pointer type within the assembly string. This will improve readability and maintainability, and reduce the risk of unintended pointer arithmetic or type mismatches.Consider using
reinterpret_cast<uintptr_t>(smem_ptr)
andreinterpret_cast<uintptr_t>(gmem_ptr)
and then casting inside the inline assembly.