Skip to content

Commit 0c6c2e2

Browse files
authored
Merge pull request #171 from iotamudelta/ifu
Merge from upstream
2 parents 9abdcf5 + ab9996c commit 0c6c2e2

File tree

19 files changed

+155
-81
lines changed

19 files changed

+155
-81
lines changed

aten/src/ATen/DLConvertor.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ static DLDataType getDLDataType(const Type& type) {
3737
case ScalarType::Half:
3838
dtype.code = DLDataTypeCode::kDLFloat;
3939
break;
40+
case ScalarType::ComplexHalf:
41+
throw std::logic_error("ComplexHalf is not supported by dlpack");
42+
case ScalarType::ComplexFloat:
43+
throw std::logic_error("ComplexFloat is not supported by dlpack");
44+
case ScalarType::ComplexDouble:
45+
throw std::logic_error("ComplexDouble is not supported by dlpack");
4046
case ScalarType::Undefined:
4147
throw std::logic_error("Undefined is not a valid ScalarType");
4248
case ScalarType::NumOptions:

aten/src/ATen/TensorOptions.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
#pragma once
22

33
#include <ATen/core/Backend.h>
4-
#include <ATen/Context.h>
54
#include <ATen/core/Device.h>
6-
#include <ATen/DeviceGuard.h>
75
#include <ATen/core/Layout.h>
86
#include <ATen/core/ScalarType.h>
9-
#include <ATen/Type.h>
107

118
#include <cstddef>
129
#include <iosfwd>
@@ -62,16 +59,6 @@ struct AT_API TensorOptions {
6259
/// - requires_grad: false
6360
explicit TensorOptions(bool use_thread_local_default_options);
6461

65-
/// Constructs the `TensorOptions` from a type and a `device_index`.
66-
/* implicit */ TensorOptions(
67-
const Type& type,
68-
int32_t device_index = -1) {
69-
this->dtype(type.scalarType());
70-
this->device({backendToDeviceType(type.backend()), device_index});
71-
this->layout(type.layout());
72-
this->is_variable(type.is_variable());
73-
}
74-
7562
/// Constructs a `TensorOptions` object with the given layout.
7663
/* implicit */ TensorOptions(Layout layout) : TensorOptions() {
7764
this->layout(layout);

aten/src/ATen/core/Half.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ struct alignas(2) Half {
6868
#endif
6969
};
7070

71+
// This is just a placeholder for whatever complex representation we
72+
// end up deciding to use for half-precision complex numbers.
73+
struct alignas(4) ComplexHalf {
74+
Half real_;
75+
Half imag_;
76+
ComplexHalf() = default;
77+
};
78+
7179
template <typename To, typename From>
7280
To convert(From f) {
7381
return static_cast<To>(f);

aten/src/ATen/core/ScalarType.h

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,34 @@
66

77
#include <cstdint>
88
#include <iostream>
9+
#include <complex>
910

1011
namespace at {
1112

1213
// NB: Order matters for this macro; it is relied upon in
1314
// _promoteTypesLookup and the serialization format.
14-
#define AT_FORALL_SCALAR_TYPES(_) \
15+
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
1516
_(uint8_t,Byte,i) /* 0 */ \
1617
_(int8_t,Char,i) /* 1 */ \
1718
_(int16_t,Short,i) /* 2 */ \
1819
_(int,Int,i) /* 3 */ \
1920
_(int64_t,Long,i) /* 4 */ \
2021
_(at::Half,Half,d) /* 5 */ \
2122
_(float,Float,d) /* 6 */ \
22-
_(double,Double,d) /* 7 */
23+
_(double,Double,d) /* 7 */ \
24+
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
25+
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
26+
_(std::complex<double>,ComplexDouble,z) /* 10 */
27+
28+
#define AT_FORALL_SCALAR_TYPES(_) \
29+
_(uint8_t,Byte,i) \
30+
_(int8_t,Char,i) \
31+
_(int16_t,Short,i) \
32+
_(int,Int,i) \
33+
_(int64_t,Long,i) \
34+
_(at::Half,Half,d) \
35+
_(float,Float,d) \
36+
_(double,Double,d)
2337

2438
#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
2539
_(uint8_t,Byte,i) \
@@ -33,9 +47,9 @@ _(double,Double,d)
3347
enum class ScalarType {
3448
#define DEFINE_ENUM(_1,n,_2) \
3549
n,
36-
AT_FORALL_SCALAR_TYPES(DEFINE_ENUM)
50+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
3751
#undef DEFINE_ENUM
38-
Undefined, // 8
52+
Undefined,
3953
NumOptions
4054
};
4155

@@ -44,7 +58,7 @@ static inline DataType scalarTypeToDataType(ScalarType scalar_type) {
4458
case ScalarType:: name : return caffe2::TypeMeta::Id<ctype>();
4559

4660
switch(scalar_type) {
47-
AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
61+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
4862
case ScalarType::Undefined: return DataType::uninitialized();
4963
default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
5064
}
@@ -56,7 +70,7 @@ static inline ScalarType dataTypeToScalarType(DataType dtype) {
5670
if (dtype == caffe2::TypeMeta::Id<ctype>()) { \
5771
return ScalarType:: name; \
5872
}
59-
AT_FORALL_SCALAR_TYPES(DEFINE_IF)
73+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
6074
#undef DEFINE_IF
6175
if (dtype == at::DataType::uninitialized()) {
6276
return ScalarType::Undefined;
@@ -67,15 +81,15 @@ static inline ScalarType dataTypeToScalarType(DataType dtype) {
6781
#define DEFINE_CONSTANT(_,name,_2) \
6882
constexpr ScalarType k##name = ScalarType::name;
6983

70-
AT_FORALL_SCALAR_TYPES(DEFINE_CONSTANT)
84+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
7185
#undef DEFINE_CONSTANT
7286

7387
static inline const char * toString(ScalarType t) {
7488
#define DEFINE_CASE(_,name,_2) \
7589
case ScalarType:: name : return #name;
7690

7791
switch(t) {
78-
AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
92+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
7993
default:
8094
return "UNKNOWN_SCALAR";
8195
}
@@ -87,7 +101,7 @@ static inline size_t elementSize(ScalarType t) {
87101
case ScalarType:: name : return sizeof(ctype);
88102

89103
switch(t) {
90-
AT_FORALL_SCALAR_TYPES(CASE_ELEMENTSIZE_CASE)
104+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
91105
default:
92106
AT_ERROR("Unknown ScalarType");
93107
}
@@ -108,6 +122,12 @@ static inline bool isFloatingType(ScalarType t) {
108122
t == ScalarType::Half);
109123
}
110124

125+
static inline bool isComplexType(ScalarType t) {
126+
return (t == ScalarType::ComplexHalf ||
127+
t == ScalarType::ComplexFloat ||
128+
t == ScalarType::ComplexDouble);
129+
}
130+
111131
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
112132
// This is generated according to NumPy's promote_types
113133
constexpr auto u1 = ScalarType::Byte;
@@ -119,19 +139,24 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
119139
constexpr auto f4 = ScalarType::Float;
120140
constexpr auto f8 = ScalarType::Double;
121141
constexpr auto ud = ScalarType::Undefined;
142+
if (a == ud || b == ud) {
143+
return ScalarType::Undefined;
144+
}
145+
if (isComplexType(a) || isComplexType(b)) {
146+
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
147+
}
122148
static constexpr ScalarType _promoteTypesLookup
123149
[static_cast<int>(ScalarType::NumOptions)]
124150
[static_cast<int>(ScalarType::NumOptions)] = {
125-
/* u1 i1 i2 i4 i8 f2 f4 f8, ud */
126-
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud },
127-
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud },
128-
/* i2 */ { i2, i2, i2, i4, i8, f4, f4, f8, ud },
129-
/* i4 */ { i4, i4, i4, i4, i8, f8, f4, f8, ud },
130-
/* i8 */ { i8, i8, i8, i8, i8, f8, f4, f8, ud },
131-
/* f2 */ { f2, f2, f4, f8, f8, f2, f4, f8, ud },
132-
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud },
133-
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud },
134-
/* ud */ { ud, ud, ud, ud, ud, ud, ud, ud, ud },
151+
/* u1 i1 i2 i4 i8 f2 f4 f8 */
152+
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8 },
153+
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8 },
154+
/* i2 */ { i2, i2, i2, i4, i8, f4, f4, f8 },
155+
/* i4 */ { i4, i4, i4, i4, i8, f8, f4, f8 },
156+
/* i8 */ { i8, i8, i8, i8, i8, f8, f4, f8 },
157+
/* f2 */ { f2, f2, f4, f8, f8, f2, f4, f8 },
158+
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8 },
159+
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8 },
135160
};
136161
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
137162
}

aten/src/ATen/core/ScalarTypeUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ template <> \
1313
struct CTypeToScalarType<ct> { \
1414
static inline at::ScalarType to() { return at::ScalarType::st; } \
1515
};
16-
AT_FORALL_SCALAR_TYPES(DEFINE_TO_SCALAR_TYPE)
16+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO_SCALAR_TYPE)
1717
#undef DEFINE_TO_SCALAR_TYPE
1818

