@@ -40,7 +40,8 @@ class TORCH_CUDA_CU_API IntOrDouble {
40
40
41
41
template <typename T>
42
42
T as () const {
43
- TORCH_CHECK (c10::holds_alternative<T>(value_), " wrong type" );
43
+ TORCH_CHECK (
44
+ c10::holds_alternative<T>(value_), " dtype not supported in evaluator" );
44
45
return c10::get<T>(value_);
45
46
}
46
47
@@ -145,8 +146,19 @@ class TORCH_CUDA_CU_API IntOrDouble {
145
146
} \
146
147
TORCH_INTERNAL_ASSERT (false ); \
147
148
} \
148
- template <typename T> \
149
- bool operator op (T other) { \
149
+ bool operator op (double other) { \
150
+ if (is_int ()) { \
151
+ return as<int64_t >() op other; \
152
+ } \
153
+ return as<double >() op other; \
154
+ } \
155
+ bool operator op (int64_t other) { \
156
+ if (is_int ()) { \
157
+ return as<int64_t >() op other; \
158
+ } \
159
+ return as<double >() op other; \
160
+ } \
161
+ bool operator op (int other) { \
150
162
if (is_int ()) { \
151
163
return as<int64_t >() op other; \
152
164
} \
@@ -169,21 +181,10 @@ class TORCH_CUDA_CU_API IntOrDouble {
169
181
return IntOrDouble (-as<double >());
170
182
}
171
183
172
- template <typename T>
173
- bool operator ==(T val) const {
174
- return operator ==(IntOrDouble (val));
175
- }
176
-
177
- template <typename T>
178
- bool operator !=(T val) const {
179
- return operator !=(IntOrDouble (val));
180
- }
181
-
182
- operator double () const ;
183
-
184
- operator int64_t () const ;
185
- operator size_t () const ;
186
- operator int () const ;
184
+ explicit operator double () const ;
185
+ explicit operator int64_t () const ;
186
+ explicit operator size_t () const ;
187
+ explicit operator int () const ;
187
188
};
188
189
189
190
#define DEFINE_ARITHMETIC_OP (op ) \
@@ -269,7 +270,13 @@ namespace IntOrDouble_functions {
269
270
270
271
inline IntOrDouble ceildiv (const IntOrDouble& a, const IntOrDouble& b) {
271
272
if (a.is_int () && b.is_int ()) {
272
- return (a.as <int64_t >() + b.as <int64_t >() - 1 ) / b.as <int64_t >();
273
+ auto aa = a.as <int64_t >();
274
+ auto bb = b.as <int64_t >();
275
+ if (bb > 0 ) {
276
+ return (aa + bb - 1 ) / bb;
277
+ } else {
278
+ return (aa + bb + 1 ) / bb;
279
+ }
273
280
}
274
281
return std::ceil ((a / b).as <double >());
275
282
}
0 commit comments