diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 8c871ba15a38d..134100342a2e9 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -53,6 +53,7 @@ void Val::dispatch(T handler, Val* val) { default: break; } + break; case ValType::IterDomain: ptr(handler)->handle(static_cast(val)); return; @@ -122,28 +123,29 @@ void Val::constDispatch(T handler, const Val* const val) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Float: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; case DataType::Int: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; default: break; } + break; case ValType::IterDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; case ValType::TensorDomain: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; case ValType::TensorView: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; case ValType::TensorIndex: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; case ValType::NamedScalar: - ptr(handler)->handle(static_cast(val)); + ptr(handler)->handle(static_cast(val)); return; default: break; @@ -152,31 +154,31 @@ void Val::constDispatch(T handler, const Val* const val) { } template -void Expr::constDispatch(T handler, const Expr* const expr) { +void Expr::constDispatch(T handler, const Expr* expr) { switch (*(expr->getExprType())) { case ExprType::Split: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::Merge: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::Reorder: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::UnaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::BinaryOp: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::ForLoop: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::IfThenElse: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; case ExprType::Allocate: - ptr(handler)->handle(static_cast(expr)); + ptr(handler)->handle(static_cast(expr)); return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -184,11 +186,11 @@ void Expr::constDispatch(T handler, const Expr* const expr) { } template -void Statement::constDispatch(T handler, const Statement* const stmt) { +void Statement::constDispatch(T handler, const Statement* stmt) { if (stmt->isVal()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(static_cast(stmt)); } else if (stmt->isExpr()) { - ptr(handler)->handle(static_cast(stmt)); + ptr(handler)->handle(static_cast(stmt)); } else TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -216,6 +218,7 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) { default: break; } + break; case ValType::IterDomain: return ptr(mutator)->mutate(static_cast(val)); case ValType::TensorDomain: diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 884df4ebcd2a2..44b4b42fe7d15 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -185,9 +185,9 @@ struct TORCH_CUDA_API Val : public Statement { // was found Expr* getOrigin(); - virtual bool sameType(const Statement* const other) { + virtual bool sameType(const Statement* other) { return Statement::sameType(other) && - getDataType() == static_cast(other)->getDataType(); + getDataType() == static_cast(other)->getDataType(); } // TODO: Make this more sophisticated. A value being the same as another value