Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void Val::dispatch(T handler, Val* val) {
default:
break;
}
break;
case ValType::IterDomain:
ptr(handler)->handle(static_cast<IterDomain*>(val));
return;
Expand Down Expand Up @@ -122,28 +123,29 @@ void Val::constDispatch(T handler, const Val* const val) {
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Float:
ptr(handler)->handle(static_cast<const Float* const>(val));
ptr(handler)->handle(static_cast<const Float*>(val));
return;
case DataType::Int:
ptr(handler)->handle(static_cast<const Int* const>(val));
ptr(handler)->handle(static_cast<const Int*>(val));
return;
default:
break;
}
break;
case ValType::IterDomain:
ptr(handler)->handle(static_cast<const IterDomain* const>(val));
ptr(handler)->handle(static_cast<const IterDomain*>(val));
return;
case ValType::TensorDomain:
ptr(handler)->handle(static_cast<const TensorDomain* const>(val));
ptr(handler)->handle(static_cast<const TensorDomain*>(val));
return;
case ValType::TensorView:
ptr(handler)->handle(static_cast<const TensorView* const>(val));
ptr(handler)->handle(static_cast<const TensorView*>(val));
return;
case ValType::TensorIndex:
ptr(handler)->handle(static_cast<const TensorIndex* const>(val));
ptr(handler)->handle(static_cast<const TensorIndex*>(val));
return;
case ValType::NamedScalar:
ptr(handler)->handle(static_cast<const NamedScalar* const>(val));
ptr(handler)->handle(static_cast<const NamedScalar*>(val));
return;
default:
break;
Expand All @@ -152,43 +154,43 @@ void Val::constDispatch(T handler, const Val* const val) {
}

template <typename T>
void Expr::constDispatch(T handler, const Expr* const expr) {
void Expr::constDispatch(T handler, const Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
ptr(handler)->handle(static_cast<const Split* const>(expr));
ptr(handler)->handle(static_cast<const Split*>(expr));
return;
case ExprType::Merge:
ptr(handler)->handle(static_cast<const Merge* const>(expr));
ptr(handler)->handle(static_cast<const Merge*>(expr));
return;
case ExprType::Reorder:
ptr(handler)->handle(static_cast<const Reorder* const>(expr));
ptr(handler)->handle(static_cast<const Reorder*>(expr));
return;
case ExprType::UnaryOp:
ptr(handler)->handle(static_cast<const UnaryOp* const>(expr));
ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
return;
case ExprType::BinaryOp:
ptr(handler)->handle(static_cast<const BinaryOp* const>(expr));
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<const ForLoop* const>(expr));
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
return;
case ExprType::IfThenElse:
ptr(handler)->handle(static_cast<const IfThenElse* const>(expr));
ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
return;
case ExprType::Allocate:
ptr(handler)->handle(static_cast<const Allocate* const>(expr));
ptr(handler)->handle(static_cast<const Allocate*>(expr));
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
}

template <typename T>
void Statement::constDispatch(T handler, const Statement* const stmt) {
void Statement::constDispatch(T handler, const Statement* stmt) {
if (stmt->isVal()) {
ptr(handler)->handle(static_cast<const Val* const>(stmt));
ptr(handler)->handle(static_cast<const Val*>(stmt));
} else if (stmt->isExpr()) {
ptr(handler)->handle(static_cast<const Expr* const>(stmt));
ptr(handler)->handle(static_cast<const Expr*>(stmt));
} else
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
Expand Down Expand Up @@ -216,6 +218,7 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) {
default:
break;
}
break;
case ValType::IterDomain:
return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
case ValType::TensorDomain:
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ struct TORCH_CUDA_API Val : public Statement {
// was found
Expr* getOrigin();

virtual bool sameType(const Statement* const other) {
virtual bool sameType(const Statement* other) {
return Statement::sameType(other) &&
getDataType() == static_cast<const Val* const>(other)->getDataType();
getDataType() == static_cast<const Val*>(other)->getDataType();
}

// TODO: Make this more sophisticated. A value being the same as another value
Expand Down