Skip to content

Commit 4c254c0

Browse files
authored
Fix arange when step is negative (#1942)
* The device version of ceilDiv assumes positive inputs, so when step is negative, it gives an incorrect result. For example, I see FusionStandAloneArange results in a write error with compute-sanitizer when start = 0, stop = -1, step = -1.5 and dtype = kLong.
1 parent 89330aa commit 4c254c0

File tree

6 files changed

+27
-1
lines changed

6 files changed

+27
-1
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,9 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
471471
end = castOp(DataType::Double, end);
472472
step = castOp(DataType::Double, step);
473473
}
474-
auto size = castOp(DataType::Int, ceilDiv(sub(end, start), step));
474+
// Make sure no negative value is passed to ceilDiv as the device
475+
// implementation of ceilDiv assumes positive inputs
476+
auto size = castOp(DataType::Int, ceilDiv(abs(sub(end, start)), abs(step)));
475477
auto out = TensorViewBuilder()
476478
.ndims(1)
477479
.dtype(dtype)

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ TORCH_CUDA_CU_API WelfordResult Welford(
125125
TORCH_CUDA_CU_API TensorView* rand(
126126
const std::vector<Val*>& shape,
127127
DataType dtype);
128+
129+
//! WARNING: giving invalid combinations of the start, end and step
130+
//! arguments can result in undefined behavior. Specifically, the
131+
//! signs of `end - start` and step must be the same.
128132
TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype = DataType::Int);
129133
TORCH_CUDA_CU_API TensorView* arange(
130134
Val* start,

torch/csrc/jit/codegen/cuda/dynamic_type.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,14 @@ inline IntOrDouble min(const IntOrDouble& a, const IntOrDouble& b) {
296296
return (a < b ? a : b).cast<double>();
297297
}
298298

299+
inline IntOrDouble abs(const IntOrDouble& a) {
300+
if (a.is_int()) {
301+
return IntOrDouble(std::abs(a.as<int64_t>()));
302+
} else {
303+
return IntOrDouble(std::abs(a.as<double>()));
304+
}
305+
}
306+
299307
} // namespace IntOrDouble_functions
300308

301309
} // namespace cuda

torch/csrc/jit/codegen/cuda/evaluator_common.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ void NaiveValueMachine<IRContext>::runInstruction(int index) {
295295

296296
template <typename IRContext>
297297
void NaiveValueMachine<IRContext>::runUnaryOp(int index) {
298+
using namespace IntOrDouble_functions;
298299
int src_index = src0_[index];
299300
bool src_defined = precomputed_values_.defined_[src_index];
300301
bool src_is_const = precomputed_values_.is_constant_[src_index];
@@ -323,6 +324,9 @@ void NaiveValueMachine<IRContext>::runUnaryOp(int index) {
323324
TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
324325
}
325326
break;
327+
case UnaryOpType::Abs:
328+
dest = abs(src);
329+
break;
326330
default:
327331
TORCH_CHECK(!"Unexpected operator type ", uop_type_[index]);
328332
}

torch/csrc/jit/codegen/cuda/expr_evaluator.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ c10::optional<IntOrDouble> ExpressionEvaluator::getValue(Val* value) {
105105
}
106106

107107
void ExpressionEvaluator::handle(UnaryOp* uop) {
108+
using namespace IntOrDouble_functions;
108109
const auto in = evaluate(uop->in());
109110
if (in.has_value()) {
110111
switch (uop->getUnaryOpType()) {
@@ -123,6 +124,9 @@ void ExpressionEvaluator::handle(UnaryOp* uop) {
123124
TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
124125
}
125126
break;
127+
case UnaryOpType::Abs:
128+
known_values_[uop->out()] = abs(*in);
129+
break;
126130
default:
127131
TORCH_CHECK(
128132
!"Unexpected operator type ",

torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ void ExpressionEvaluator::handle(const NamedScalar* named_scalar) {
132132
}
133133

134134
void ExpressionEvaluator::handle(const UnaryOp* unary_op) {
135+
using namespace IntOrDouble_functions;
135136
const auto in = evaluate(unary_op->in());
136137
if (in.has_value()) {
137138
switch (unary_op->getUnaryOpType()) {
@@ -150,6 +151,9 @@ void ExpressionEvaluator::handle(const UnaryOp* unary_op) {
150151
TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
151152
}
152153
break;
154+
case UnaryOpType::Abs:
155+
known_values_[unary_op->out()] = abs(*in);
156+
break;
153157
default:
154158
TORCH_CHECK(
155159
false,

0 commit comments

Comments
 (0)