diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index de282dfc8182..27d524389056 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -471,7 +471,9 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) { end = castOp(DataType::Double, end); step = castOp(DataType::Double, step); } - auto size = castOp(DataType::Int, ceilDiv(sub(end, start), step)); + // Make sure no negative value is passed to ceilDiv as the device + // implementation of ceilDiv assumes positive inputs + auto size = castOp(DataType::Int, ceilDiv(abs(sub(end, start)), abs(step))); auto out = TensorViewBuilder() .ndims(1) .dtype(dtype) diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index d8e6b6588214..2d3f1f32f23f 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -125,6 +125,10 @@ TORCH_CUDA_CU_API WelfordResult Welford( TORCH_CUDA_CU_API TensorView* rand( const std::vector& shape, DataType dtype); + +//! WARNING: giving invalid combinations of the start, end and step +//! arguments can result in undefined behavior. Specifically, the +//! signs of `end - start` and step must be the same. TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype = DataType::Int); TORCH_CUDA_CU_API TensorView* arange( Val* start, diff --git a/torch/csrc/jit/codegen/cuda/dynamic_type.h b/torch/csrc/jit/codegen/cuda/dynamic_type.h index aba725e0ea60..5cf9f0930929 100644 --- a/torch/csrc/jit/codegen/cuda/dynamic_type.h +++ b/torch/csrc/jit/codegen/cuda/dynamic_type.h @@ -296,6 +296,14 @@ inline IntOrDouble min(const IntOrDouble& a, const IntOrDouble& b) { return (a < b ? a : b).cast(); } +inline IntOrDouble abs(const IntOrDouble& a) { + if (a.is_int()) { + return IntOrDouble(std::abs(a.as())); + } else { + return IntOrDouble(std::abs(a.as())); + } +} + } // namespace IntOrDouble_functions } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp index bab8586247bf..237881b49b2d 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp @@ -295,6 +295,7 @@ void NaiveValueMachine::runInstruction(int index) { template void NaiveValueMachine::runUnaryOp(int index) { + using namespace IntOrDouble_functions; int src_index = src0_[index]; bool src_defined = precomputed_values_.defined_[src_index]; bool src_is_const = precomputed_values_.is_constant_[src_index]; @@ -323,6 +324,9 @@ void NaiveValueMachine::runUnaryOp(int index) { TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator"); } break; + case UnaryOpType::Abs: + dest = abs(src); + break; default: TORCH_CHECK(!"Unexpected operator type ", uop_type_[index]); } diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 7bda8682189e..7dda464a4fac 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -105,6 +105,7 @@ c10::optional ExpressionEvaluator::getValue(Val* value) { } void ExpressionEvaluator::handle(UnaryOp* uop) { + using namespace IntOrDouble_functions; const auto in = evaluate(uop->in()); if (in.has_value()) { switch (uop->getUnaryOpType()) { @@ -123,6 +124,9 @@ void ExpressionEvaluator::handle(UnaryOp* uop) { TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator"); } break; + case UnaryOpType::Abs: + known_values_[uop->out()] = abs(*in); + break; default: TORCH_CHECK( !"Unexpected operator type ", diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp index a4a823ab5560..15a18a6bca83 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp @@ -132,6 +132,7 @@ void ExpressionEvaluator::handle(const NamedScalar* named_scalar) { } void ExpressionEvaluator::handle(const UnaryOp* unary_op) { + using namespace IntOrDouble_functions; const auto in = evaluate(unary_op->in()); if (in.has_value()) { switch (unary_op->getUnaryOpType()) { @@ -150,6 +151,9 @@ void ExpressionEvaluator::handle(const UnaryOp* unary_op) { TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator"); } break; + case UnaryOpType::Abs: + known_values_[unary_op->out()] = abs(*in); + break; default: TORCH_CHECK( false,