-
Notifications
You must be signed in to change notification settings - Fork 7
Mma operator and volta mma integration #1439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I checked pretty much all the changes. Really exciting to have the MMA capability in our system! My biggest concern is the same as Christian's. User fusions, scheduling and lowering could be made more cleanly separated, and that would be really important to keep the overall complexity of our system manageable. Another thing I don't agree with the design is Another thing that's not really fundamental but still confusing to me is the terminology of "mma" and "gemm". The I kind of see "mma" could be used to indicate a unit of operations, whereas gemm would mean an overall matrix multiplication, which can use mma but is not required to do so. But in that sense, the |
2b4f3b0
to
edd43d9
Compare
194881c
to
ddac459
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expanded NIT: sometimes your comments seem to get broken up where a line will split to a new line and get indented. I don't think this was done on purpose but found it in a few spots.
Looks good to me! Just minor comments.
tv2->axis(0)->parallelize(ParallelType::BIDx); | ||
tv2->axis(1)->parallelize(ParallelType::BIDy); | ||
tv2->axis(2)->parallelize(ParallelType::TIDz); | ||
tv2->axis(3)->parallelize(ParallelType::TIDy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can transform propagator or parallelize all like reduce any of the above code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. All schedules are un-necessarily manual at this stage. I will need to add a few more options in transform propagator, especially the option to specify propagation boundaries to automate this part. Will be addressed in a follow up.
bool isReduction = false; | ||
if (axis_iter != axes_set.end() && *axis_iter == dim) { | ||
isReduction = true; | ||
axis_iter++; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why increment one by one instead of just checking if it's in the axes_set
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicated code from newForReduction
. axes_set
is ordered so I guessed that was the reason why it was originally incrementing.
// will lift this in a follow up when we have a | ||
// more generic axes matching. | ||
TORCH_CHECK( | ||
axes.size() == 1, "Single axis reduction only for mma op instantiation.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should contiguity of the input tensors be checked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we'd want to limit to contiguous tensors since we could hypothetically still schedule non-contiguous tensors into registers and run mma. Not yet sure if we'd practically want to go for perf in this case.
(loop->iter_domain()->isThread() && is_local && | ||
(same_parallel_type || | ||
(within_mma_loops && | ||
loop->iter_domain()->getParallelType() == ParallelType::TIDx)))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this additional condition required?
Would it make more sense to extend find_matching_parallel_domain
so that same_parallel_type
would include the case with MMA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is basically the swizzled version of same_parallel_type
across the mma op on the lane_id dimension. The swizzled lane_id iterdomain is up to the mma instruction and in almost all cases they are not naturally mapped from input to output. It's just the way they are defined and I guess it also changes in each gen of hardware.
I have moved this part to same_parallel_type
.
I will need to think about how to meaningfully extend finding_matching_parallel_domain
to cover this case. Could probably try to address this in a follow up if that's ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Can you please add this comment in the code?
@@ -359,6 +360,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { | |||
bool is_vector_op = false; | |||
size_t vector_word_size = 1; | |||
|
|||
if (uop->out()->isA<kir::TensorIndex>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is an op to initialize the output tensor of an MmaOp.
The condition isn't very intuitive to me. It uses the fact that the definition of the tensor has the original MmaOp expression rather than its initialization, which is true but seems like an implementation detail.
Wouldn't it be more robust to look at the parallel type of IterDomains? There should be an IterDomain with ParallelType::Mma
, and isn't it be possible to use that info to determine if this is an initialization op of a Mma output tensor? I believe that's how we generate vectorized code for UnaryOps with vectorized IterDomains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Great work!
This PR is a minimalist approach to integrating Volta mma in NVFuser. The goal is to natively support mma thread mapping through iter domain ops, and meanwhile rely on the existing iterdomain math as much as possible.
It turns out that all thread swizzles involved in volta through ampere are affine/inverse affine maps, meaning that they can be expressed as sequences of split and merges. This fact can be used to simplify the amount of generic swizzle composition needed in a gemm kernel and meanwhile also facilitates generic iterdomain checking and optimization on mma thread swizzles.
This PR focuses on the affine part of the swizzle mapping involved in mma and non-affine components will be introduced in a follow up.
In this PR:
MmaBuilder
) that provides bookkeeping of mma type and data layout.mma_util::WarpMmaScheduler
) that transforms the innermost iterdomains of mma input and output to the correct thread swizzle format for the corresponding mma type. The input to the scheduler needs to follow the "m,n,k" convention which will be checked by WarpMmaScheduler before any transform.