Skip to content

Conversation

shmsong
Copy link

@shmsong shmsong commented Nov 4, 2021

This PR is on cuda runtime strings only. It introduces:

  • MMA intrinsic wrapper for sm70
  • MMA macro definition using the intrinsics (please refer to design doc for the macro mapping)
  • MMA read and write swizzling for operand and result
  • Swizzling cuda infrastructure for composite swizzle funtions
  • A mini-gemm testcase demonstrating how the resource could be used

@shmsong shmsong changed the title [WIP] Minimal set of MMA resource string and mma swizzle Minimal set of MMA resource string and mma swizzle Nov 19, 2021
@shmsong shmsong requested review from csarofeen and naoyam November 19, 2021 05:43
testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__);
}

TEST(NVFuserTest, FusionMMASwizzlePrimitiveFloatAcc_CUDA) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we may want to put this to a new test file (e.g., "test_gpu_mma.cpp")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will also need to disable this test for <Volta archs.

// reduction tolerance. half reduction to float
// doesn't reach full float tolerance but is better
// than accumulation into half .
TORCH_CHECK(refC.allclose(outC0, 0.001, 0.0001));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the reference version in double precision, and the check passed with the default tolerance.

 auto refC = inA.to(at::kDouble).matmul(inB.to(at::kDouble)).to(at::kFloat);

Copy link
Author

@shmsong shmsong Nov 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. Sorry I didn't put this on the comment. This mini gemm case does pass with float/double version of cublas as reference. I could use the test_validate here if we want. But it would fail again with larger k (on the gemm integration PR). It didn't seem to be able to consistently meet our precision targets, so I just used manually set ones for everything.

Copy link
Author

@shmsong shmsong Nov 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the reference to use the double version. Wouldn't want cublas to accumulate in half.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally our tolerances are based on comparisons with a double precision counter part. Comparing float-float requires much larger tolerances because of the high variance float results can have, so leaving the comparison in double allows us to tighten our validation tolerances which is quite helpful.

// Allocate Smem
__shared__ uint4 As_mem[32*8 / 8];
__shared__ uint4 Bs_mem[8*32 / 8];
auto As = reinterpret_cast<__half*>(As_mem);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious here, does it make any difference if it's cast to __shared__ half*? I remember I've read somewhere that nvcc would use a generic load ptx instruction as opposed to a shmem load instruction if a pointer is not annotated as shared. This is just a test function, so it doesn't matter at all, but I'm just curious if we should generally be aware of pointer types .

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out.

It looks like nvcc does some inference here in this case. But there's no public docs on how powerful these state space inference can go.

        st.shared.u32 	[%rd66], %r46;
	st.shared.u32 	[%rd66+4], %r47;
	st.shared.u32 	[%rd66+8], %r48;
	st.shared.u32 	[%rd66+12], %r49;
	...
	ld.shared.v4.u32 	{%r66, %r67, %r68, %r69}, [%rd75];

I guess it wouldn't hurt to add them. We could consider cleaning these up in follow ups since we do something similar for other shared mem allocs too. https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/codegen.cpp#L1276-L1279

// Mini prolog: (single warp)
// global load:
__half Ag_buffer[8];
*reinterpret_cast<uint4*>(Ag_buffer) = *reinterpret_cast<uint4*>(&Ag[threadIdx.x*8]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be simpler to use uint4 rather than __half?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can change that. Thanks for pointing this out 👍

@@ -0,0 +1,410 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all uses of nvfuser_index_t, do they need to be that type? Can any of them be safely replaced with just int? For example, indexing inside a 32x8 tile should be safe with int.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will replace them all with int when we get to a point when these swizzling can be done outside of main loop.

// (cf. ISA 9.7.13.4)
// swizzling at write store so load can be less integer heavy.
// switch bit 0 and bit 1 of row_group_idx
const nvfuser_index_t row_group_swizzled = (row_group_idx & ~0b11) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is nvfuser_index_t assumed to be int?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just looks like a magic to me. Is this technique also used in CUTLASS?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a macro that could be either int64_t or int.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think similar computations are around in other libs:
https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/layout/tensor_op_multiplicand_sm70.h#L811-L813
We are not exactly the same though. I just chose ways that are more conveniently mapped to fuser pipeline.

Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally this looks good to me, though I think we'll want to do a few iterations on the resource strings as we make progress.

// The coordinate is what the Ag_buffer data
// corresponds to in un-swizzled layout
// mostly compatible with nvfuser indexing
{(int)threadIdx.x, 0}, lane_id).linearize({8,1})])
Copy link
Owner

@csarofeen csarofeen Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need lane_id everywhere? is this just so the indexing math is simpler in the swizzle? Seems we can just always pass in threadIdx.x instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need lane_id in multiple places, currently in operand A/B read/write, and output un-swizzle, that'd be 5 uses, so I kind of preferred to define it just once. But using tidx also works for me if that looks cleaner.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we'd go across warps with tidx.y or tidx.z which would make it a strict tidx.x, but if not, then this would be fine

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming blockDim.x a multiple of 32 at the moment and we could try removing that constraint if it is beneficial. It'd require some generalization effort.

