-
Notifications
You must be signed in to change notification settings - Fork 7
Add turing mma support and test #1643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a449f49
edd43d9
ddac459
2f08d09
ca77ff4
9caeb18
de1d3ec
b393cbe
74f8c12
db34181
4dec827
5ecd102
adb19b2
c97c605
264bc77
5bc918c
2f0ae5e
b96462a
5347fda
3f580d1
4834995
9e18e04
1fb32ed
43d0e72
c1f4374
7c91a0a
ae2bd1f
f4c6f12
d719bd8
a2c28c9
d28574c
2fa3a92
e0d1769
d2274cd
090752d
640d3aa
da40702
36786b9
2edd1cc
bcb18af
fd5e178
86f2d61
484efd8
fbaeaf7
5355e07
9e298b0
9615196
a787b59
7cf98ab
e2e0707
9c224dc
584567c
09f5ce0
998adb5
13266ce
5dc3d4b
b460fb1
fc35480
76563a5
119f0eb
abcd1e8
ebc215a
5e8128f
80e840c
5458e4b
b05ae79
9684acb
98f45dd
f0eb7b6
0e204aa
e92430e
37ce5aa
3d8cb3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,40 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { | |
return smem_ptr_uint; | ||
} | ||
|
||
// LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and | ||
// .x4. In .x2 option. the the address register of upper half warp (lane 16-31) | ||
// are un-used but on Turing [sm75,sm80) architecture these un-used addresses | ||
// need to be valid, in the sense that: | ||
// 1. The data it points to has to be within allocated shared mem buffer. | ||
// 2. The address needs to be aligned to 16 byte. | ||
// See also: | ||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix | ||
// This function addresses 2. above by masking out the sub-16B component | ||
// of the address in upper warp and 1. is guaranteed by ldmatrix swizzle | ||
// util. | ||
// This will **not** affect any functionality. This is just modification | ||
// of unused pointers to satisfy the alignment requirement on Turing | ||
// hardware. | ||
// The alignment requirement is lifted on sm80+, | ||
// so this function is a no-op on Ampere or above. | ||
DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { | ||
#if (__CUDA_ARCH__ < 800) | ||
const unsigned thread_id = threadIdx.x; | ||
// Upper half warp has 8 bytes offset from aligned in .x2 option | ||
// of ldmatrix. Currently no support for .x1 so assume always | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any specific reason not to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could add .x1 in a follow up. I currently didn't yet add the smaller mma tiles that .x1 would pair with, they are not immediately useful yet in large CTA tile kernels. |
||
// adjust by half warp. | ||
constexpr unsigned half_warp = 16; | ||
// Need to adjust to 16 byte alignment, mask out un-aligned component. | ||
constexpr unsigned mask_out = 16 - 1; | ||
// Adjust only in upper half warp. | ||
// use bit math to reduce strength | ||
if (thread_id & half_warp) { | ||
// mask out the bits where adjust_mask has 1. | ||
addr_in_byte &= (~mask_out); | ||
} | ||
#endif //(__CUDA_ARCH__ < 800) | ||
} | ||
|
||
} // namespace util | ||
|
||
// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. | ||
|
@@ -36,6 +70,7 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { | |
DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { | ||
uint2& val = reinterpret_cast<uint2&>(out); | ||
unsigned addr = util::toSmem(ptr); | ||
util::adjustPartialLdMatrixAddrInTuring(addr); | ||
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" | ||
: "=r"(val.x), "=r"(val.y) | ||
: "r"(addr)); | ||
|
@@ -47,6 +82,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { | |
DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) { | ||
uint2& val = reinterpret_cast<uint2&>(out); | ||
unsigned addr = util::toSmem(ptr); | ||
util::adjustPartialLdMatrixAddrInTuring(addr); | ||
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" | ||
: "=r"(val.x), "=r"(val.y) | ||
: "r"(addr)); | ||
|
Uh oh!
There was an error while loading. Please reload this page.