Skip to content

Commit d12a90f

Browse files
committed
fragment support in double buffer
1 parent c972116 commit d12a90f

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,37 @@ Val* getProducerIndexWithPartialSplit(
416416
producer_index, SimplifyingIrBuilder::create<Int>(diff_eval.value()));
417417
}
418418

419+
Val* getAllocSizeInMmaFragments(
420+
Val* alloc_size_in_elements,
421+
const TensorView* producer_tv,
422+
const TensorView* consumer_tv) {
423+
auto mma = dynamic_cast<MmaOp*>(consumer_tv->definition());
424+
TORCH_INTERNAL_ASSERT(mma != nullptr, "consumer tv needs to be mma output");
425+
426+
bool is_a = producer_tv == mma->inA();
427+
bool is_b = producer_tv == mma->inB();
428+
429+
TORCH_INTERNAL_ASSERT(is_a || is_b, "producer tv needs to be mma input");
430+
431+
int64_t fragment_size = is_a ? getInputARegisterSize(mma->options().macro)
432+
: getInputBRegisterSize(mma->options().macro);
433+
434+
auto fragment_size_val = SimplifyingIrBuilder::create<Int>(fragment_size);
435+
// This is assuming we don't do dynamic shaped double buffering.
436+
// Could extend if motivating cases are seen later. This const
437+
// value is just used for sanity check anyways.
438+
auto alloc_size_const = alloc_size_in_elements->evaluateInt();
439+
440+
// This should never fail if the scheduler pass all the mma swizzle
441+
// checks. But still asserting here in case something goes really wrong.
442+
TORCH_INTERNAL_ASSERT(
443+
alloc_size_const % fragment_size == 0,
444+
"Invalid double buffer allocation for mma input");
445+
446+
return SimplifyingIrBuilder::divExpr(
447+
alloc_size_in_elements, fragment_size_val);
448+
}
449+
419450
} // namespace
420451

421452
void IndexCompute::handle(Split* split) {
@@ -1919,6 +1950,16 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
19191950
loop_index, SimplifyingIrBuilder::create<Int>(2));
19201951
auto original_alloc_size =
19211952
gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv);
1953+
// Modify allocation size for mma fragments:
1954+
if (std::any_of(
1955+
consumer_tv->domain()->domain().begin(),
1956+
consumer_tv->domain()->domain().end(),
1957+
[](IterDomain* id) { return id->isMma(); })) {
1958+
// Double buffer switching will need to be modified since
1959+
// the iteration is in units of fragment instead of numbers.
1960+
original_alloc_size = getAllocSizeInMmaFragments(
1961+
original_alloc_size, producer_tv, consumer_tv);
1962+
}
19221963
auto db_strided_index =
19231964
SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
19241965
strided_inds.push_back(db_strided_index);

0 commit comments

Comments
 (0)