Skip to content

Align AT_FORALL macros with AT_DISPATCH macros. #23339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 67 additions & 113 deletions aten/src/ATen/Dispatch.h

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Tensor& empty_out(
return self.to(ScalarType::n, non_blocking); \
}

AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CAST_OP)
AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CAST_OP)

#undef DEFINE_CAST_OP

Expand Down Expand Up @@ -806,7 +806,7 @@ Tensor tensor_cuda(ArrayRef<T> values, const TensorOptions& options) {
return tensor_cpu(values, options); \
} \
}
AT_FORALL_SCALAR_TYPES_AND_BOOL(TENSOR)
AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, TENSOR)
#undef TENSOR

Tensor from_file(std::string filename, c10::optional<bool> shared, c10::optional<int64_t> size, const TensorOptions& options) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/NativeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace native {
inline Tensor tensor(T value) { \
return native::tensor(ArrayRef<T>(value)); \
}
AT_FORALL_SCALAR_TYPES_AND_BOOL(TENSOR)
AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, TENSOR)
#undef TENSOR

${native_function_declarations}
Expand Down
10 changes: 5 additions & 5 deletions aten/src/TH/THBlasUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);
TH ## name ## Blas_axpy(n, a, x, incx, y, incy); \
}

AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(AXPY_SPECIALIZATION)
AT_FORALL_SCALAR_TYPES(AXPY_SPECIALIZATION)


template<typename T>
Expand All @@ -29,7 +29,7 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy);
TH ## name ## Blas_copy(n, x, incx, y, incy); \
}

AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(COPY_SPECIALIZATION)
AT_FORALL_SCALAR_TYPES(COPY_SPECIALIZATION)

template<typename T>
inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy);
Expand All @@ -40,7 +40,7 @@ inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy);
return TH ## name ## Blas_dot(n, x, incx, y, incy); \
}

AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(DOT_SPECIALIZATION)
AT_FORALL_SCALAR_TYPES(DOT_SPECIALIZATION)

