@@ -382,13 +382,17 @@ void IndexLowering::handleGridReduction(
382
382
const auto buffer_size_info =
383
383
getGridCommWorkBufferSize (out_domain, for_loops_, is_persistent);
384
384
385
- const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm (
386
- buffer_size_info.size_of_privatized_buffer , out->dtype (), false );
387
-
388
- const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm (
389
- getGridSyncBufferSize (out_domain, for_loops_, is_persistent),
390
- DataType::Int,
391
- true );
385
+ auto work_buffer = allocateUniqueBuffer (
386
+ buffer_size_info.size_of_privatized_buffer ,
387
+ out_tv->dtype (),
388
+ false ,
389
+ out_tv,
390
+ work_buffer_map_);
391
+
392
+ auto sync_buffer_size =
393
+ getGridSyncBufferSize (out_domain, for_loops_, is_persistent);
394
+ auto sync_buffer = allocateUniqueBuffer (
395
+ sync_buffer_size, DataType::Int, true , out_tv, sync_buffer_map_);
392
396
393
397
const auto entrance_ind = !is_persistent
394
398
? getEntranceLinIndGridReduce (for_loops_)
@@ -408,7 +412,7 @@ void IndexLowering::handleGridReduction(
408
412
rop->init (),
409
413
out,
410
414
in,
411
- reduce_buffer ,
415
+ work_buffer ,
412
416
sync_buffer,
413
417
entrance_ind,
414
418
n_entrances,
@@ -423,17 +427,11 @@ void IndexLowering::handleGridReduction(
423
427
grid_reduction->setWritePredicate (rop->writePredicate ());
424
428
}
425
429
426
- pushBack (reduce_buffer);
427
- pushBack (sync_buffer);
428
430
pushBack (grid_reduction);
429
431
GpuLower::current ()->propagateExprInfo (rop, back ());
430
432
431
433
if (rop->isAllreduce ()) {
432
- // When using the fused reduction, allocate the reduction object at
433
- // the outer-most scope
434
- auto fused_reduction_alloc_reduction =
435
- IrBuilder::create<kir::AllocateFusedReduction>(grid_reduction);
436
- insertAtTopLevel (fused_reduction_alloc_reduction);
434
+ allocateUniqueFusedReduction (grid_reduction, out_tv);
437
435
}
438
436
}
439
437
@@ -521,22 +519,24 @@ void IndexLowering::handleGridReduction(
521
519
auto work_buf_size_info =
522
520
getGridCommWorkBufferSize (out_domain, for_loops_, is_persistent);
523
521
524
- std::vector<kir::Allocate*> reduce_buffers ;
522
+ std::vector<kir::Allocate*> work_buffers ;
525
523
std::transform (
526
524
outputs.begin (),
527
525
outputs.end (),
528
- std::back_inserter (reduce_buffers ),
526
+ std::back_inserter (work_buffers ),
529
527
[&](Val* output) {
530
- return ir_utils::allocGlobalBufferForGridComm (
528
+ return allocateUniqueBuffer (
531
529
work_buf_size_info.size_of_privatized_buffer ,
532
530
output->dtype (),
533
- false );
531
+ false ,
532
+ output->as <kir::TensorIndex>()->view (),
533
+ work_buffer_map_);
534
534
});
535
535
536
- const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm (
537
- getGridSyncBufferSize (out_domain, for_loops_, is_persistent),
538
- DataType::Int,
539
- true );
536
+ auto sync_buffer_size =
537
+ getGridSyncBufferSize (out_domain, for_loops_, is_persistent);
538
+ auto sync_buffer = allocateUniqueBuffer (
539
+ sync_buffer_size, DataType::Int, true , out_tv, sync_buffer_map_ );
540
540
541
541
const auto entrance_ind = !is_persistent
542
542
? getEntranceLinIndGridReduce (for_loops_)
@@ -556,7 +556,7 @@ void IndexLowering::handleGridReduction(
556
556
grouped_rop->initVals (),
557
557
outputs,
558
558
inputs,
559
- reduce_buffers ,
559
+ work_buffers ,
560
560
sync_buffer,
561
561
entrance_ind,
562
562
n_entrances,
@@ -572,17 +572,11 @@ void IndexLowering::handleGridReduction(
572
572
grid_reduction->setWritePredicate (grouped_rop->writePredicate ());
573
573
}
574
574
575
- for (auto reduce_buffer : reduce_buffers) {
576
- pushBack (reduce_buffer);
577
- }
578
- pushBack (sync_buffer);
579
575
pushBack (grid_reduction);
580
576
GpuLower::current ()->propagateExprInfo (grouped_rop, back ());
581
577
582
578
if (grouped_rop->isAllreduce ()) {
583
- auto fused_reduction_alloc_reduction =
584
- IrBuilder::create<kir::AllocateFusedReduction>(grid_reduction);
585
- insertAtTopLevel (fused_reduction_alloc_reduction);
579
+ allocateUniqueFusedReduction (grid_reduction, out_tv);
586
580
}
587
581
}
588
582
@@ -672,17 +666,29 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
672
666
getGridCommWorkBufferSize (out_domain, for_loops_, is_persistent);
673
667
674
668
const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer ;
675
- const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm (
676
- work_buffer_size, indexed_wop->outVar ()->dtype (), false );
677
- const auto out_avg_buffer = ir_utils::allocGlobalBufferForGridComm (
678
- work_buffer_size, indexed_wop->outAvg ()->dtype (), false );
679
- const auto out_N_buffer = ir_utils::allocGlobalBufferForGridComm (
680
- work_buffer_size, indexed_wop->outN ()->dtype (), false );
681
-
682
- const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm (
683
- getGridSyncBufferSize (out_domain, for_loops_, is_persistent),
684
- DataType::Int,
685
- true );
669
+ auto out_var_buffer = allocateUniqueBuffer (
670
+ work_buffer_size,
671
+ indexed_wop->outVar ()->dtype (),
672
+ false ,
673
+ indexed_wop->outVar ()->as <kir::TensorIndex>()->view (),
674
+ work_buffer_map_);
675
+ auto out_avg_buffer = allocateUniqueBuffer (
676
+ work_buffer_size,
677
+ indexed_wop->outAvg ()->dtype (),
678
+ false ,
679
+ indexed_wop->outAvg ()->as <kir::TensorIndex>()->view (),
680
+ work_buffer_map_);
681
+ auto out_N_buffer = allocateUniqueBuffer (
682
+ work_buffer_size,
683
+ indexed_wop->outN ()->dtype (),
684
+ false ,
685
+ indexed_wop->outN ()->as <kir::TensorIndex>()->view (),
686
+ work_buffer_map_);
687
+
688
+ auto sync_buffer_size =
689
+ getGridSyncBufferSize (out_domain, for_loops_, is_persistent);
690
+ auto sync_buffer = allocateUniqueBuffer (
691
+ sync_buffer_size, DataType::Int, true , out_tv, sync_buffer_map_);
686
692
687
693
const auto entrance_ind = !is_persistent
688
694
? getEntranceLinIndGridReduce (for_loops_)
@@ -729,19 +735,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
729
735
GpuLower::current ()->propagateExprInfo (indexed_wop, back ());
730
736
}
731
737
732
- pushBack (out_var_buffer);
733
- pushBack (out_avg_buffer);
734
- pushBack (out_N_buffer);
735
- pushBack (sync_buffer);
736
738
pushBack (grid_welford);
737
739
GpuLower::current ()->propagateExprInfo (indexed_wop, back ());
738
740
739
741
if (indexed_wop->isAllreduce ()) {
740
742
// When using the fused reduction, allocate the reduction object at
741
743
// the outer-most scope
742
- auto fused_reduction_alloc_reduction =
743
- IrBuilder::create<kir::AllocateFusedReduction>(grid_welford);
744
- insertAtTopLevel (fused_reduction_alloc_reduction);
744
+ allocateUniqueFusedReduction (grid_welford, out_tv);
745
745
}
746
746
}
747
747
@@ -792,24 +792,24 @@ void IndexLowering::handle(const BroadcastOp* bop) {
792
792
793
793
// Grid broadcast
794
794
const auto out_domain = out_tv->domain ();
795
- const auto broadcast_buffer = ir_utils::allocGlobalBufferForGridComm (
795
+ const auto work_buffer_size =
796
796
getGridCommWorkBufferSize (out_domain, for_loops_, true )
797
- .size_of_privatized_buffer ,
798
- out->dtype (),
799
- false );
797
+ .size_of_privatized_buffer ;
798
+
799
+ auto work_buffer = allocateUniqueBuffer (
800
+ work_buffer_size, out->dtype (), false , out_tv, work_buffer_map_);
800
801
801
- const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm (
802
- getGridSyncBufferSize (out_domain, for_loops_, true ), DataType::Int, true );
802
+ auto sync_buffer_size = getGridSyncBufferSize (out_domain, for_loops_, true );
803
+ auto sync_buffer = allocateUniqueBuffer (
804
+ sync_buffer_size, DataType::Int, true , out_tv, sync_buffer_map_);
803
805
804
806
auto grid_broadcast = IrBuilder::create<kir::GridBroadcast>(
805
- indexed_expr, broadcast_buffer , sync_buffer);
807
+ indexed_expr, work_buffer , sync_buffer);
806
808
807
809
if (bop->predicate ()) {
808
810
grid_broadcast->setPredicate (bop->predicate ());
809
811
}
810
812
811
- pushBack (broadcast_buffer);
812
- pushBack (sync_buffer);
813
813
pushBack (grid_broadcast);
814
814
GpuLower::current ()->propagateExprInfo (bop, back ());
815
815
}
@@ -840,6 +840,69 @@ void IndexLowering::generate(const std::vector<Expr*>& exprs) {
840
840
}
841
841
}
842
842
843
+ kir::Allocate* IndexLowering::allocateUniqueBuffer (
844
+ Val* buffer_size,
845
+ DataType dtype,
846
+ bool zero_init,
847
+ TensorView* out_tv,
848
+ std::unordered_map<TensorView*, kir::Allocate*>& alloc_map) {
849
+ // Return an existing allocation if exists
850
+ auto it = alloc_map.find (out_tv);
851
+ if (it != alloc_map.end ()) {
852
+ return it->second ;
853
+ }
854
+
855
+ // No existing allocation found. Create a new one
856
+ auto new_buffer =
857
+ ir_utils::allocGlobalBufferForGridComm (buffer_size, dtype, zero_init);
858
+
859
+ // Keep track of the allocation
860
+ alloc_map.emplace (out_tv, new_buffer);
861
+
862
+ // A buffer may be used in both the unswitched paths, so it must be
863
+ // placed outside of the current scope. Simplying placing it at the
864
+ // top-level scope should work.
865
+ insertAtTopLevel (new_buffer);
866
+
867
+ return new_buffer;
868
+ }
869
+
870
+ void IndexLowering::allocateUniqueFusedReduction (
871
+ Expr* expr,
872
+ TensorView* out_tv) {
873
+ auto it = fused_reduction_map_.find (out_tv);
874
+ if (it != fused_reduction_map_.end ()) {
875
+ return ;
876
+ }
877
+
878
+ kir::AllocateFusedReduction* fused_reduction_alloc_reduction = nullptr ;
879
+ switch (expr->getExprType ().value ()) {
880
+ case ExprType::GridReduction:
881
+ fused_reduction_alloc_reduction =
882
+ IrBuilder::create<kir::AllocateFusedReduction>(
883
+ expr->as <kir::GridReduction>());
884
+ break ;
885
+ case ExprType::GridWelford:
886
+ fused_reduction_alloc_reduction =
887
+ IrBuilder::create<kir::AllocateFusedReduction>(
888
+ expr->as <kir::GridWelford>());
889
+ break ;
890
+ case ExprType::GroupedGridReduction:
891
+ fused_reduction_alloc_reduction =
892
+ IrBuilder::create<kir::AllocateFusedReduction>(
893
+ expr->as <kir::GroupedGridReduction>());
894
+ break ;
895
+ default :
896
+ TORCH_INTERNAL_ASSERT (false , " Invalid expr: " , expr->toString ());
897
+ }
898
+
899
+ fused_reduction_map_.emplace (out_tv, fused_reduction_alloc_reduction);
900
+
901
+ // When using the fused reduction, allocate the reduction object at
902
+ // the outer-most scope
903
+ insertAtTopLevel (fused_reduction_alloc_reduction);
904
+ }
905
+
843
906
} // namespace cuda
844
907
} // namespace fuser
845
908
} // namespace jit
0 commit comments