Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ T* ptr(T* obj) {
* }
*
* And therefore dispatch should never call:
* ptr(mutator)->handle(static_cast<Statement*>(this));
* ptr(mutator)->handle(this->as<Statement>());
*/

template <typename T>
Expand All @@ -45,35 +45,35 @@ void Val::dispatch(T handler, Val* val) {
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
ptr(handler)->handle(static_cast<Bool*>(val));
ptr(handler)->handle(val->as<Bool>());
return;
case DataType::Float:
ptr(handler)->handle(static_cast<Float*>(val));
ptr(handler)->handle(val->as<Float>());
return;
case DataType::Half:
ptr(handler)->handle(static_cast<Half*>(val));
ptr(handler)->handle(val->as<Half>());
return;
case DataType::Int:
ptr(handler)->handle(static_cast<Int*>(val));
ptr(handler)->handle(val->as<Int>());
return;
default:
break;
}
break;
case ValType::IterDomain:
ptr(handler)->handle(static_cast<IterDomain*>(val));
ptr(handler)->handle(val->as<IterDomain>());
return;
case ValType::TensorDomain:
ptr(handler)->handle(static_cast<TensorDomain*>(val));
ptr(handler)->handle(val->as<TensorDomain>());
return;
case ValType::TensorView:
ptr(handler)->handle(static_cast<TensorView*>(val));
ptr(handler)->handle(val->as<TensorView>());
return;
case ValType::TensorIndex:
ptr(handler)->handle(static_cast<TensorIndex*>(val));
ptr(handler)->handle(val->as<TensorIndex>());
return;
case ValType::NamedScalar:
ptr(handler)->handle(static_cast<NamedScalar*>(val));
ptr(handler)->handle(val->as<NamedScalar>());
return;
default:
break;
Expand All @@ -85,34 +85,34 @@ template <typename T>
void Expr::dispatch(T handler, Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
ptr(handler)->handle(static_cast<Split*>(expr));
ptr(handler)->handle(expr->as<Split>());
return;
case ExprType::Merge:
ptr(handler)->handle(static_cast<Merge*>(expr));
ptr(handler)->handle(expr->as<Merge>());
return;
case ExprType::UnaryOp:
ptr(handler)->handle(static_cast<UnaryOp*>(expr));
ptr(handler)->handle(expr->as<UnaryOp>());
return;
case ExprType::BinaryOp:
ptr(handler)->handle(static_cast<BinaryOp*>(expr));
ptr(handler)->handle(expr->as<BinaryOp>());
return;
case ExprType::TernaryOp:
ptr(handler)->handle(static_cast<TernaryOp*>(expr));
ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::ReductionOp:
ptr(handler)->handle(static_cast<ReductionOp*>(expr));
ptr(handler)->handle(expr->as<ReductionOp>());
return;
case ExprType::BroadcastOp:
ptr(handler)->handle(static_cast<BroadcastOp*>(expr));
ptr(handler)->handle(expr->as<BroadcastOp>());
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<ForLoop*>(expr));
ptr(handler)->handle(expr->as<ForLoop>());
return;
case ExprType::IfThenElse:
ptr(handler)->handle(static_cast<IfThenElse*>(expr));
ptr(handler)->handle(expr->as<IfThenElse>());
return;
case ExprType::Allocate:
ptr(handler)->handle(static_cast<Allocate*>(expr));
ptr(handler)->handle(expr->as<Allocate>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
Expand All @@ -122,9 +122,9 @@ void Expr::dispatch(T handler, Expr* expr) {
template <typename T>
void Statement::dispatch(T handler, Statement* stmt) {
if (stmt->isVal()) {
ptr(handler)->handle(static_cast<Val*>(stmt));
ptr(handler)->handle(stmt->as<Val>());
} else if (stmt->isExpr()) {
ptr(handler)->handle(static_cast<Expr*>(stmt));
ptr(handler)->handle(stmt->as<Expr>());
} else
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
Expand All @@ -135,35 +135,35 @@ void Val::constDispatch(T handler, const Val* val) {
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
ptr(handler)->handle(static_cast<const Bool*>(val));
ptr(handler)->handle(val->as<Bool>());
return;
case DataType::Float:
ptr(handler)->handle(static_cast<const Float*>(val));
ptr(handler)->handle(val->as<Float>());
return;
case DataType::Half:
ptr(handler)->handle(static_cast<const Half*>(val));
ptr(handler)->handle(val->as<Half>());
return;
case DataType::Int:
ptr(handler)->handle(static_cast<const Int*>(val));
ptr(handler)->handle(val->as<Int>());
return;
default:
break;
}
break;
case ValType::IterDomain:
ptr(handler)->handle(static_cast<const IterDomain*>(val));
ptr(handler)->handle(val->as<IterDomain>());
return;
case ValType::TensorDomain:
ptr(handler)->handle(static_cast<const TensorDomain*>(val));
ptr(handler)->handle(val->as<TensorDomain>());
return;
case ValType::TensorView:
ptr(handler)->handle(static_cast<const TensorView*>(val));
ptr(handler)->handle(val->as<TensorView>());
return;
case ValType::TensorIndex:
ptr(handler)->handle(static_cast<const TensorIndex*>(val));
ptr(handler)->handle(val->as<TensorIndex>());
return;
case ValType::NamedScalar:
ptr(handler)->handle(static_cast<const NamedScalar*>(val));
ptr(handler)->handle(val->as<NamedScalar>());
return;
default:
break;
Expand All @@ -175,34 +175,34 @@ template <typename T>
void Expr::constDispatch(T handler, const Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
ptr(handler)->handle(static_cast<const Split*>(expr));
ptr(handler)->handle(expr->as<Split>());
return;
case ExprType::Merge:
ptr(handler)->handle(static_cast<const Merge*>(expr));
ptr(handler)->handle(expr->as<Merge>());
return;
case ExprType::UnaryOp:
ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
ptr(handler)->handle(expr->as<UnaryOp>());
return;
case ExprType::BinaryOp:
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
ptr(handler)->handle(expr->as<BinaryOp>());
return;
case ExprType::TernaryOp:
ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::ReductionOp:
ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
ptr(handler)->handle(expr->as<ReductionOp>());
return;
case ExprType::BroadcastOp:
ptr(handler)->handle(static_cast<const BroadcastOp*>(expr));
ptr(handler)->handle(expr->as<BroadcastOp>());
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
ptr(handler)->handle(expr->as<ForLoop>());
return;
case ExprType::IfThenElse:
ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
ptr(handler)->handle(expr->as<IfThenElse>());
return;
case ExprType::Allocate:
ptr(handler)->handle(static_cast<const Allocate*>(expr));
ptr(handler)->handle(expr->as<Allocate>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
Expand All @@ -212,9 +212,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
template <typename T>
void Statement::constDispatch(T handler, const Statement* stmt) {
if (stmt->isVal()) {
ptr(handler)->handle(static_cast<const Val*>(stmt));
ptr(handler)->handle(stmt->as<Val>());
} else if (stmt->isExpr()) {
ptr(handler)->handle(static_cast<const Expr*>(stmt));
ptr(handler)->handle(stmt->as<Expr>());
} else
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
Expand All @@ -228,35 +228,35 @@ void Statement::constDispatch(T handler, const Statement* stmt) {
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
* }
* And therefore dispatch should never call:
* ptr(mutator)->mutate(static_cast<Statement*>(this));
* ptr(mutator)->mutate(this->as<Statement>());
*/
template <typename T>
Statement* Val::mutatorDispatch(T mutator, Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
return ptr(mutator)->mutate(static_cast<Bool*>(val));
return ptr(mutator)->mutate(val->as<Bool>());
case DataType::Float:
return ptr(mutator)->mutate(static_cast<Float*>(val));
return ptr(mutator)->mutate(val->as<Float>());
case DataType::Half:
return ptr(mutator)->mutate(static_cast<Half*>(val));
return ptr(mutator)->mutate(val->as<Half>());
case DataType::Int:
return ptr(mutator)->mutate(static_cast<Int*>(val));
return ptr(mutator)->mutate(val->as<Int>());
default:
break;
}
break;
case ValType::IterDomain:
return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
return ptr(mutator)->mutate(val->as<IterDomain>());
case ValType::TensorDomain:
return ptr(mutator)->mutate(static_cast<TensorDomain*>(val));
return ptr(mutator)->mutate(val->as<TensorDomain>());
case ValType::TensorView:
return ptr(mutator)->mutate(static_cast<TensorView*>(val));
return ptr(mutator)->mutate(val->as<TensorView>());
case ValType::TensorIndex:
return ptr(mutator)->mutate(static_cast<TensorIndex*>(val));
return ptr(mutator)->mutate(val->as<TensorIndex>());
case ValType::NamedScalar:
return ptr(mutator)->mutate(static_cast<NamedScalar*>(val));
return ptr(mutator)->mutate(val->as<NamedScalar>());
default:
break;
}
Expand All @@ -267,25 +267,25 @@ template <typename T>
Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
return ptr(mutator)->mutate(static_cast<Split*>(expr));
return ptr(mutator)->mutate(expr->as<Split>());
case ExprType::Merge:
return ptr(mutator)->mutate(static_cast<Merge*>(expr));
return ptr(mutator)->mutate(expr->as<Merge>());
case ExprType::UnaryOp:
return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
return ptr(mutator)->mutate(expr->as<UnaryOp>());
case ExprType::BinaryOp:
return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
return ptr(mutator)->mutate(expr->as<BinaryOp>());
case ExprType::TernaryOp:
return ptr(mutator)->mutate(static_cast<TernaryOp*>(expr));
return ptr(mutator)->mutate(expr->as<TernaryOp>());
case ExprType::ReductionOp:
return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
return ptr(mutator)->mutate(expr->as<ReductionOp>());
case ExprType::BroadcastOp:
return ptr(mutator)->mutate(static_cast<BroadcastOp*>(expr));
return ptr(mutator)->mutate(expr->as<BroadcastOp>());
case ExprType::ForLoop:
return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
return ptr(mutator)->mutate(expr->as<ForLoop>());
case ExprType::IfThenElse:
return ptr(mutator)->mutate(static_cast<IfThenElse*>(expr));
return ptr(mutator)->mutate(expr->as<IfThenElse>());
case ExprType::Allocate:
return ptr(mutator)->mutate(static_cast<Allocate*>(expr));
return ptr(mutator)->mutate(expr->as<Allocate>());
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
Expand All @@ -294,10 +294,10 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
template <typename T>
Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) {
if (stmt->isVal()) {
return ptr(mutator)->mutate(static_cast<Val*>(stmt));
return ptr(mutator)->mutate(stmt->as<Val>());
}
if (stmt->isExpr()) {
return ptr(mutator)->mutate(static_cast<Expr*>(stmt));
return ptr(mutator)->mutate(stmt->as<Expr>());
}
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
Expand Down Expand Up @@ -352,39 +352,47 @@ template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*);
void OptOutDispatch::handle(Statement* s) {
Statement::dispatch(this, s);
}

void OptOutDispatch::handle(Expr* e) {
Expr::dispatch(this, e);
}

void OptOutDispatch::handle(Val* v) {
Val::dispatch(this, v);
}

void OptInDispatch::handle(Statement* s) {
Statement::dispatch(this, s);
}

void OptInDispatch::handle(Expr* e) {
Expr::dispatch(this, e);
}

void OptInDispatch::handle(Val* v) {
Val::dispatch(this, v);
}

void OptOutConstDispatch::handle(const Statement* s) {
Statement::constDispatch(this, s);
}

void OptOutConstDispatch::handle(const Expr* e) {
Expr::constDispatch(this, e);
}

void OptOutConstDispatch::handle(const Val* v) {
Val::constDispatch(this, v);
}

void OptInConstDispatch::handle(const Statement* s) {
Statement::constDispatch(this, s);
}

void OptInConstDispatch::handle(const Expr* e) {
Expr::constDispatch(this, e);
}

void OptInConstDispatch::handle(const Val* v) {
Val::constDispatch(this, v);
}
Expand All @@ -407,9 +415,11 @@ Statement* OptInMutator::mutate(Val* v) {
Statement* OptOutMutator::mutate(Statement* s) {
return Statement::mutatorDispatch(this, s);
}

Statement* OptOutMutator::mutate(Expr* e) {
return Expr::mutatorDispatch(this, e);
}

Statement* OptOutMutator::mutate(Val* v) {
// If value is already mutated, return the mutation
if (mutations.find(v) != mutations.end())
Expand Down
15 changes: 7 additions & 8 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,11 @@ struct TORCH_CUDA_API OptOutMutator {
virtual Statement* mutate(Expr* e);
virtual Statement* mutate(Val* v);

/*
* We always want to dispatch through a Val, so we can capture and dispatch
* correctly members of nodes like Split->TensorDomain If we don't call the
* below function or manually cast to use mutate(Val* v) we can't intercept
* and mutate by capturing mutate(Val* v), which is what we do when we want to
* replace all instances of a value.
*/
// We always want to dispatch through a Val, so we can capture and dispatch
// correctly members of nodes like Split->TensorDomain If we don't call the
// below function or manually cast to use mutate(Val* v) we can't intercept
// and mutate by capturing mutate(Val* v), which is what we do when we want to
// replace all instances of a value.
Statement* mutateAsVal(Val* v) {
return mutate(v);
}
Expand All @@ -352,7 +350,8 @@ struct TORCH_CUDA_API OptOutMutator {

std::unordered_map<Val*, Val*> mutations;

//****Functions below defined in mutator.cpp*****///
//****Functions below defined in mutator.cpp*****

// Vals
virtual Statement* mutate(IterDomain*);
virtual Statement* mutate(TensorDomain*);
Expand Down