Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/tl_templates/cuda/copy_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@

namespace tl {

TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar,
uint32_t size) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar)
:);
Comment on lines +13 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The inline assembly uses raw pointers, but the function signature uses void*. It's better to be explicit about the memory spaces involved by casting the void* to uintptr_t and then casting to the appropriate pointer type within the assembly string. This will improve readability and maintainability, and reduce the risk of unintended pointer arithmetic or type mismatches.

Consider using reinterpret_cast<uintptr_t>(smem_ptr) and reinterpret_cast<uintptr_t>(gmem_ptr) and then casting inside the inline assembly.

TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar,
                        uint32_t size) {
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uintptr_t smem_ptr_int = reinterpret_cast<uintptr_t>(smem_ptr); // Cast to uintptr_t
  asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
               "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_ptr_int),
               "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar)
               :);
}

}

TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
uint64_t &smem_mbar, uint32_t size,
uint16_t mask) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes."
"multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask)
:);
Comment on lines +23 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, it's better to be explicit about the memory spaces involved by casting the void* to uintptr_t and then casting to the appropriate pointer type within the assembly string. This will improve readability and maintainability, and reduce the risk of unintended pointer arithmetic or type mismatches.

Consider using reinterpret_cast<uintptr_t>(smem_ptr) and reinterpret_cast<uintptr_t>(gmem_ptr) and then casting inside the inline assembly.

TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
                                  uint64_t &smem_mbar, uint32_t size,
                                  uint16_t mask) {
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uintptr_t smem_ptr_int = reinterpret_cast<uintptr_t>(smem_ptr); // Cast to uintptr_t
  asm volatile(
      "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes."
      "multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_ptr_int),
      "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask)
      :);
}

}

TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
Expand Down Expand Up @@ -105,6 +127,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
: "memory");
}

TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"(
smem_int_ptr),
"l"(dst_gmem_ptr), "r"(size)
:);
Comment on lines +130 to +136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comments, it's better to be explicit about the memory spaces involved by casting the void* to uintptr_t and then casting to the appropriate pointer type within the assembly string. This will improve readability and maintainability, and reduce the risk of unintended pointer arithmetic or type mismatches.

Consider using reinterpret_cast<uintptr_t>(smem_ptr) and reinterpret_cast<uintptr_t>(dst_gmem_ptr) and then casting inside the inline assembly.

TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) {
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
  uintptr_t dst_gmem_ptr_int = reinterpret_cast<uintptr_t>(dst_gmem_ptr); // Cast to uintptr_t
  asm volatile(
      "cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"(
          smem_int_ptr),
      "l"(dst_gmem_ptr_int), "r"(size)
      :);
}

}

TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
Expand Down
Loading