1919
} // namespace at

aten/src/ATen/core/typeid.h

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <unordered_map>
1111
#include <unordered_set>
1212
#include <vector>
13+
#include <complex>
1314
#ifdef __GXX_RTTI
1415
#include <typeinfo>
1516
#endif
@@ -466,26 +467,29 @@ CAFFE_DECLARE_KNOWN_TYPE(4, int64_t)
466467
CAFFE_DECLARE_KNOWN_TYPE(5, at::Half)
467468
CAFFE_DECLARE_KNOWN_TYPE(6, float)
468469
CAFFE_DECLARE_KNOWN_TYPE(7, double)
469-
// 8 = undefined type id
470-
471-
CAFFE_DECLARE_KNOWN_TYPE(9, Tensor)
472-
CAFFE_DECLARE_KNOWN_TYPE(10, std::string)
473-
CAFFE_DECLARE_KNOWN_TYPE(11, bool)
474-
CAFFE_DECLARE_KNOWN_TYPE(12, uint16_t)
475-
CAFFE_DECLARE_KNOWN_TYPE(13, char)
476-
CAFFE_DECLARE_KNOWN_TYPE(14, std::unique_ptr<std::mutex>)
477-
CAFFE_DECLARE_KNOWN_TYPE(15, std::unique_ptr<std::atomic<bool>>)
478-
CAFFE_DECLARE_KNOWN_TYPE(16, std::vector<int32_t>)
479-
CAFFE_DECLARE_KNOWN_TYPE(17, std::vector<int64_t>)
480-
CAFFE_DECLARE_KNOWN_TYPE(18, std::vector<unsigned long>)
481-
CAFFE_DECLARE_KNOWN_TYPE(19, bool*)
482-
CAFFE_DECLARE_KNOWN_TYPE(20, char*)
483-
CAFFE_DECLARE_KNOWN_TYPE(21, int*)
470+
CAFFE_DECLARE_KNOWN_TYPE(8, at::ComplexHalf)
471+
CAFFE_DECLARE_KNOWN_TYPE(9, std::complex<float>)
472+
CAFFE_DECLARE_KNOWN_TYPE(10, std::complex<double>)
473+
// 10 = undefined type id
474+
475+
CAFFE_DECLARE_KNOWN_TYPE(12, Tensor)
476+
CAFFE_DECLARE_KNOWN_TYPE(13, std::string)
477+
CAFFE_DECLARE_KNOWN_TYPE(14, bool)
478+
CAFFE_DECLARE_KNOWN_TYPE(15, uint16_t)
479+
CAFFE_DECLARE_KNOWN_TYPE(16, char)
480+
CAFFE_DECLARE_KNOWN_TYPE(17, std::unique_ptr<std::mutex>)
481+
CAFFE_DECLARE_KNOWN_TYPE(18, std::unique_ptr<std::atomic<bool>>)
482+
CAFFE_DECLARE_KNOWN_TYPE(19, std::vector<int32_t>)
483+
CAFFE_DECLARE_KNOWN_TYPE(20, std::vector<int64_t>)
484+
CAFFE_DECLARE_KNOWN_TYPE(21, std::vector<unsigned long>)
485+
CAFFE_DECLARE_KNOWN_TYPE(22, bool*)
486+
CAFFE_DECLARE_KNOWN_TYPE(23, char*)
487+
CAFFE_DECLARE_KNOWN_TYPE(24, int*)
484488