template<typename T>
inline void THBlas_gemm(
Expand Down Expand Up @@ -78,7 +78,7 @@ inline void THBlas_gemm(
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \
}

AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(GEMM_SPECIALIZATION)
AT_FORALL_SCALAR_TYPES(GEMM_SPECIALIZATION)

template <typename T>
inline void THBlas_gemv(
Expand Down Expand Up @@ -111,4 +111,4 @@ inline void THBlas_gemv(
TH ## name ## Blas_gemv(transa, m, n, alpha, a, lda, x, incx, beta, y, incy); \
}

AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(GEMV_SPECIALIZATION)
AT_FORALL_SCALAR_TYPES(GEMV_SPECIALIZATION)
28 changes: 20 additions & 8 deletions c10/core/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,15 @@ namespace c10 {
* Scalar (which is why, for example, we provide both add(Tensor) and
* add(Scalar) overloads for many operations). It may also be used in
* circumstances where you statically know a tensor is 0-dim and single size,
* but don't know it's type.
* but don't know its type.
*/
class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}

#define DEFINE_IMPLICIT_CTOR(type, name, member) \
Scalar(type vv) : tag(Tag::HAS_##member) { \
v.member = convert<decltype(v.member), type>(vv); \
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC

AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR)
Scalar(type vv) : Scalar(vv, true) { }
AT_FORALL_SCALAR_TYPES_AND(c10::ScalarType::BFloat16, DEFINE_IMPLICIT_CTOR)

#undef DEFINE_IMPLICIT_CTOR

Expand Down Expand Up @@ -92,6 +87,23 @@ class C10_API Scalar {
Scalar operator-() const;

private:
template<typename T,
typename std::enable_if<std::numeric_limits<T>::is_integer, bool>::type* =
nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_i) {
v.i = convert<decltype(v.i), T>(vv);
}

template<typename T,
typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type* =
nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_d) {
v.d = convert<decltype(v.d), T>(vv);
}

// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC

enum class Tag { HAS_d, HAS_i, HAS_z };
Tag tag;
union {
Expand Down
189 changes: 124 additions & 65 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,80 +11,53 @@

namespace c10 {

// TODO: check all usages of these macro and make sure
// the use case makes sense for qint

// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
// For the macros below:
// NB: QInt ScalarTypes are referred to as "STUBS" here since they do not
// contain complete information to determine the tensor value of the data,
// they are just stubs for dispatch / quantization.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(_) \
_(uint8_t, Byte, i) /* 0 */ \
_(int8_t, Char, i) /* 1 */ \
_(int16_t, Short, i) /* 2 */ \
_(int, Int, i) /* 3 */ \
_(int64_t, Long, i) /* 4 */ \
_(at::Half, Half, d) /* 5 */ \
_(float, Float, d) /* 6 */ \
_(double, Double, d) /* 7 */ \
_(at::ComplexHalf, ComplexHalf, z) /* 8 */ \
_(std::complex<float>, ComplexFloat, z) /* 9 */ \
_(std::complex<double>, ComplexDouble, z) /* 10 */ \
_(bool, Bool, i) /* 11 */ \
_(c10::qint8, QInt8, i) /* 12 */ \
_(c10::quint8, QUInt8, i) /* 13 */ \
_(c10::qint32, QInt32, i) /* 14 */ \
_(at::BFloat16, BFloat16, d) /* 15 */
// NB: If you want to macro some code for all non-stub scalar types, you
// probably want one of the AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
// macros below, which are designed to behave similarly to the Dispatch macros
// with the same name.

// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(_) \
_(uint8_t, Byte, __) /* 0 */ \
_(int8_t, Char, __) /* 1 */ \
_(int16_t, Short, __) /* 2 */ \
_(int, Int, __) /* 3 */ \
_(int64_t, Long, __) /* 4 */ \
_(at::Half, Half, __) /* 5 */ \
_(float, Float, __) /* 6 */ \
_(double, Double, __) /* 7 */ \
_(at::ComplexHalf, ComplexHalf, __) /* 8 */ \
_(std::complex<float>, ComplexFloat, __) /* 9 */ \
_(std::complex<double>, ComplexDouble, __) /* 10 */ \
_(bool, Bool, __) /* 11 */ \
_(c10::qint8, QInt8, __) /* 12 */ \
_(c10::quint8, QUInt8, __) /* 13 */ \
_(c10::qint32, QInt32, __) /* 14 */ \
_(at::BFloat16, BFloat16, __) /* 15 */


// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
_(uint8_t, Byte, i) \
_(int8_t, Char, i) \
_(int16_t, Short, i) \
_(int, Int, i) \
_(int64_t, Long, i) \
_(at::Half, Half, d) \
_(float, Float, d) \
_(double, Double, d) \
_(std::complex<float>, ComplexFloat, z) \
_(std::complex<double>, ComplexDouble, z) \
_(bool, Bool, i) \
_(at::BFloat16, BFloat16, d)
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(at::Half, Half, __) \
_(float, Float, __) \
_(double, Double, __) \
_(std::complex<float>, ComplexFloat, __) \
_(std::complex<double>, ComplexDouble, __) \
_(bool, Bool, __) \
_(at::BFloat16, BFloat16, __)

#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte, i) \
_(int8_t, Char, i) \
_(int16_t, Short, i) \
_(int, Int, i) \
_(int64_t, Long, i) \
_(at::Half, Half, d) \
_(float, Float, d) \
_(double, Double, d) \
_(at::BFloat16, BFloat16, d)

#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \
_(uint8_t, Byte, i) \
_(int8_t, Char, i) \
_(int16_t, Short, i) \
_(int, Int, i) \
_(int64_t, Long, i) \
_(at::Half, Half, d) \
_(float, Float, d) \
_(double, Double, d) \
_(bool, Bool, i) \
_(at::BFloat16, BFloat16, d)

#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(_) \
_(uint8_t, Byte, i) \
_(int8_t, Char, i) \
_(int16_t, Short, i) \
_(int, Int, i) \
_(int64_t, Long, i) \
_(float, Float, d) \
_(double, Double, d)

#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8, i) \
Expand All @@ -99,6 +72,92 @@ enum class ScalarType : int8_t {
NumOptions
};

namespace impl {

// These are used to map ScalarTypes to C++ types. Feel free to add more or even
// macro generate this; the examples here are just those we have found to be
// necessary.

template <c10::ScalarType N>
struct ScalarTypeToCPPType;

template<>
struct ScalarTypeToCPPType<c10::ScalarType::Half> {
using type = c10::Half;

// This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
// due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
// For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
// TODO: remove once the bug is fixed.
static type t;
};

template<>
struct ScalarTypeToCPPType<c10::ScalarType::BFloat16> {
using type = c10::BFloat16;

// This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
// due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
// For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
// TODO: remove once the bug is fixed.
static type t;
};

template<>
struct ScalarTypeToCPPType<c10::ScalarType::Bool> {
using type = bool;

// This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
// due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
// For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
// TODO: remove once the bug is fixed.
static type t;
};

template<>
struct ScalarTypeToCPPType<c10::ScalarType::Long> {
using type = int64_t;

// This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
// due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
// For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
// TODO: remove once the bug is fixed.
static type t;
};
}

#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(float, Float, __) \
_(double, Double, __)

#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(at::Half, Half, __) \
_(float, Float, __) \
_(double, Double, __) \
_(decltype(::c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), SCALARTYPE, __)

#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(at::Half, Half, __) \
_(float, Float, __) \
_(double, Double, __) \
_(decltype(::c10::impl::ScalarTypeToCPPType<c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1, __) \
_(decltype(::c10::impl::ScalarTypeToCPPType<c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2, __)

static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
#define DEFINE_CASE(ctype, name, _) \
case ScalarType::name: \
Expand Down
1 change: 1 addition & 0 deletions c10/util/BFloat16-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ template <>
class numeric_limits<c10::BFloat16> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr c10::BFloat16 lowest() {
Expand Down
31 changes: 16 additions & 15 deletions caffe2/contrib/aten/aten_op_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ static std::unordered_map<std::string, int> op_to_key = {

namespace caffe2 {

using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL
using at::Half; // for AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, ...)

template <class Context>
class ATenOp : public Operator<Context> {
Expand Down Expand Up @@ -47,7 +47,7 @@ class ATenOp : public Operator<Context> {
case at::k##aten_name: \
return TypeMeta::Make<ctype>();
switch(st) {
AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE)
AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CASE)
default:
CAFFE_THROW("Unknown ATen Type");
}
Expand Down Expand Up @@ -118,27 +118,28 @@ class ATenOp : public Operator<Context> {
}
}

// the AT_FORALL_SCALAR_TYPES_AND_BOOL macro just gives a 'i' or
// 'd' argument for each type to specify if it is stored as a integer or a
// double. We need this workaround here to extract the value in the scalar
// losslessly because in some cases like 'sum' Torch promotes float to double
// and will complain if we downcast it with toFloat, causing it
// to lose precision
double extract_d(const at::Scalar & s) {
return s.toDouble();
}
int64_t extract_i(const at::Scalar & s) {
template<typename T,
typename std::enable_if<std::numeric_limits<T>::is_integer, bool>::type* =
nullptr>
int64_t extract(const at::Scalar &s) {
return s.toLong();
}

template<typename T,
typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type* =
nullptr>
int64_t extract(const at::Scalar &s) {
return s.toDouble();
}

void assignTo(Tensor* dst, at::ScalarType scalar_type, at::Scalar scalar) {
switch(scalar_type) {
#define DEFINE_CASE(ctype,aten_name,native) \
#define DEFINE_CASE(ctype,aten_name,_1) \
case at::k##aten_name: { \
auto value = extract_##native(scalar); \
auto value = extract<ctype>(scalar); \
assignToValue<ctype>(dst, at::convert<ctype,decltype(value)>(value)); \
} break;
AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE)
AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CASE)
#undef DEFINE_CASE
default:
CAFFE_THROW("Unknown ATen Type");
Expand Down
Loading