-
Notifications
You must be signed in to change notification settings - Fork 7
Minimal set of MMA resource string and mma swizzle #1252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test/cpp/jit/test_gpu.cpp
Outdated
testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); | ||
} | ||
|
||
TEST(NVFuserTest, FusionMMASwizzlePrimitiveFloatAcc_CUDA) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we may want to put this to a new test file (e.g., "test_gpu_mma.cpp")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I will also need to disable this test for <Volta archs.
test/cpp/jit/test_gpu.cpp
Outdated
// reduction tolerance. half reduction to float | ||
// doesn't reach full float tolerance but is better | ||
// than accumulation into half . | ||
TORCH_CHECK(refC.allclose(outC0, 0.001, 0.0001)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried the reference version in double precision, and the check passed with the default tolerance.
auto refC = inA.to(at::kDouble).matmul(inB.to(at::kDouble)).to(at::kFloat);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out. Sorry I didn't put this on the comment. This mini gemm case does pass with float/double version of cublas as reference. I could use the test_validate here if we want. But it would fail again with larger k (on the gemm integration PR). It didn't seem to be able to consistently meet our precision targets, so I just used manually set ones for everything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change the reference to use the double version. Wouldn't want cublas to accumulate in half.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally our tolerances are based on comparisons with a double precision counter part. Comparing float-float requires much larger tolerances because of the high variance float results can have, so leaving the comparison in double allows us to tighten our validation tolerances which is quite helpful.
test/cpp/jit/test_gpu.cpp
Outdated
// Allocate Smem | ||
__shared__ uint4 As_mem[32*8 / 8]; | ||
__shared__ uint4 Bs_mem[8*32 / 8]; | ||
auto As = reinterpret_cast<__half*>(As_mem); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious here, does it make any difference if it's cast to __shared__ half*
? I remember I've read somewhere that nvcc would use a generic load ptx instruction as opposed to a shmem load instruction if a pointer is not annotated as shared. This is just a test function, so it doesn't matter at all, but I'm just curious if we should generally be aware of pointer types .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out.
It looks like nvcc does some inference here in this case. But there's no public docs on how powerful these state space inference can go.
st.shared.u32 [%rd66], %r46;
st.shared.u32 [%rd66+4], %r47;
st.shared.u32 [%rd66+8], %r48;
st.shared.u32 [%rd66+12], %r49;
...
ld.shared.v4.u32 {%r66, %r67, %r68, %r69}, [%rd75];
I guess it wouldn't hurt to add them. We could consider cleaning these up in follow ups since we do something similar for other shared mem allocs too. https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/codegen.cpp#L1276-L1279
test/cpp/jit/test_gpu.cpp
Outdated
// Mini prolog: (single warp) | ||
// global load: | ||
__half Ag_buffer[8]; | ||
*reinterpret_cast<uint4*>(Ag_buffer) = *reinterpret_cast<uint4*>(&Ag[threadIdx.x*8]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler to use uint4
rather than __half
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I can change that. Thanks for pointing this out 👍
@@ -0,0 +1,410 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all uses of nvfuser_index_t
, do they need to be that type? Can any of them be safely replaced with just int
? For example, indexing inside a 32x8
tile should be safe with int
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I will replace them all with int
when we get to a point when these swizzling can be done outside of main loop.
// (cf. ISA 9.7.13.4) | ||
// swizzling at write store so load can be less integer heavy. | ||
// switch bit 0 and bit 1 of row_group_idx | ||
const nvfuser_index_t row_group_swizzled = (row_group_idx & ~0b11) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is nvfuser_index_t
assumed to be int
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just looks like a magic to me. Is this technique also used in CUTLASS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a macro that could be either int64_t or int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think similar computations are around in other libs:
https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/layout/tensor_op_multiplicand_sm70.h#L811-L813
We are not exactly the same though. I just chose ways that are more conveniently mapped to fuser pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally this looks good to me, though I think we'll want to do a few iterations on the resource strings as we make progress.
test/cpp/jit/test_gpu.cpp
Outdated
// The coordinate is what the Ag_buffer data | ||
// corresponds to in un-swizzled layout | ||
// mostly compatible with nvfuser indexing | ||
{(int)threadIdx.x, 0}, lane_id).linearize({8,1})]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need lane_id everywhere? is this just so the indexing math is simpler in the swizzle? Seems we can just always pass in threadIdx.x instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do need lane_id in multiple places, currently in operand A/B read/write, and output un-swizzle, that'd be 5 uses, so I kind of preferred to define it just once. But using tidx also works for me if that looks cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking we'd go across warps with tidx.y or tidx.z which would make it a strict tidx.x, but if not, then this would be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming blockDim.x a multiple of 32 at the moment and we could try removing that constraint if it is beneficial. It'd require some generalization effort.
test/cpp/jit/test_gpu.cpp
Outdated
mem_swizzle::swizzle_sequence< | ||
// This is the instruction mem layout | ||
// for mma32X32X8 | ||
warp_mma::mmaM32N32K8::SmemWriteASwizzle>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there 4 independent swizzles for reg->smem looking across [N,T], [N,T]? I thought there were just 2 i.e. An == Bt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes there will be just two. I haven't built out all the transposition yet. They'd need a few other mma instruction and swizzle options in here. Currently only focusing on getting TT gemm fast. I will need to rename the swizzle a bit to keep the categories organized once we have more of them.
test/cpp/jit/test_gpu.cpp
Outdated
*reinterpret_cast<uint4*>( | ||
&As[ | ||
// Utility to compose a sequence of swizzles | ||
mem_swizzle::swizzle_sequence< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by a swizzle_sequence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be useful when we come to the bank conflict removal step. My current understanding now is that we have swizzle for mma and swizzle for bank conflict (will be in the next few PRs). The former is a function of the instruction we are using and the latter is a function of CTA tile size and vectorization. So swizzle sequence is an interface for fuser to configure them separately and the kernel will apply them sequentially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it's just a mechanism to go from one swizzle format to another?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That could be one way to use this utility as well.
The scenario I was thinking was that the swizzled format in smem in gemm is always a composition of 2 because we want to satisfy mma requirement and avoid bank conflict at the same time.
if say we have a smem dimension of 128x32 f16, with 128b vectorization, and storing a mma operandA, then we need a swizzle that satisfies mma requirement and avoids bank conflict when we load a column.
This could be done with:
cyclic shift blocks of 128b by 8, every 2 rows
(in a separate PR) and then mma operand A swizzle
.
So swizzle sequence
was used as a utility to configure these 2 separately and generate the composed format.
test/cpp/jit/test_gpu.cpp
Outdated
warp_mma::mmaM32N32K8::SmemWriteBSwizzle>( | ||
// The coordinate is what the Bg_buffer data | ||
// corresponds to in un-swizzled layout | ||
// mostly compatible with nvfuser indexing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the native nvFuser mapping this corresponds to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a warp for a 32x32x8. Is there a specific way this is "viewed" by nvFuser? Or is this really focusing on a warp programming methodology?
I think maybe the way we'd really want to think about this, is if we didn't have mma, how would we think about mapping threadIdx.x to the problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to solve this now, but I would think we'd want mma to be a "post process" on a non-mma (more naive) matmul setup. I would think the mma part of a schedule would be best as a very late transformation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I was thinking of this as a pattern matching and as a post process. The current setup I have, on #1274 is essentially splitting out the inner loops and saving them for mma, i.e. [ ... , M_mma(32), N_mma(32), K_mma(8)]
, and this is the ground truth fuser mapping for pattern matching.
To account for the "warp instruction" nature of mma: The 32x32 accumulator is actually distributed in a warp, so I just chose m dimension to bind to warp, i.e.
[ ... , M_mma.warp(32), N_mma(32), K_mma(8)]
, but warp
wasn't modeled explicitly but just later merged into tidx.
test/cpp/jit/test_gpu.cpp
Outdated
.linearize({32,1})]) | ||
=*reinterpret_cast<uint4*>(Bg_buffer); | ||
// Read from Smem: | ||
__half A[16]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Ar
for consistency.
test/cpp/jit/test_gpu.cpp
Outdated
// reduction tolerance. half reduction to float | ||
// doesn't reach full float tolerance but is better | ||
// than accumulation into half . | ||
TORCH_CHECK(refC.allclose(outC0, 0.001, 0.0001)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally our tolerances are based on comparisons with a double precision counter part. Comparing float-float requires much larger tolerances because of the high variance float results can have, so leaving the comparison in double allows us to tighten our validation tolerances which is quite helpful.
// 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate | ||
// per thread register: | ||
// A[4] x B[4] -> C[8] | ||
__device__ inline void mmaM8n8k4(float* C, __half* A, __half* B) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there something we can pass in here that's encapsulating. I don't like the idea of passing direct pointers, I'd probably even prefer to pass in uint4 or uint8's, but something like an array class would be best if possible (I thought there was something like this we could use).
^^ for the rest of the functions in this namespace
// 8x8x4 mma instruction, per quarter warp (8 threads), fp16 accumulate | ||
// per thread register: | ||
// A[4] x B[4] -> C[8] | ||
__device__ inline void mmaM8n8k4(__half* C, __half* A, __half* B) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not support fp16 accumulate at this point. It can be really dangerous in training.
// Helper function to index shared mem, | ||
// if we decided to wrap smem pointers in a class should probably move this | ||
// there. | ||
__device__ inline nvfuser_index_t linearize(const MatrixCoordinate& stride) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to figure out an interface where we can return the 2D swizzled indices and let the codegen handle the linearization. This is useful if the swizzle coordinates are not the innermost-contiguous coordinates.
} | ||
}; | ||
|
||
class SmemWriteBSwizzle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the plan for all the swizzle logic to be extracted out of the main loop? If that's the case is there any harm in doing these in a less "magic" bit-wise heavy op way and do them with natural int patterns?
close in favor of #1439 |
This PR is on cuda runtime strings only. It introduces: