Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace fuser {

// Will return a new value of type val with the DataType dtype, if it's a
// tensorview it will propagate the shape information from val.
TORCH_CUDA_API Val* newValLike(const Val* const val, DataType dtype) {
TORCH_CUDA_API Val* newValLike(const Val* val, DataType dtype) {
switch (val->getValType().value()) {
case (ValType::TensorView):
return static_cast<const TensorView* const>(val)->newForOutput(dtype);
return val->as<TensorView>()->newForOutput(dtype);
case (ValType::NamedScalar):
case (ValType::Scalar):
switch (dtype) {
Expand All @@ -39,7 +39,7 @@ TORCH_CUDA_API Val* newValLike(const Val* const val, DataType dtype) {
val->getDataType().value());
}

TORCH_CUDA_API Val* newValLike(const Val* const val) {
TORCH_CUDA_API Val* newValLike(const Val* val) {
return newValLike(val, val->getDataType().value());
}

Expand Down Expand Up @@ -112,7 +112,7 @@ TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
}

TORCH_CUDA_API TensorView* castOp(DataType dtype, TensorView* v1) {
return castOp(dtype, static_cast<Val*>(v1))->as<TensorView>();
return castOp(dtype, v1->as<Val>())->as<TensorView>();
}

// UNARY OPERATIONS
Expand All @@ -124,7 +124,7 @@ TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1) {
}

TORCH_CUDA_API TensorView* unaryOp(UnaryOpType type, TensorView* v1) {
return unaryOp(type, static_cast<Val*>(v1))->as<TensorView>();
return unaryOp(type, v1->as<Val>())->as<TensorView>();
}

TORCH_CUDA_API Val* neg(Val* v) {
Expand Down Expand Up @@ -551,7 +551,7 @@ TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
}

TORCH_CUDA_API TensorView* threshold(TensorView* in, Val* thresh, Val* value) {
return threshold(static_cast<Val*>(in), thresh, value)->as<TensorView>();
return threshold(in->as<Val>(), thresh, value)->as<TensorView>();
}

TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
Expand All @@ -572,7 +572,7 @@ TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
}

TORCH_CUDA_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) {
return clamp(static_cast<Val*>(in), min_val, max_val)->as<TensorView>();
return clamp(in->as<Val>(), min_val, max_val)->as<TensorView>();
}

} // namespace fuser
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ void Statement::dispatch(T handler, Statement* stmt) {
}

template <typename T>
void Val::constDispatch(T handler, const Val* const val) {
void Val::constDispatch(T handler, const Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
ptr(handler)->handle(static_cast<const Bool* const>(val));
ptr(handler)->handle(static_cast<const Bool*>(val));
return;
case DataType::Float:
ptr(handler)->handle(static_cast<const Float*>(val));
return;
case DataType::Half:
ptr(handler)->handle(static_cast<const Half* const>(val));
ptr(handler)->handle(static_cast<const Half*>(val));
return;
case DataType::Int:
ptr(handler)->handle(static_cast<const Int*>(val));
Expand Down Expand Up @@ -190,10 +190,10 @@ void Expr::constDispatch(T handler, const Expr* expr) {
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
return;
case ExprType::TernaryOp:
ptr(handler)->handle(static_cast<const TernaryOp* const>(expr));
ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
return;
case ExprType::ReductionOp:
ptr(handler)->handle(static_cast<const ReductionOp* const>(expr));
ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
Expand Down