485489
#ifdef CAFFE2_UNIQUE_LONG_TYPEMETA
486-
CAFFE_DECLARE_KNOWN_TYPE(22, long)
487-
CAFFE_DECLARE_KNOWN_TYPE(23, std::vector<long>)
490+
CAFFE_DECLARE_KNOWN_TYPE(25, long)
491+
CAFFE_DECLARE_KNOWN_TYPE(26, std::vector<long>)
488492
#endif // CAFFE2_UNIQUE_LONG_TYPEMETA
489493

490-
CAFFE_DECLARE_KNOWN_TYPE(24, _CaffeHighestPreallocatedTypeId)
494+
CAFFE_DECLARE_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId)
491495
} // namespace caffe2

aten/src/ATen/function_wrapper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,8 @@ def TypedDict(name, attrs, total=True): # type: ignore
8989
""")
9090
DEPRECATED_TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\
9191
${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
92-
TensorOptions options(*this);
9392
${device_guard_declaration}
94-
return at::native::${api_name}(${type_method_actuals}, options);
93+
return at::native::${api_name}(${type_method_actuals}, options());
9594
}
9695
""")
9796
# 4. add virtual override to TypeDerived.h
@@ -165,7 +164,7 @@ def TypedDict(name, attrs, total=True): # type: ignore
165164
# special method definition for *deprecated* factory functions in Functions.h
166165
DEPRECATED_FACTORY_DEFINITION = CodeTemplate("""\
167166
static inline ${return_type} ${api_name}(${formals}) {
168-
return at::${api_name}(${type_method_actuals}, TensorOptions(${inferred_type}));
167+
return at::${api_name}(${type_method_actuals}, ${inferred_type}.options());
169168
}
170169
""")
171170