mem_swizzle::swizzle_sequence<
// This is the instruction mem layout
// for mma32X32X8
warp_mma::mmaM32N32K8::SmemWriteASwizzle>(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there 4 independent swizzles for reg->smem looking across [N,T], [N,T]? I thought there were just 2 i.e. An == Bt.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there will be just two. I haven't built out all the transposition yet. They'd need a few other mma instruction and swizzle options in here. Currently only focusing on getting TT gemm fast. I will need to rename the swizzle a bit to keep the categories organized once we have more of them.

*reinterpret_cast<uint4*>(
&As[
// Utility to compose a sequence of swizzles
mem_swizzle::swizzle_sequence<
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by a swizzle_sequence?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be useful when we come to the bank conflict removal step. My current understanding now is that we have swizzle for mma and swizzle for bank conflict (will be in the next few PRs). The former is a function of the instruction we are using and the latter is a function of CTA tile size and vectorization. So swizzle sequence is an interface for fuser to configure them separately and the kernel will apply them sequentially.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's just a mechanism to go from one swizzle format to another?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be one way to use this utility as well.

The scenario I was thinking was that the swizzled format in smem in gemm is always a composition of 2 because we want to satisfy mma requirement and avoid bank conflict at the same time.

if say we have a smem dimension of 128x32 f16, with 128b vectorization, and storing a mma operandA, then we need a swizzle that satisfies mma requirement and avoids bank conflict when we load a column.

This could be done with:
cyclic shift blocks of 128b by 8, every 2 rows(in a separate PR) and then mma operand A swizzle.

So swizzle sequence was used as a utility to configure these 2 separately and generate the composed format.

warp_mma::mmaM32N32K8::SmemWriteBSwizzle>(
// The coordinate is what the Bg_buffer data
// corresponds to in un-swizzled layout
// mostly compatible with nvfuser indexing
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the native nvFuser mapping this corresponds to?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a warp for a 32x32x8. Is there a specific way this is "viewed" by nvFuser? Or is this really focusing on a warp programming methodology?

I think maybe the way we'd really want to think about this, is if we didn't have mma, how would we think about mapping threadIdx.x to the problem?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to solve this now, but I would think we'd want mma to be a "post process" on a non-mma (more naive) matmul setup. I would think the mma part of a schedule would be best as a very late transformation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was thinking of this as a pattern matching and as a post process. The current setup I have, on #1274 is essentially splitting out the inner loops and saving them for mma, i.e. [ ... , M_mma(32), N_mma(32), K_mma(8)], and this is the ground truth fuser mapping for pattern matching.

To account for the "warp instruction" nature of mma: The 32x32 accumulator is actually distributed in a warp, so I just chose m dimension to bind to warp, i.e.
[ ... , M_mma.warp(32), N_mma(32), K_mma(8)], but warp wasn't modeled explicitly but just later merged into tidx.

.linearize({32,1})])
=*reinterpret_cast<uint4*>(Bg_buffer);
// Read from Smem:
__half A[16];
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Ar for consistency.

// reduction tolerance. half reduction to float
// doesn't reach full float tolerance but is better
// than accumulation into half .
TORCH_CHECK(refC.allclose(outC0, 0.001, 0.0001));
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally our tolerances are based on comparisons with a double precision counter part. Comparing float-float requires much larger tolerances because of the high variance float results can have, so leaving the comparison in double allows us to tighten our validation tolerances which is quite helpful.

// 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate
// per thread register:
// A[4] x B[4] -> C[8]
__device__ inline void mmaM8n8k4(float* C, __half* A, __half* B) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there something we can pass in here that's encapsulating. I don't like the idea of passing direct pointers, I'd probably even prefer to pass in uint4 or uint8's, but something like an array class would be best if possible (I thought there was something like this we could use).

^^ for the rest of the functions in this namespace

// 8x8x4 mma instruction, per quarter warp (8 threads), fp16 accumulate
// per thread register:
// A[4] x B[4] -> C[8]
__device__ inline void mmaM8n8k4(__half* C, __half* A, __half* B) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not support fp16 accumulate at this point. It can be really dangerous in training.

// Helper function to index shared mem,
// if we decided to wrap smem pointers in a class should probably move this
// there.
__device__ inline nvfuser_index_t linearize(const MatrixCoordinate& stride) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to figure out an interface where we can return the 2D swizzled indices and let the codegen handle the linearization. This is useful if the swizzle coordinates are not the innermost-contiguous coordinates.

}
};

class SmemWriteBSwizzle {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan for all the swizzle logic to be extracted out of the main loop? If that's the case is there any harm in doing these in a less "magic" bit-wise heavy op way and do them with natural int patterns?

@shmsong
Copy link
Author

shmsong commented Feb 8, 2022

close in favor of #1439

@shmsong shmsong closed this Feb 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants