1
1
#include < torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
2
2
3
3
#include < torch/csrc/jit/codegen/cuda/executor_utils.h>
4
+ #include < torch/csrc/jit/codegen/cuda/inline_propagator.h>
4
5
#include < torch/csrc/jit/codegen/cuda/instrumentation.h>
5
6
#include < torch/csrc/jit/codegen/cuda/ir_iostream.h>
6
7
#include < torch/csrc/jit/codegen/cuda/ir_utils.h>
13
14
14
15
#include < ATen/cuda/CUDAContext.h>
15
16
17
+ #include < algorithm>
18
+ #include < unordered_map>
19
+
16
20
namespace torch {
17
21
namespace jit {
18
22
namespace fuser {
@@ -671,6 +675,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
671
675
}
672
676
}
673
677
678
+ int64_t unswitch_pos;
674
679
if (params.break_point ) {
675
680
// 2D parallelization scheme
676
681
TORCH_INTERNAL_ASSERT (rhs_i >= 0 && lhs_i >= 0 );
@@ -723,8 +728,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
723
728
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx]
724
729
reference_tv->split (0 , 65535 );
725
730
reference_tv->axis (1 )->parallelize (ParallelType::BIDy);
731
+ unswitch_pos = 5 ;
726
732
} else {
727
733
reference_tv->axis (0 )->parallelize (ParallelType::BIDy);
734
+ unswitch_pos = 4 ;
728
735
}
729
736
} else {
730
737
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx]
@@ -734,8 +741,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
734
741
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx]
735
742
reference_tv->split (1 , 65535 );
736
743
reference_tv->axis (2 )->parallelize (ParallelType::BIDy);
744
+ unswitch_pos = 5 ;
737
745
} else {
738
746
reference_tv->axis (1 )->parallelize (ParallelType::BIDy);
747
+ unswitch_pos = 4 ;
739
748
}
740
749
}
741
750
} else {
@@ -747,8 +756,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
747
756
// [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx]
748
757
reference_tv->split (0 , 65535 );
749
758
reference_tv->axis (1 )->parallelize (ParallelType::BIDy);
759
+ unswitch_pos = 4 ;
750
760
} else {
751
761
reference_tv->axis (0 )->parallelize (ParallelType::BIDy);
762
+ unswitch_pos = 3 ;
752
763
}
753
764
} else {
754
765
// [BIDx | BIDy | Unswitch, Unroll, TIDx]
@@ -757,8 +768,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
757
768
// [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx]
758
769
reference_tv->split (1 , 65535 );
759
770
reference_tv->axis (2 )->parallelize (ParallelType::BIDy);
771
+ unswitch_pos = 4 ;
760
772
} else {
761
773
reference_tv->axis (1 )->parallelize (ParallelType::BIDy);
774
+ unswitch_pos = 3 ;
762
775
}
763
776
}
764
777
}
@@ -803,10 +816,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
803
816
reference_tv->axis (1 )->parallelize (ParallelType::Unswitch);
804
817
reference_tv->axis (3 )->parallelize (ParallelType::TIDx);
805
818
}
819
+ unswitch_pos = 2 ;
806
820
}
807
821
808
822
TransformPropagator propagator (reference_tv);
809
- MaxRootDomainInfoSpanningTree (reference_tv).traverse (&propagator);
823
+ MaxRootDomainInfoSpanningTree spanning_tree (reference_tv);
824
+ spanning_tree.traverse (&propagator);
810
825
scheduler_utils::parallelizeAllLike (reference_tv, all_tvs);
811
826
812
827
if (params.vectorize ) {
@@ -841,84 +856,31 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
841
856
}
842
857
}
843
858
844
- // Compute at into cached inputs
845
- std::vector<TensorView*> consumers_of_cached_inputs;
846
- // Cache of input, and one of its consumers
847
- std::vector<std::pair<TensorView*, TensorView*>> input_cache_and_consumer;
848
- {
849
- // Avoid duplicate additions, so track what we add
850
- std::unordered_set<TensorView*> added;
851
- for (auto cached_input : cached_inputs) {
852
- auto consumer_tvs = ir_utils::consumerTvsOf (cached_input);
853
- TORCH_INTERNAL_ASSERT (
854
- consumer_tvs.size (),
855
- " Input was not succesfully filtered out for scheduling but wasn't used." );
856
-
857
- // Grab a consumer which will be used for computeAt structure of cached
858
- // input into a consumer
859
- input_cache_and_consumer.emplace_back (
860
- std::make_pair (cached_input, consumer_tvs[0 ]));
861
-
862
- // Grab all consumers which will be used for inlining computeAt for the
863
- // body of the computation (excluding caching inputs/outputs)
864
- for (auto consumer_tv : consumer_tvs) {
865
- // Don't duplicate
866
- if (added.insert (consumer_tv).second ) {
867
- consumers_of_cached_inputs.emplace_back (consumer_tv);
868
- }
869
- }
870
- }
871
- }
872
-
873
- for (auto entry : input_cache_and_consumer) {
874
- // Compute at inside unswitch position:
875
- auto input_cache = entry.first ;
876
- auto input_cache_consumer = entry.second ;
877
-
878
- auto unswitch_it = std::find_if (
879
- input_cache_consumer->domain ()->domain ().begin (),
880
- input_cache_consumer->domain ()->domain ().end (),
881
- [](IterDomain* id) {
882
- return id->getParallelType () == ParallelType::Unswitch;
883
- });
884
- auto unswitch_pos =
885
- unswitch_it == input_cache_consumer->domain ()->domain ().end ()
886
- ? -1
887
- : std::distance (
888
- input_cache_consumer->domain ()->domain ().begin (), unswitch_it) +
889
- 1 ;
859
+ // Begin by inlining at the unswitch position for the entire DAG. The cached
860
+ // inputs, and outputs will keep this inline position, but other tensors will
861
+ // get a higher position in later inline propagation.
862
+ InlinePropagator inline_unswitch (
863
+ reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
864
+ spanning_tree.traverse (&inline_unswitch);
890
865
891
- input_cache->computeAt (
892
- input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort);
866
+ // Inline at the inner most position. The CA position of all tensors except
867
+ // inputs, cached inputs and outputs will be updated.
868
+ std::unordered_set<TensorView*> inner_most_tensors (
869
+ all_tvs.begin (), all_tvs.end ());
870
+ for (auto cached_input : cached_inputs) {
871
+ inner_most_tensors.erase (cached_input);
893
872
}
894
-
895
- // Producers for inlined computeAt
896
- std::vector<TensorView*> compute_from = consumers_of_cached_inputs;
897
-
898
- // Consumers for inlined computeAt
899
- std::vector<TensorView*> compute_to;
900
- // Compute at cached outputs
901
- // [BIDx, Unswitch, Vectorization, TIDx]
902
873
for (auto entry : cached_outputs) {
903
- auto cached_output = entry.first ;
904
874
auto output = entry.second ;
905
-
906
- auto unswitch_it = std::find_if (
907
- output->domain ()->domain ().begin (),
908
- output->domain ()->domain ().end (),
909
- [](IterDomain* id) {
910
- return id->getParallelType () == ParallelType::Unswitch;
911
- });
912
- auto unswitch_pos = unswitch_it == output->domain ()->domain ().end ()
913
- ? -1
914
- : std::distance (output->domain ()->domain ().begin (), unswitch_it) + 1 ;
915
-
916
- cached_output->computeAt (output, unswitch_pos, ComputeAtMode::BestEffort);
917
- compute_to.push_back (cached_output);
875
+ inner_most_tensors.erase (output);
918
876
}
877
+ InlinePropagator inline_inner_most (
878
+ reference_tv, -1 , ComputeAtMode::BestEffort, inner_most_tensors);
879
+ spanning_tree.traverse (&inline_inner_most);
919
880
920
- scheduler_utils::computeAtBetween (
921
- compute_from, compute_to, -1 , ComputeAtMode::BestEffort);
881
+ // Fix max producer position
882
+ MaxProducerPosUpdater updater;
883
+ spanning_tree.traverse (&updater);
922
884
}
923
885
924
886
} // namespace cuda
0 commit comments