Skip to content

Commit d9420e4

Browse files
authored
View scheduling (#1928)
1 parent c668e13 commit d9420e4

18 files changed

+1556
-150
lines changed

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,30 @@ bool isReductionTvOp(const Expr* expr) {
874874
return ir_utils::isTvOp(expr) && isReductionOp(expr);
875875
}
876876

877+
TORCH_CUDA_CU_API std::vector<ViewOp*> getViewOps(Fusion* fusion) {
878+
auto all_exprs = fusion->exprs();
879+
880+
auto all_view_ops = ir_utils::filterByType<ViewOp>(all_exprs);
881+
882+
std::vector<ViewOp*> view_ops;
883+
884+
std::copy_if(
885+
all_view_ops.begin(),
886+
all_view_ops.end(),
887+
std::back_inserter(view_ops),
888+
[](ViewOp* view) {
889+
return std::any_of(
890+
view->outputs().begin(), view->outputs().end(), [](Val* v) {
891+
if (!v->isA<TensorView>()) {
892+
return false;
893+
}
894+
return v->as<TensorView>()->hasRFactor();
895+
});
896+
});
897+
898+
return view_ops;
899+
}
900+
877901
namespace {
878902

879903
struct ReplaceValInIndexVal : public OptInDispatch {

torch/csrc/jit/codegen/cuda/ir_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
317317
// Returns if Expr is a reduction op with TensorView or TensorIndex
318318
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
319319

320+
// Returns all non-trivial view operations. We shouldn't have trivial view
321+
// operations but this function is to simply make sure if we ever do we don't
322+
// pull them in.
323+
TORCH_CUDA_CU_API std::vector<ViewOp*> getViewOps(Fusion*);
324+
320325
template <typename T>
321326
std::string toString(const T& nodes) {
322327
std::stringstream ss;

torch/csrc/jit/codegen/cuda/ops/alias.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ TensorView* applyViewTransforms(
4343
TORCH_INTERNAL_ASSERT(
4444
post_reduce_tv->nDims() > 0, "Tried to view a 0-dim TensorView");
4545

46-
TORCH_CHECK(
47-
!post_reduce_tv->domain()->hasRFactor(),
48-
"Cannot call view on the same TensorView twice.");
49-
5046
TORCH_INTERNAL_ASSERT(!view_analysis.transforms.empty());
5147

5248
TensorView* consumer = IrBuilder::create<TensorView>(

torch/csrc/jit/codegen/cuda/root_domain_map.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -835,10 +835,6 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped(
835835
addToPendingList(producer_bcast_key, consumer_bcast_key);
836836
}
837837
} else {
838-
TORCH_INTERNAL_ASSERT(
839-
!consumer_id->isBroadcast(),
840-
"No concrete domain found for a broadcast domain: ",
841-
consumer_key.toString());
842838
auto producer_concrete_key = producer_key;
843839
if (producer_id->isBroadcast()) {
844840
const auto concrete_id = consumer_id;
@@ -895,11 +891,11 @@ void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) {
895891
"Unknown multi-output Expr type ",
896892
e->getExprType().value(),
897893
" is found");
898-
for (auto o : e->outputs()) {
899-
auto o_tv = o->as<TensorView>();
900-
auto o_td = o_tv->domain();
901-
auto o_root = o_td->getRootDomain();
902-
setMaybeMapped(in_td, in_root[it], o_td, o_root[it]);
894+
for (auto out : e->outputs()) {
895+
auto out_tv = out->as<TensorView>();
896+
auto out_td = out_tv->domain();
897+
auto out_root = out_td->getRootDomain();
898+
setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
903899
}
904900
} else {
905901
setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);

torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class ScopePersistentFactorInfo {
121121
//! information.
122122
class BroadcastMultiples {
123123
public:
124-
using DataType = std::vector<scheduler_utils::BroadcastMultiple>;
124+
using DataType = scheduler_utils::BroadcastMultipleInformation;
125125
static const CompileTimeEntryType EntryType =
126126
CompileTimeEntryType::BROADCAST_BYTE_MULTIPLES;
127127
};

torch/csrc/jit/codegen/cuda/scheduler/heuristic.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
4+
#include <torch/csrc/jit/codegen/cuda/utils.h>
45

56
#include <string>
67

@@ -9,7 +10,7 @@ namespace jit {
910
namespace fuser {
1011
namespace cuda {
1112

12-
class HeuristicParams {
13+
class HeuristicParams : public PolymorphicBase {
1314
public:
1415
std::string tag = "";
1516

torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
992992
scheduler_utils::getReductionTvs(fusion /*, ignore_trivial = true */);
993993

994994
TORCH_INTERNAL_ASSERT(reduction_tvs.size());
995+
// Registry assumes the reference tv is the first reduction_tv, if this
996+
// changes registry needs to change.
995997
auto reduction_tv = reduction_tvs[0];
996998

997999
auto dim_analysis = scheduler_utils::canonicalDimReduction(

0 commit comments

Comments
 (0)