Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,9 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
end = castOp(DataType::Double, end);
step = castOp(DataType::Double, step);
}
auto size = castOp(DataType::Int, ceilDiv(sub(end, start), step));
// Make sure no negative value is passed to ceilDiv as the device
// implementation of ceilDiv assumes positive inputs
auto size = castOp(DataType::Int, ceilDiv(abs(sub(end, start)), abs(step)));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm not really sure what the right behavior should be when the signs of end - start and step are different, which I think is invalid. Do we need to do error checking? Doing so in device code could be very costly, so we probably would want to do on the host side before launching kernels.

In any case, unless we want to check it on device, the above workaround should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need to check it on the host side, but I am not sure how. Would it be a good idea to have both a BinaryOpType::CeilDiv and a BinaryOpType::CeilDivMaybeNegative, and we use CeilDivMaybeNegative for arange, and use CeilDiv for others?

Copy link
Collaborator Author

@naoyam naoyam Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we would need additional ceilDiv expr types. I think it should be possible to determine if it's guaranteed to be safe or not by looking at Expr nodes, and we could just do host-side checks of those ceilDiv exprs that are not determined to be safe.

That said, I don't think it's important at this point. At least for now, I think it'd be just reasonable if nvFuser works if a given program is correct but could result in undefined behavior if an unsound program is given. We would definitely want to leave a note about the lack of the check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, if it should be validated anyway, I think we would want to extend KernelIrScanner to mark potentially unsafe ceilDiv exprs. As I mentioned above, I think we should be able to just analyze the input vals of those exprs and see if they are composed of just non-negative values such as domain extents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note is fine for now. We can revisit in the future if needed.

auto out = TensorViewBuilder()
.ndims(1)
.dtype(dtype)
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ TORCH_CUDA_CU_API WelfordResult Welford(
TORCH_CUDA_CU_API TensorView* rand(
const std::vector<Val*>& shape,
DataType dtype);

//! WARNING: giving invalid combinations of the start, end and step
//! arguments can result in undefined behavior. Specifically, the
//! signs of `end - start` and step must be the same.
TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype = DataType::Int);
TORCH_CUDA_CU_API TensorView* arange(
Val* start,
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/codegen/cuda/dynamic_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ inline IntOrDouble min(const IntOrDouble& a, const IntOrDouble& b) {
return (a < b ? a : b).cast<double>();
}

inline IntOrDouble abs(const IntOrDouble& a) {
if (a.is_int()) {
return IntOrDouble(std::abs(a.as<int64_t>()));
} else {
return IntOrDouble(std::abs(a.as<double>()));
}
}

} // namespace IntOrDouble_functions

} // namespace cuda
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ void NaiveValueMachine<IRContext>::runInstruction(int index) {

template <typename IRContext>
void NaiveValueMachine<IRContext>::runUnaryOp(int index) {
using namespace IntOrDouble_functions;
int src_index = src0_[index];
bool src_defined = precomputed_values_.defined_[src_index];
bool src_is_const = precomputed_values_.is_constant_[src_index];
Expand Down Expand Up @@ -323,6 +324,9 @@ void NaiveValueMachine<IRContext>::runUnaryOp(int index) {
TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
}
break;
case UnaryOpType::Abs:
dest = abs(src);
break;
default:
TORCH_CHECK(!"Unexpected operator type ", uop_type_[index]);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ c10::optional<IntOrDouble> ExpressionEvaluator::getValue(Val* value) {
}

void ExpressionEvaluator::handle(UnaryOp* uop) {
using namespace IntOrDouble_functions;
const auto in = evaluate(uop->in());
if (in.has_value()) {
switch (uop->getUnaryOpType()) {
Expand All @@ -123,6 +124,9 @@ void ExpressionEvaluator::handle(UnaryOp* uop) {
TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
}
break;
case UnaryOpType::Abs:
known_values_[uop->out()] = abs(*in);
break;
default:
TORCH_CHECK(
!"Unexpected operator type ",
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void ExpressionEvaluator::handle(const NamedScalar* named_scalar) {
}

void ExpressionEvaluator::handle(const UnaryOp* unary_op) {
using namespace IntOrDouble_functions;
const auto in = evaluate(unary_op->in());
if (in.has_value()) {
switch (unary_op->getUnaryOpType()) {
Expand All @@ -150,6 +151,9 @@ void ExpressionEvaluator::handle(const UnaryOp* unary_op) {
TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
}
break;
case UnaryOpType::Abs:
known_values_[unary_op->out()] = abs(*in);
break;
default:
TORCH_CHECK(
false,
Expand Down