From 412505254189677e6f6c396029dcf3a4c4f66a9c Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 15:03:12 -0700 Subject: [PATCH] Misc cleanup --- torch/csrc/jit/codegen/cuda/dispatch.cpp | 186 ++++++++++++----------- torch/csrc/jit/codegen/cuda/dispatch.h | 103 +++++++------ 2 files changed, 145 insertions(+), 144 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index a3a55220e534..197c5e68e7e3 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -36,7 +36,7 @@ T* ptr(T* obj) { * } * * And therefore dispatch should never call: - * ptr(mutator)->handle(static_cast(this)); + * ptr(mutator)->handle(this->as()); */ template @@ -45,35 +45,35 @@ void Val::dispatch(T handler, Val* val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Float: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Half: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Int: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; } break; case ValType::IterDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorView: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorIndex: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::NamedScalar: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; @@ -85,34 +85,34 @@ template void Expr::dispatch(T handler, Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Merge: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::UnaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BinaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::TernaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ReductionOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BroadcastOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ForLoop: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::IfThenElse: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Allocate: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -122,9 +122,9 @@ void Expr::dispatch(T handler, Expr* expr) { template void Statement::dispatch(T handler, Statement* stmt) { if (stmt->isVal()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else if (stmt->isExpr()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -135,35 +135,35 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Float: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Half: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Int: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; } break; case ValType::IterDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorView: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorIndex: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::NamedScalar: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; @@ -175,34 +175,34 @@ template void Expr::constDispatch(T handler, const Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Merge: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::UnaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BinaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::TernaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ReductionOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BroadcastOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ForLoop: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::IfThenElse: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Allocate: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -212,9 +212,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { template void Statement::constDispatch(T handler, const Statement* stmt) { if (stmt->isVal()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else if (stmt->isExpr()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -228,7 +228,7 @@ void Statement::constDispatch(T handler, const Statement* stmt) { * implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this); * } * And therefore dispatch should never call: - * ptr(mutator)->mutate(static_cast(this)); + * ptr(mutator)->mutate(this->as()); */ template Statement* Val::mutatorDispatch(T mutator, Val* val) { @@ -236,27 +236,27 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case DataType::Float: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case DataType::Half: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case DataType::Int: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); default: break; } break; case ValType::IterDomain: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::TensorDomain: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::TensorView: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::TensorIndex: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::NamedScalar: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); default: break; } @@ -267,25 +267,25 @@ template Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::Merge: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::UnaryOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::BinaryOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::TernaryOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::ReductionOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::BroadcastOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::ForLoop: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::IfThenElse: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::Allocate: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -294,10 +294,10 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { template Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) { if (stmt->isVal()) { - return ptr(mutator)->mutate(static_cast(stmt)); + return ptr(mutator)->mutate(stmt->as()); } if (stmt->isExpr()) { - return ptr(mutator)->mutate(static_cast(stmt)); + return ptr(mutator)->mutate(stmt->as()); } TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -321,27 +321,19 @@ template void Val::dispatch(OptInDispatch*, Val*); template void Expr::dispatch(OptInDispatch, Expr*); template void Expr::dispatch(OptInDispatch*, Expr*); -template void Statement::constDispatch( - OptOutConstDispatch, - const Statement* const); -template void Statement::constDispatch( - OptOutConstDispatch*, - const Statement* const); -template void Val::constDispatch(OptOutConstDispatch, const Val* const); -template void Val::constDispatch(OptOutConstDispatch*, const Val* const); -template void Expr::constDispatch(OptOutConstDispatch, const Expr* const); -template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const); - -template void Statement::constDispatch( - OptInConstDispatch, - const Statement* const); -template void Statement::constDispatch( - OptInConstDispatch*, - const Statement* const); -template void Val::constDispatch(OptInConstDispatch, const Val* const); -template void Val::constDispatch(OptInConstDispatch*, const Val* const); -template void Expr::constDispatch(OptInConstDispatch, const Expr* const); -template void Expr::constDispatch(OptInConstDispatch*, const Expr* const); +template void Statement::constDispatch(OptOutConstDispatch, const Statement*); +template void Statement::constDispatch(OptOutConstDispatch*, const Statement*); +template void Val::constDispatch(OptOutConstDispatch, const Val*); +template void Val::constDispatch(OptOutConstDispatch*, const Val*); +template void Expr::constDispatch(OptOutConstDispatch, const Expr*); +template void Expr::constDispatch(OptOutConstDispatch*, const Expr*); + +template void Statement::constDispatch(OptInConstDispatch, const Statement*); +template void Statement::constDispatch(OptInConstDispatch*, const Statement*); +template void Val::constDispatch(OptInConstDispatch, const Val*); +template void Val::constDispatch(OptInConstDispatch*, const Val*); +template void Expr::constDispatch(OptInConstDispatch, const Expr*); +template void Expr::constDispatch(OptInConstDispatch*, const Expr*); template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*); template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*); @@ -360,9 +352,11 @@ template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*); void OptOutDispatch::handle(Statement* s) { Statement::dispatch(this, s); } + void OptOutDispatch::handle(Expr* e) { Expr::dispatch(this, e); } + void OptOutDispatch::handle(Val* v) { Val::dispatch(this, v); } @@ -370,30 +364,36 @@ void OptOutDispatch::handle(Val* v) { void OptInDispatch::handle(Statement* s) { Statement::dispatch(this, s); } + void OptInDispatch::handle(Expr* e) { Expr::dispatch(this, e); } + void OptInDispatch::handle(Val* v) { Val::dispatch(this, v); } -void OptOutConstDispatch::handle(const Statement* const s) { +void OptOutConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } -void OptOutConstDispatch::handle(const Expr* const e) { + +void OptOutConstDispatch::handle(const Expr* e) { Expr::constDispatch(this, e); } -void OptOutConstDispatch::handle(const Val* const v) { + +void OptOutConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } -void OptInConstDispatch::handle(const Statement* const s) { +void OptInConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } -void OptInConstDispatch::handle(const Expr* const e) { + +void OptInConstDispatch::handle(const Expr* e) { Expr::constDispatch(this, e); } -void OptInConstDispatch::handle(const Val* const v) { + +void OptInConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } @@ -415,9 +415,11 @@ Statement* OptInMutator::mutate(Val* v) { Statement* OptOutMutator::mutate(Statement* s) { return Statement::mutatorDispatch(this, s); } + Statement* OptOutMutator::mutate(Expr* e) { return Expr::mutatorDispatch(this, e); } + Statement* OptOutMutator::mutate(Val* v) { // If value is already mutated, return the mutation if (mutations.find(v) != mutations.end()) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 5a3aa89d8275..b3575cf44d33 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -93,32 +93,32 @@ struct TORCH_CUDA_API OptOutConstDispatch { OptOutConstDispatch& operator=(OptOutConstDispatch&& other) = default; // Hierarchal dispatch functions for handle - virtual void handle(const Statement* const); - virtual void handle(const Expr* const); - virtual void handle(const Val* const); + virtual void handle(const Statement*); + virtual void handle(const Expr*); + virtual void handle(const Val*); // Vals - virtual void handle(const IterDomain* const) {} - virtual void handle(const TensorDomain* const) {} - virtual void handle(const TensorView* const) {} - virtual void handle(const TensorIndex* const) {} - virtual void handle(const Bool* const) {} - virtual void handle(const Float* const) {} - virtual void handle(const Half* const) {} - virtual void handle(const Int* const) {} - virtual void handle(const NamedScalar* const) {} + virtual void handle(const IterDomain*) {} + virtual void handle(const TensorDomain*) {} + virtual void handle(const TensorView*) {} + virtual void handle(const TensorIndex*) {} + virtual void handle(const Bool*) {} + virtual void handle(const Float*) {} + virtual void handle(const Half*) {} + virtual void handle(const Int*) {} + virtual void handle(const NamedScalar*) {} // Exprs - virtual void handle(const Split* const) {} - virtual void handle(const Merge* const) {} - virtual void handle(const UnaryOp* const) {} - virtual void handle(const BinaryOp* const) {} - virtual void handle(const TernaryOp* const) {} - virtual void handle(const ReductionOp* const) {} - virtual void handle(const BroadcastOp* const) {} - virtual void handle(const ForLoop* const) {} - virtual void handle(const IfThenElse* const) {} - virtual void handle(const Allocate* const) {} + virtual void handle(const Split*) {} + virtual void handle(const Merge*) {} + virtual void handle(const UnaryOp*) {} + virtual void handle(const BinaryOp*) {} + virtual void handle(const TernaryOp*) {} + virtual void handle(const ReductionOp*) {} + virtual void handle(const BroadcastOp*) {} + virtual void handle(const ForLoop*) {} + virtual void handle(const IfThenElse*) {} + virtual void handle(const Allocate*) {} }; struct TORCH_CUDA_API OptOutDispatch { @@ -171,68 +171,68 @@ struct TORCH_CUDA_API OptInConstDispatch { OptInConstDispatch& operator=(OptInConstDispatch&& other) = default; // Hierarchal dispatch functions for handle - virtual void handle(const Statement* const); - virtual void handle(const Expr* const); - virtual void handle(const Val* const); + virtual void handle(const Statement*); + virtual void handle(const Expr*); + virtual void handle(const Val*); // Vals - virtual void handle(const IterDomain* const) { + virtual void handle(const IterDomain*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain."); } - virtual void handle(const TensorDomain* const) { + virtual void handle(const TensorDomain*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain."); } - virtual void handle(const TensorView* const) { + virtual void handle(const TensorView*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView."); } - virtual void handle(const TensorIndex* const) { + virtual void handle(const TensorIndex*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorIndex."); } - virtual void handle(const Bool* const) { + virtual void handle(const Bool*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool."); } - virtual void handle(const Float* const) { + virtual void handle(const Float*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float."); } - virtual void handle(const Half* const) { + virtual void handle(const Half*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half."); } - virtual void handle(const Int* const) { + virtual void handle(const Int*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int."); } - virtual void handle(const NamedScalar* const) { + virtual void handle(const NamedScalar*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar."); } // Exprs - virtual void handle(const Split* const) { + virtual void handle(const Split*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split."); } - virtual void handle(const Merge* const) { + virtual void handle(const Merge*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); } - virtual void handle(const UnaryOp* const) { + virtual void handle(const UnaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); } - virtual void handle(const BinaryOp* const) { + virtual void handle(const BinaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp."); } - virtual void handle(const TernaryOp* const) { + virtual void handle(const TernaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp."); } - virtual void handle(const ReductionOp* const) { + virtual void handle(const ReductionOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); } - virtual void handle(const BroadcastOp* const) { + virtual void handle(const BroadcastOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); } - virtual void handle(const ForLoop* const) { + virtual void handle(const ForLoop*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ForLoop."); } - virtual void handle(const Allocate* const) { + virtual void handle(const Allocate*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Allocate."); } - virtual void handle(const IfThenElse* const) { + virtual void handle(const IfThenElse*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IfThenElse."); } }; @@ -331,13 +331,11 @@ struct TORCH_CUDA_API OptOutMutator { virtual Statement* mutate(Expr* e); virtual Statement* mutate(Val* v); - /* - * We always want to dispatch through a Val, so we can capture and dispatch - * correctly members of nodes like Split->TensorDomain If we don't call the - * below function or manually cast to use mutate(Val* v) we can't intercept - * and mutate by capturing mutate(Val* v), which is what we do when we want to - * replace all instances of a value. - */ + // We always want to dispatch through a Val, so we can capture and dispatch + // correctly members of nodes like Split->TensorDomain If we don't call the + // below function or manually cast to use mutate(Val* v) we can't intercept + // and mutate by capturing mutate(Val* v), which is what we do when we want to + // replace all instances of a value. Statement* mutateAsVal(Val* v) { return mutate(v); } @@ -352,7 +350,8 @@ struct TORCH_CUDA_API OptOutMutator { std::unordered_map mutations; - //****Functions below defined in mutator.cpp*****/// + //****Functions below defined in mutator.cpp***** + // Vals virtual Statement* mutate(IterDomain*); virtual Statement* mutate(TensorDomain*);