Skip to content

Conversation

shmsong
Copy link

@shmsong shmsong commented Feb 8, 2022

This PR is a minimalist approach to integrating Volta mma in NVFuser. The goal is to natively support mma thread mapping through iter domain ops, and meanwhile rely on the existing iterdomain math as much as possible.

It turns out that all thread swizzles involved in volta through ampere are affine/inverse affine maps, meaning that they can be expressed as sequences of split and merges. This fact can be used to simplify the amount of generic swizzle composition needed in a gemm kernel and meanwhile also facilitates generic iterdomain checking and optimization on mma thread swizzles.

This PR focuses on the affine part of the swizzle mapping involved in mma and non-affine components will be introduced in a follow up.

In this PR:

  • A definition of MmaOp, which is a binary tensor op with reduction domain on output.
  • An inventory infra and user interface (MmaBuilder) that provides bookkeeping of mma type and data layout.
  • A scheduling interface (mma_util::WarpMmaScheduler) that transforms the innermost iterdomains of mma input and output to the correct thread swizzle format for the corresponding mma type. The input to the scheduler needs to follow the "m,n,k" convention which will be checked by WarpMmaScheduler before any transform.
  • Validation passes for mma op.
  • Resource string and corresponding warp mapping for Volta mma's.

@shmsong shmsong requested review from csarofeen and naoyam February 8, 2022 17:22
@shmsong shmsong changed the title Mma operator and volta mma integration [Do not merge] Mma operator and volta mma integration Feb 10, 2022
@naoyam
Copy link
Collaborator

naoyam commented Feb 12, 2022

I checked pretty much all the changes. Really exciting to have the MMA capability in our system!

My biggest concern is the same as Christian's. User fusions, scheduling and lowering could be made more cleanly separated, and that would be really important to keep the overall complexity of our system manageable.

Another thing I don't agree with the design is IterDomain::is_instruciton_ and IterDomain::is_warp_mapped_ as well as related added function parameters like is_instruciton_loop. Please see specific inline comments.

Another thing that's not really fundamental but still confusing to me is the terminology of "mma" and "gemm". The mma expression is configured with MmaBuilder that takes GemmTileOptions as a parameter. Can't we just call everything MMA?

I kind of see "mma" could be used to indicate a unit of operations, whereas gemm would mean an overall matrix multiplication, which can use mma but is not required to do so. But in that sense, the mma expression should be renamed to gemm.

@shmsong shmsong changed the title [Do not merge] Mma operator and volta mma integration Mma operator and volta mma integration Feb 22, 2022
@shmsong shmsong changed the title Mma operator and volta mma integration WIP: Mma operator and volta mma integration Feb 22, 2022
@shmsong shmsong changed the title WIP: Mma operator and volta mma integration Mma operator and volta mma integration Feb 22, 2022
@shmsong shmsong requested review from csarofeen and naoyam February 22, 2022 09:13
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded NIT: sometimes your comments seem to get broken up where a line will split to a new line and get indented. I don't think this was done on purpose but found it in a few spots.

Looks good to me! Just minor comments.

tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::BIDy);
tv2->axis(2)->parallelize(ParallelType::TIDz);
tv2->axis(3)->parallelize(ParallelType::TIDy);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can transform propagator or parallelize all like reduce any of the above code?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. All schedules are un-necessarily manual at this stage. I will need to add a few more options in transform propagator, especially the option to specify propagation boundaries to automate this part. Will be addressed in a follow up.

bool isReduction = false;
if (axis_iter != axes_set.end() && *axis_iter == dim) {
isReduction = true;
axis_iter++;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why increment one by one instead of just checking if it's in the axes_set?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated code from newForReduction. axes_set is ordered so I guessed that was the reason why it was originally incrementing.

// will lift this in a follow up when we have a
// more generic axes matching.
TORCH_CHECK(
axes.size() == 1, "Single axis reduction only for mma op instantiation.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should contiguity of the input tensors be checked?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we'd want to limit to contiguous tensors since we could hypothetically still schedule non-contiguous tensors into registers and run mma. Not yet sure if we'd practically want to go for perf in this case.

(loop->iter_domain()->isThread() && is_local &&
(same_parallel_type ||
(within_mma_loops &&
loop->iter_domain()->getParallelType() == ParallelType::TIDx)))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this additional condition required?

Would it make more sense to extend find_matching_parallel_domain so that same_parallel_type would include the case with MMA?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically the swizzled version of same_parallel_type across the mma op on the lane_id dimension. The swizzled lane_id iterdomain is up to the mma instruction and in almost all cases they are not naturally mapped from input to output. It's just the way they are defined and I guess it also changes in each gen of hardware.

I have moved this part to same_parallel_type.

I will need to think about how to meaningfully extend finding_matching_parallel_domain to cover this case. Could probably try to address this in a follow up if that's ok.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can you please add this comment in the code?

@@ -359,6 +360,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
bool is_vector_op = false;
size_t vector_word_size = 1;

if (uop->out()->isA<kir::TensorIndex>()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is an op to initialize the output tensor of an MmaOp.

The condition isn't very intuitive to me. It uses the fact that the definition of the tensor has the original MmaOp expression rather than its initialization, which is true but seems like an implementation detail.

Wouldn't it be more robust to look at the parallel type of IterDomains? There should be an IterDomain with ParallelType::Mma, and isn't it be possible to use that info to determine if this is an initialization op of a Mma output tensor? I believe that's how we generate vectorized code for UnaryOps with vectorized IterDomains.

@shmsong shmsong requested review from csarofeen and naoyam March 17, 2022 06:54
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Great work!

@shmsong shmsong merged commit 6df7b77 into devel Mar 21, 2022
@shmsong shmsong deleted the volta_mma_op branch March 21, 2022 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants