diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index f30a193d4af324..1a227f2136ad2e 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -22,53 +22,6 @@ namespace detail { -template -struct ScalarTypeToCType; - -template<> -struct ScalarTypeToCType { - using type = at::Half; - - // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly - // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. - // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba - // TODO: remove once the bug is fixed. - static type t; -}; - -template<> -struct ScalarTypeToCType { - using type = at::BFloat16; - - // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly - // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. - // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba - // TODO: remove once the bug is fixed. - static type t; -}; - -template<> -struct ScalarTypeToCType { - using type = bool; - - // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly - // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. - // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba - // TODO: remove once the bug is fixed. - static type t; -}; - -template<> -struct ScalarTypeToCType { - using type = int64_t; - - // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly - // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. - // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba - // TODO: remove once the bug is fixed. - static int64_t t; -}; - inline at::ScalarType scalar_type(at::ScalarType s) { return s; } @@ -178,7 +131,8 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE, \ + decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ @@ -290,77 +244,77 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} } \ }() -#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ }() -#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ }() -#define AT_DISPATCH_ALL_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE3, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE3, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ }() -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE3, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE3, decltype(c10::impl::ScalarTypeToCPPType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + } \ }() // ---------------------------------------------------------------------------- diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index fec29519cc5c7f..a9eb1cf1e44eb7 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -169,7 +169,7 @@ Tensor& empty_out( return self.to(ScalarType::n, non_blocking); \ } -AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CAST_OP) +AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CAST_OP) #undef DEFINE_CAST_OP @@ -806,7 +806,7 @@ Tensor tensor_cuda(ArrayRef values, const TensorOptions& options) { return tensor_cpu(values, options); \ } \ } -AT_FORALL_SCALAR_TYPES_AND_BOOL(TENSOR) +AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, TENSOR) #undef TENSOR Tensor from_file(std::string filename, c10::optional shared, c10::optional size, const TensorOptions& options) { diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 9b5dc1f38175ae..50ad60384c5b99 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -44,7 +44,7 @@ namespace native { inline Tensor tensor(T value) { \ return native::tensor(ArrayRef(value)); \ } -AT_FORALL_SCALAR_TYPES_AND_BOOL(TENSOR) +AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, TENSOR) #undef TENSOR ${native_function_declarations} diff --git a/aten/src/TH/THBlasUtils.h b/aten/src/TH/THBlasUtils.h index 10b17a3851428a..47b74f85ac8248 100644 --- a/aten/src/TH/THBlasUtils.h +++ b/aten/src/TH/THBlasUtils.h @@ -16,7 +16,7 @@ inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_axpy(n, a, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(AXPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES(AXPY_SPECIALIZATION) template @@ -29,7 +29,7 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_copy(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(COPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES(COPY_SPECIALIZATION) template inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); @@ -40,7 +40,7 @@ inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); return TH ## name ## Blas_dot(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(DOT_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES(DOT_SPECIALIZATION) template inline void THBlas_gemm( @@ -78,7 +78,7 @@ inline void THBlas_gemm( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(GEMM_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES(GEMM_SPECIALIZATION) template inline void THBlas_gemv( @@ -111,4 +111,4 @@ inline void THBlas_gemv( TH ## name ## Blas_gemv(transa, m, n, alpha, a, lda, x, incx, beta, y, incy); \ } - AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(GEMV_SPECIALIZATION) + AT_FORALL_SCALAR_TYPES(GEMV_SPECIALIZATION) diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 5d66b506853ed0..afe36b5da98bdc 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -18,20 +18,15 @@ namespace c10 { * Scalar (which is why, for example, we provide both add(Tensor) and * add(Scalar) overloads for many operations). It may also be used in * circumstances where you statically know a tensor is 0-dim and single size, - * but don't know it's type. + * but don't know its type. */ class C10_API Scalar { public: Scalar() : Scalar(int64_t(0)) {} #define DEFINE_IMPLICIT_CTOR(type, name, member) \ - Scalar(type vv) : tag(Tag::HAS_##member) { \ - v.member = convert(vv); \ - } - // We can't set v in the initializer list using the - // syntax v{ .member = ... } because it doesn't work on MSVC - - AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR) + Scalar(type vv) : Scalar(vv, true) { } + AT_FORALL_SCALAR_TYPES_AND(c10::ScalarType::BFloat16, DEFINE_IMPLICIT_CTOR) #undef DEFINE_IMPLICIT_CTOR @@ -92,6 +87,23 @@ class C10_API Scalar { Scalar operator-() const; private: + template::is_integer, bool>::type* = + nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_i) { + v.i = convert(vv); + } + + template::is_integer, bool>::type* = + nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_d) { + v.d = convert(vv); + } + + // We can't set v in the initializer list using the + // syntax v{ .member = ... } because it doesn't work on MSVC + enum class Tag { HAS_d, HAS_i, HAS_z }; Tag tag; union { diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 5a91624b1088f0..4f85d19de552d8 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -11,80 +11,53 @@ namespace c10 { -// TODO: check all usages of these macro and make sure -// the use case makes sense for qint - -// NB: Order matters for this macro; it is relied upon in -// _promoteTypesLookup and the serialization format. +// For the macros below: // NB: QInt ScalarTypes are referred to as "STUBS" here since they do not // contain complete information to determine the tensor value of the data, // they are just stubs for dispatch / quantization. -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(_) \ - _(uint8_t, Byte, i) /* 0 */ \ - _(int8_t, Char, i) /* 1 */ \ - _(int16_t, Short, i) /* 2 */ \ - _(int, Int, i) /* 3 */ \ - _(int64_t, Long, i) /* 4 */ \ - _(at::Half, Half, d) /* 5 */ \ - _(float, Float, d) /* 6 */ \ - _(double, Double, d) /* 7 */ \ - _(at::ComplexHalf, ComplexHalf, z) /* 8 */ \ - _(std::complex, ComplexFloat, z) /* 9 */ \ - _(std::complex, ComplexDouble, z) /* 10 */ \ - _(bool, Bool, i) /* 11 */ \ - _(c10::qint8, QInt8, i) /* 12 */ \ - _(c10::quint8, QUInt8, i) /* 13 */ \ - _(c10::qint32, QInt32, i) /* 14 */ \ - _(at::BFloat16, BFloat16, d) /* 15 */ +// NB: If you want to macro some code for all non-stub scalar types, you +// probably want one of the AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND +// macros below, which are designed to behave similarly to the Dispatch macros +// with the same name. + +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(_) \ + _(uint8_t, Byte, __) /* 0 */ \ + _(int8_t, Char, __) /* 1 */ \ + _(int16_t, Short, __) /* 2 */ \ + _(int, Int, __) /* 3 */ \ + _(int64_t, Long, __) /* 4 */ \ + _(at::Half, Half, __) /* 5 */ \ + _(float, Float, __) /* 6 */ \ + _(double, Double, __) /* 7 */ \ + _(at::ComplexHalf, ComplexHalf, __) /* 8 */ \ + _(std::complex, ComplexFloat, __) /* 9 */ \ + _(std::complex, ComplexDouble, __) /* 10 */ \ + _(bool, Bool, __) /* 11 */ \ + _(c10::qint8, QInt8, __) /* 12 */ \ + _(c10::quint8, QUInt8, __) /* 13 */ \ + _(c10::qint32, QInt32, __) /* 14 */ \ + _(at::BFloat16, BFloat16, __) /* 15 */ + // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() // doesn't work for all the conversions you need... #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \ - _(uint8_t, Byte, i) \ - _(int8_t, Char, i) \ - _(int16_t, Short, i) \ - _(int, Int, i) \ - _(int64_t, Long, i) \ - _(at::Half, Half, d) \ - _(float, Float, d) \ - _(double, Double, d) \ - _(std::complex, ComplexFloat, z) \ - _(std::complex, ComplexDouble, z) \ - _(bool, Bool, i) \ - _(at::BFloat16, BFloat16, d) + _(uint8_t, Byte, __) \ + _(int8_t, Char, __) \ + _(int16_t, Short, __) \ + _(int, Int, __) \ + _(int64_t, Long, __) \ + _(at::Half, Half, __) \ + _(float, Float, __) \ + _(double, Double, __) \ + _(std::complex, ComplexFloat, __) \ + _(std::complex, ComplexDouble, __) \ + _(bool, Bool, __) \ + _(at::BFloat16, BFloat16, __) -#define AT_FORALL_SCALAR_TYPES(_) \ - _(uint8_t, Byte, i) \ - _(int8_t, Char, i) \ - _(int16_t, Short, i) \ - _(int, Int, i) \ - _(int64_t, Long, i) \ - _(at::Half, Half, d) \ - _(float, Float, d) \ - _(double, Double, d) \ - _(at::BFloat16, BFloat16, d) - -#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \ - _(uint8_t, Byte, i) \ - _(int8_t, Char, i) \ - _(int16_t, Short, i) \ - _(int, Int, i) \ - _(int64_t, Long, i) \ - _(at::Half, Half, d) \ - _(float, Float, d) \ - _(double, Double, d) \ - _(bool, Bool, i) \ - _(at::BFloat16, BFloat16, d) - -#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_BFLOAT16(_) \ - _(uint8_t, Byte, i) \ - _(int8_t, Char, i) \ - _(int16_t, Short, i) \ - _(int, Int, i) \ - _(int64_t, Long, i) \ - _(float, Float, d) \ - _(double, Double, d) #define AT_FORALL_QINT_TYPES(_) \ _(c10::qint8, QInt8, i) \ @@ -99,6 +72,92 @@ enum class ScalarType : int8_t { NumOptions }; +namespace impl { + +// These are used to map ScalarTypes to C++ types. Feel free to add more or even +// macro generate this; the examples here are just those we have found to be +// necessary. + +template +struct ScalarTypeToCPPType; + +template<> +struct ScalarTypeToCPPType { + using type = c10::Half; + + // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly + // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. + // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba + // TODO: remove once the bug is fixed. + static type t; +}; + +template<> +struct ScalarTypeToCPPType { + using type = c10::BFloat16; + + // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly + // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. + // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba + // TODO: remove once the bug is fixed. + static type t; +}; + +template<> +struct ScalarTypeToCPPType { + using type = bool; + + // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly + // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. + // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba + // TODO: remove once the bug is fixed. + static type t; +}; + +template<> +struct ScalarTypeToCPPType { + using type = int64_t; + + // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly + // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. + // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba + // TODO: remove once the bug is fixed. + static type t; +}; +} + +#define AT_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, Byte, __) \ + _(int8_t, Char, __) \ + _(int16_t, Short, __) \ + _(int, Int, __) \ + _(int64_t, Long, __) \ + _(float, Float, __) \ + _(double, Double, __) + +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte, __) \ + _(int8_t, Char, __) \ + _(int16_t, Short, __) \ + _(int, Int, __) \ + _(int64_t, Long, __) \ + _(at::Half, Half, __) \ + _(float, Float, __) \ + _(double, Double, __) \ + _(decltype(::c10::impl::ScalarTypeToCPPType::t), SCALARTYPE, __) + +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte, __) \ + _(int8_t, Char, __) \ + _(int16_t, Short, __) \ + _(int, Int, __) \ + _(int64_t, Long, __) \ + _(at::Half, Half, __) \ + _(float, Float, __) \ + _(double, Double, __) \ + _(decltype(::c10::impl::ScalarTypeToCPPType::t), SCALARTYPE1, __) \ + _(decltype(::c10::impl::ScalarTypeToCPPType::t), SCALARTYPE2, __) + static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { #define DEFINE_CASE(ctype, name, _) \ case ScalarType::name: \ diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index b2d8f68fe7ce2e..4c69f7266ce7bd 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -23,6 +23,7 @@ template <> class numeric_limits { public: static constexpr bool is_signed = true; + static constexpr bool is_integer = false; static constexpr bool has_infinity = true; static constexpr bool has_quiet_NaN = true; static constexpr c10::BFloat16 lowest() { diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index c2abdf7b4ced47..8eb3352c0685f2 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -15,7 +15,7 @@ static std::unordered_map op_to_key = { namespace caffe2 { -using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL +using at::Half; // for AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, ...) template class ATenOp : public Operator { @@ -47,7 +47,7 @@ class ATenOp : public Operator { case at::k##aten_name: \ return TypeMeta::Make(); switch(st) { - AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CASE) default: CAFFE_THROW("Unknown ATen Type"); } @@ -118,27 +118,28 @@ class ATenOp : public Operator { } } - // the AT_FORALL_SCALAR_TYPES_AND_BOOL macro just gives a 'i' or - // 'd' argument for each type to specify if it is stored as a integer or a - // double. We need this workaround here to extract the value in the scalar - // losslessly because in some cases like 'sum' Torch promotes float to double - // and will complain if we downcast it with toFloat, causing it - // to lose precision - double extract_d(const at::Scalar & s) { - return s.toDouble(); - } - int64_t extract_i(const at::Scalar & s) { + template::is_integer, bool>::type* = + nullptr> + int64_t extract(const at::Scalar &s) { return s.toLong(); } + template::is_integer, bool>::type* = + nullptr> + int64_t extract(const at::Scalar &s) { + return s.toDouble(); + } + void assignTo(Tensor* dst, at::ScalarType scalar_type, at::Scalar scalar) { switch(scalar_type) { - #define DEFINE_CASE(ctype,aten_name,native) \ + #define DEFINE_CASE(ctype,aten_name,_1) \ case at::k##aten_name: { \ - auto value = extract_##native(scalar); \ + auto value = extract(scalar); \ assignToValue(dst, at::convert(value)); \ } break; - AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CASE) #undef DEFINE_CASE default: CAFFE_THROW("Unknown ATen Type"); diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index 770de5dc095bd6..ccd69e2c76b6dc 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -80,7 +80,7 @@ void cast_op_cpu( int64_t to) { switch (input.scalar_type()) { #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl(input, output, to); - AT_FORALL_SCALAR_TYPES_AND_BOOL(CASE) + AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, CASE) #undef CASE default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type())); } diff --git a/tools/autograd/templates/variable_factories.h b/tools/autograd/templates/variable_factories.h index ec2927c4247ea2..97d266858f4ce6 100644 --- a/tools/autograd/templates/variable_factories.h +++ b/tools/autograd/templates/variable_factories.h @@ -44,7 +44,7 @@ namespace torch { inline at::Tensor tensor(T value) { \ return torch::tensor(at::ArrayRef(value)); \ } -AT_FORALL_SCALAR_TYPES_AND_BOOL(TENSOR) +AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, TENSOR) #undef TENSOR /// A generic deleter function.