@@ -416,6 +416,37 @@ Val* getProducerIndexWithPartialSplit(
416
416
producer_index, SimplifyingIrBuilder::create<Int>(diff_eval.value ()));
417
417
}
418
418
419
+ Val* getAllocSizeInMmaFragments (
420
+ Val* alloc_size_in_elements,
421
+ const TensorView* producer_tv,
422
+ const TensorView* consumer_tv) {
423
+ auto mma = dynamic_cast <MmaOp*>(consumer_tv->definition ());
424
+ TORCH_INTERNAL_ASSERT (mma != nullptr , " consumer tv needs to be mma output" );
425
+
426
+ bool is_a = producer_tv == mma->inA ();
427
+ bool is_b = producer_tv == mma->inB ();
428
+
429
+ TORCH_INTERNAL_ASSERT (is_a || is_b, " producer tv needs to be mma input" );
430
+
431
+ int64_t fragment_size = is_a ? getInputARegisterSize (mma->options ().macro )
432
+ : getInputBRegisterSize (mma->options ().macro );
433
+
434
+ auto fragment_size_val = SimplifyingIrBuilder::create<Int>(fragment_size);
435
+ // This is assuming we don't do dynamic shaped double buffering.
436
+ // Could extend if motivating cases are seen later. This const
437
+ // value is just used for sanity check anyways.
438
+ auto alloc_size_const = alloc_size_in_elements->evaluateInt ();
439
+
440
+ // This should never fail if the scheduler pass all the mma swizzle
441
+ // checks. But still asserting here in case something goes really wrong.
442
+ TORCH_INTERNAL_ASSERT (
443
+ alloc_size_const % fragment_size == 0 ,
444
+ " Invalid double buffer allocation for mma input" );
445
+
446
+ return SimplifyingIrBuilder::divExpr (
447
+ alloc_size_in_elements, fragment_size_val);
448
+ }
449
+
419
450
} // namespace
420
451
421
452
void IndexCompute::handle (Split* split) {
@@ -1919,6 +1950,16 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
1919
1950
loop_index, SimplifyingIrBuilder::create<Int>(2 ));
1920
1951
auto original_alloc_size =
1921
1952
gpu_lower->doubleBufferInfo ().getOriginalAllocSize (producer_tv);
1953
+ // Modify allocation size for mma fragments:
1954
+ if (std::any_of (
1955
+ consumer_tv->domain ()->domain ().begin (),
1956
+ consumer_tv->domain ()->domain ().end (),
1957
+ [](IterDomain* id) { return id->isMma (); })) {
1958
+ // Double buffer switching will need to be modified since
1959
+ // the iteration is in units of fragment instead of numbers.
1960
+ original_alloc_size = getAllocSizeInMmaFragments (
1961
+ original_alloc_size, producer_tv, consumer_tv);
1962
+ }
1922
1963
auto db_strided_index =
1923
1964
SimplifyingIrBuilder::mulExpr (db_switch_index, original_alloc_size);
1924
1965
strided_inds.push_back (db_strided_index);
0 commit comments