aten/src/ATen/native/cuda/SpectralOps.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ struct cnt_to_dst_idx_functor : public thrust::unary_function<int64_t, int64_t>
3737
last_dim_size(last_dim_size), last_dim_start_slice(last_dim_start_slice),
3838
last_dim_to_fill_size(last_dim_size - last_dim_start_slice) {}
3939

40+
// HIP wants __host__ __device__ tag, CUDA does not
41+
#ifdef __HIP_PLATFORM_HCC__
42+
__host__ __device__
43+
#endif
4044
cnt_to_dst_idx_functor & operator=(const cnt_to_dst_idx_functor&) = default;
4145

4246
__host__ __device__ __forceinline__

aten/src/ATen/templates/TensorMethods.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "ATen/core/SparseTensorRef.h"
88
#include "ATen/Type.h"
99
#include "ATen/TensorOptions.h"
10+
#include "ATen/DeviceGuard.h"
11+
#include "ATen/Context.h"
1012

1113
namespace at {
1214

aten/src/ATen/templates/Type.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ATen/core/Half.h"
1616
#include "ATen/core/TensorTypeIdRegistration.h"
1717
#include "ATen/core/Reduction.h"
18+
#include "ATen/TensorOptions.h"
1819

1920
#include <array>
2021
#include <cstddef>
@@ -110,6 +111,20 @@ struct AT_API Type {
110111
return this != &other;
111112
}
112113

114+
/// Constructs the `TensorOptions` from a type and a `device_index`.
115+
TensorOptions options(int32_t device_index = -1) const {
116+
TensorOptions r;
117+
r.dtype(scalarType());
118+
r.device({backendToDeviceType(backend()), device_index});
119+
r.layout(layout());
120+
r.is_variable(is_variable());
121+
return r;
122+
}
123+
124+
operator TensorOptions() const {
125+
return options();
126+
}
127+
113128
// example
114129
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
115130
${pure_virtual_type_method_declarations}

caffe2/operators/softsign_op.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ inline __device__ float typed_abs(float x) {
2222
return fabsf(x);
2323
}
2424

25-
template <>
26-
inline __device__ double typed_abs(double x) {
27-
return fabs(x);
28-
}
25+
// Avoid compiler warning. <double> specification currently not used.
26+
// template <>
27+
// inline __device__ double typed_abs(double x) {
28+
// return fabs(x);
29+
// }
2930

3031
template <typename T>
3132
__global__ void SoftsignCUDAKernel(const int N, const T* X, T* Y) {

test/cpp/api/serialization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ TEST_CASE("serialization") {
5252
// XXX can't serialize half tensors at the moment since contiguous() is
5353
// not implemented for this type;
5454
continue;
55+
} else if (at::isComplexType(static_cast<torch::Dtype>(i))) {
56+
// Not supported yet
57+
continue;
5558
} else if (i == static_cast<int>(torch::Dtype::Undefined)) {
5659
// We can't construct a tensor for this type. This is tested in
5760
// serialization/undefined anyway.

test/cpp/api/tensor_options_cuda.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ using namespace at;
2222
REQUIRE(tensor.type().layout() == (layout_))
2323

2424
TEST_CASE("TensorOptions/ConstructsWellFromCUDATypes", "[cuda]") {
25-
auto options = TensorOptions(CUDA(kFloat));
25+
auto options = CUDA(kFloat).options();
2626
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kStrided);
2727

28-
options = TensorOptions(CUDA(kInt));
28+
options = CUDA(kInt).options();
2929
REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided);
3030

31-
options = TensorOptions(getNonVariableType(Backend::SparseCUDA, kFloat));
31+
options = getNonVariableType(Backend::SparseCUDA, kFloat).options();
3232
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse);
3333

34-
options = TensorOptions(getNonVariableType(Backend::SparseCUDA, kByte));
34+
options = getNonVariableType(Backend::SparseCUDA, kByte).options();
3535
REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse);
3636

37-
options = TensorOptions(CUDA(kFloat), /*device=*/5);
37+
options = CUDA(kFloat).options(/*device=*/5);
3838
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided);
3939

40-
options = TensorOptions(getNonVariableType(Backend::SparseCUDA, kFloat), /*device=*/5);
40+
options = getNonVariableType(Backend::SparseCUDA, kFloat).options(/*device=*/5);
4141
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
4242
}
4343

0 commit comments

Comments
 (0)