diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 2ec01fc34f8c..197c5e68e7e3 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -36,7 +36,7 @@ T* ptr(T* obj) { * } * * And therefore dispatch should never call: - * ptr(mutator)->handle(static_cast(this)); + * ptr(mutator)->handle(this->as()); */ template @@ -45,35 +45,35 @@ void Val::dispatch(T handler, Val* val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Float: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Half: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Int: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; } break; case ValType::IterDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorView: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorIndex: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::NamedScalar: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; @@ -85,34 +85,34 @@ template void Expr::dispatch(T handler, Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Merge: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::UnaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BinaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::TernaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ReductionOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BroadcastOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ForLoop: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::IfThenElse: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Allocate: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -122,9 +122,9 @@ void Expr::dispatch(T handler, Expr* expr) { template void Statement::dispatch(T handler, Statement* stmt) { if (stmt->isVal()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else if (stmt->isExpr()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -135,35 +135,35 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Float: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Half: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case DataType::Int: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; } break; case ValType::IterDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorView: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::TensorIndex: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; case ValType::NamedScalar: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(val->as()); return; default: break; @@ -175,34 +175,34 @@ template void Expr::constDispatch(T handler, const Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Merge: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::UnaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BinaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::TernaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ReductionOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::BroadcastOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::ForLoop: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::IfThenElse: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; case ExprType::Allocate: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(expr->as()); return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -212,9 +212,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { template void Statement::constDispatch(T handler, const Statement* stmt) { if (stmt->isVal()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else if (stmt->isExpr()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(stmt->as()); } else TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -228,7 +228,7 @@ void Statement::constDispatch(T handler, const Statement* stmt) { * implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this); * } * And therefore dispatch should never call: - * ptr(mutator)->mutate(static_cast(this)); + * ptr(mutator)->mutate(this->as()); */ template Statement* Val::mutatorDispatch(T mutator, Val* val) { @@ -236,27 +236,27 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case DataType::Float: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case DataType::Half: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case DataType::Int: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); default: break; } break; case ValType::IterDomain: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::TensorDomain: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::TensorView: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::TensorIndex: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); case ValType::NamedScalar: - return ptr(mutator)->mutate(static_cast(val)); + return ptr(mutator)->mutate(val->as()); default: break; } @@ -267,25 +267,25 @@ template Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::Merge: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::UnaryOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::BinaryOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::TernaryOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::ReductionOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::BroadcastOp: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::ForLoop: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::IfThenElse: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); case ExprType::Allocate: - return ptr(mutator)->mutate(static_cast(expr)); + return ptr(mutator)->mutate(expr->as()); default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -294,10 +294,10 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { template Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) { if (stmt->isVal()) { - return ptr(mutator)->mutate(static_cast(stmt)); + return ptr(mutator)->mutate(stmt->as()); } if (stmt->isExpr()) { - return ptr(mutator)->mutate(static_cast(stmt)); + return ptr(mutator)->mutate(stmt->as()); } TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -352,9 +352,11 @@ template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*); void OptOutDispatch::handle(Statement* s) { Statement::dispatch(this, s); } + void OptOutDispatch::handle(Expr* e) { Expr::dispatch(this, e); } + void OptOutDispatch::handle(Val* v) { Val::dispatch(this, v); } @@ -362,9 +364,11 @@ void OptOutDispatch::handle(Val* v) { void OptInDispatch::handle(Statement* s) { Statement::dispatch(this, s); } + void OptInDispatch::handle(Expr* e) { Expr::dispatch(this, e); } + void OptInDispatch::handle(Val* v) { Val::dispatch(this, v); } @@ -372,9 +376,11 @@ void OptInDispatch::handle(Val* v) { void OptOutConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } + void OptOutConstDispatch::handle(const Expr* e) { Expr::constDispatch(this, e); } + void OptOutConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } @@ -382,9 +388,11 @@ void OptOutConstDispatch::handle(const Val* v) { void OptInConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } + void OptInConstDispatch::handle(const Expr* e) { Expr::constDispatch(this, e); } + void OptInConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } @@ -407,9 +415,11 @@ Statement* OptInMutator::mutate(Val* v) { Statement* OptOutMutator::mutate(Statement* s) { return Statement::mutatorDispatch(this, s); } + Statement* OptOutMutator::mutate(Expr* e) { return Expr::mutatorDispatch(this, e); } + Statement* OptOutMutator::mutate(Val* v) { // If value is already mutated, return the mutation if (mutations.find(v) != mutations.end()) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index f2ad5f7128dd..b3575cf44d33 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -331,13 +331,11 @@ struct TORCH_CUDA_API OptOutMutator { virtual Statement* mutate(Expr* e); virtual Statement* mutate(Val* v); - /* - * We always want to dispatch through a Val, so we can capture and dispatch - * correctly members of nodes like Split->TensorDomain If we don't call the - * below function or manually cast to use mutate(Val* v) we can't intercept - * and mutate by capturing mutate(Val* v), which is what we do when we want to - * replace all instances of a value. - */ + // We always want to dispatch through a Val, so we can capture and dispatch + // correctly members of nodes like Split->TensorDomain If we don't call the + // below function or manually cast to use mutate(Val* v) we can't intercept + // and mutate by capturing mutate(Val* v), which is what we do when we want to + // replace all instances of a value. Statement* mutateAsVal(Val* v) { return mutate(v); } @@ -352,7 +350,8 @@ struct TORCH_CUDA_API OptOutMutator { std::unordered_map mutations; - //****Functions below defined in mutator.cpp*****/// + //****Functions below defined in mutator.cpp***** + // Vals virtual Statement* mutate(IterDomain*); virtual Statement* mutate(TensorDomain*);