Skip to content

Commit 56c00fd

Browse files
authored
Double support on all expression evaluators (#1937)
1 parent 371f282 commit 56c00fd

29 files changed

+364
-254
lines changed

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
597597
vector_size_optional.has_value(),
598598
"Could not evaluate constant value bound to vectorized dim.");
599599

600-
vector_word_size = vector_size_optional.value();
600+
vector_word_size = vector_size_optional->as<int64_t>();
601601

602602
vectorize_op = id->getParallelType() == ParallelType::Vectorize;
603603
misaligned_op =
@@ -1267,7 +1267,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
12671267
TORCH_INTERNAL_ASSERT(
12681268
id->getParallelType() != ParallelType::MisalignedVectorize,
12691269
"LoadStoreOp: no support yet for mis-aligned vectorization");
1270-
vector_word_size = vector_size_optional.value();
1270+
vector_word_size = vector_size_optional->as<int64_t>();
12711271
vectorize_op = true;
12721272
break;
12731273
}

torch/csrc/jit/codegen/cuda/dynamic_type.h

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class TORCH_CUDA_CU_API IntOrDouble {
4040

4141
template <typename T>
4242
T as() const {
43-
TORCH_CHECK(c10::holds_alternative<T>(value_), "wrong type");
43+
TORCH_CHECK(
44+
c10::holds_alternative<T>(value_), "dtype not supported in evaluator");
4445
return c10::get<T>(value_);
4546
}
4647

@@ -145,8 +146,19 @@ class TORCH_CUDA_CU_API IntOrDouble {
145146
} \
146147
TORCH_INTERNAL_ASSERT(false); \
147148
} \
148-
template <typename T> \
149-
bool operator op(T other) { \
149+
bool operator op(double other) { \
150+
if (is_int()) { \
151+
return as<int64_t>() op other; \
152+
} \
153+
return as<double>() op other; \
154+
} \
155+
bool operator op(int64_t other) { \
156+
if (is_int()) { \
157+
return as<int64_t>() op other; \
158+
} \
159+
return as<double>() op other; \
160+
} \
161+
bool operator op(int other) { \
150162
if (is_int()) { \
151163
return as<int64_t>() op other; \
152164
} \
@@ -169,21 +181,10 @@ class TORCH_CUDA_CU_API IntOrDouble {
169181
return IntOrDouble(-as<double>());
170182
}
171183

172-
template <typename T>
173-
bool operator==(T val) const {
174-
return operator==(IntOrDouble(val));
175-
}
176-
177-
template <typename T>
178-
bool operator!=(T val) const {
179-
return operator!=(IntOrDouble(val));
180-
}
181-
182-
operator double() const;
183-
184-
operator int64_t() const;
185-
operator size_t() const;
186-
operator int() const;
184+
explicit operator double() const;
185+
explicit operator int64_t() const;
186+
explicit operator size_t() const;
187+
explicit operator int() const;
187188
};
188189

189190
#define DEFINE_ARITHMETIC_OP(op) \
@@ -269,7 +270,13 @@ namespace IntOrDouble_functions {
269270

270271
inline IntOrDouble ceildiv(const IntOrDouble& a, const IntOrDouble& b) {
271272
if (a.is_int() && b.is_int()) {
272-
return (a.as<int64_t>() + b.as<int64_t>() - 1) / b.as<int64_t>();
273+
auto aa = a.as<int64_t>();
274+
auto bb = b.as<int64_t>();
275+
if (bb > 0) {
276+
return (aa + bb - 1) / bb;
277+
} else {
278+
return (aa + bb + 1) / bb;
279+
}
273280
}
274281
return std::ceil((a / b).as<double>());
275282
}

0 commit comments

Comments
 (0)