Skip to content

Commit eea1c6a

Browse files
authored
Merge pull request #170 from iotamudelta/ifu
Merge from upstream
2 parents 16e2f6a + df81592 commit eea1c6a

File tree

244 files changed

+3800
-3741
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

244 files changed

+3800
-3741
lines changed

aten/src/ATen/UndefinedType.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
namespace at {
55

66
UndefinedType::UndefinedType()
7-
: Type(UndefinedTensorId(), /*is_variable=*/false, /*is_undefined=*/true) {}
7+
: TypeDefault(UndefinedTensorId(), /*is_variable=*/false, /*is_undefined=*/true) {}
88
ScalarType UndefinedType::scalarType() const {
99
return ScalarType::Undefined;
1010
}
@@ -50,13 +50,13 @@ size_t UndefinedType::elementSizeInBytes() const {
5050

5151
Type & UndefinedType::toBackend(Backend b) const {
5252
if (b == Backend::Undefined) {
53-
return Type::toBackend(b);
53+
return TypeDefault::toBackend(b);
5454
}
5555
AT_ERROR("toBackend not implemented for UndefinedType to non-UndefinedType");
5656
}
5757
Type & UndefinedType::toScalarType(ScalarType s) const {
5858
if (s == ScalarType::Undefined) {
59-
return Type::toScalarType(s);
59+
return TypeDefault::toScalarType(s);
6060
}
6161
AT_ERROR("toScalarType not implemented for UndefinedType to non-UndefinedType");
6262
}

aten/src/ATen/UndefinedType.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "ATen/Type.h"
3+
#include "ATen/TypeDefault.h"
44
#include "ATen/CheckGenerator.h"
55

66
#ifdef _MSC_VER
@@ -11,7 +11,7 @@
1111

1212
namespace at {
1313

14-
struct UndefinedType final : public Type {
14+
struct UndefinedType final : public TypeDefault {
1515
explicit UndefinedType();
1616
virtual ScalarType scalarType() const override;
1717
virtual Backend backend() const override;

aten/src/ATen/function_wrapper.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def TypedDict(name, attrs, total=True): # type: ignore
3939
# declaration under Type.h (right now, we call this template
4040
# BROADCAST but it also handles default arguments)
4141
TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate("""\
42-
${return_type} ${api_name}(${type_method_formals_with_defaults}) const;
42+
${return_type} ${api_name}(${type_method_formals_with_defaults}) const override;
4343
""")
4444
# 2. broadcasting functions are implemented in Type.cpp
4545
TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate("""\
46-
${return_type} Type::${api_name}(${type_method_formals}) const {
46+
${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
4747
${device_guard_declaration}
4848
Tensor ${broadcast_returns};
4949
std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
@@ -59,36 +59,44 @@ def TypedDict(name, attrs, total=True): # type: ignore
5959
# actual implementation. At the moment, this situation *only* occurs
6060
# for 'native' declarations (so the native dispatch is hardcoded into
6161
# the template here.)
62+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate("""\
63+
virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals_with_defaults}) const = 0;
64+
""")
65+
DEPRECATED_PURE_VIRTUAL_TYPE_METHOD_DECLARATION = CodeTemplate("""\
66+
AT_DEPRECATED(virtual ${return_type} \
67+
${method_prefix_derived}${api_name}(${type_method_formals_with_defaults}) const = 0);
68+
""")
69+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate("""\
70+
virtual ${return_type} ${api_name}(${type_method_formals_with_defaults}) const = 0;
71+
""")
72+
6273
TYPE_METHOD_DECLARATION_ABSTRACT = CodeTemplate("""\
63-
virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals_with_defaults}) const;
74+
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals_with_defaults}) const override;
6475
""")
6576
TYPE_METHOD_DEFINITION_ABSTRACT = CodeTemplate("""\
66-
${return_type} Type::${method_prefix_derived}${api_name}(${type_method_formals}) const {
77+
${return_type} TypeDefault::${method_prefix_derived}${api_name}(${type_method_formals}) const {
6778
AT_ERROR("${method_prefix_derived}${api_name} is not implemented for type ", toString());
6879
}
6980
""")
7081
TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate("""\
71-
virtual ${return_type} ${api_name}(${type_method_formals_with_defaults}) const;
72-
""")
73-
DEPRECATED_TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate("""\
74-
AT_DEPRECATED(virtual ${return_type} ${api_name}(${type_method_formals_with_defaults}) const);
82+
${return_type} ${api_name}(${type_method_formals_with_defaults}) const override;
7583
""")
7684
TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\
77-
${return_type} Type::${api_name}(${type_method_formals}) const {
85+
${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
7886
${device_guard_declaration}
7987
${type_definition_body}
8088
}
8189
""")
8290
DEPRECATED_TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\
83-
${return_type} Type::${api_name}(${type_method_formals}) const {
91+
${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
8492
TensorOptions options(*this);
8593
${device_guard_declaration}
8694
return at::native::${api_name}(${type_method_actuals}, options);
8795
}
8896
""")
8997
# 4. add virtual override to TypeDerived.h
9098
TYPE_DERIVED_DECLARATION = CodeTemplate("""\
91-
virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
99+
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
92100
""")
93101
# 5. add override definition to TypeDerived.cpp
94102
TYPE_DERIVED_DEFINITION = CodeTemplate("""\
@@ -382,6 +390,7 @@ def __getitem__(self, x):
382390
TopEnvironment = TypedDict('TopEnvironment', {
383391
'type_registrations': List[str],
384392
'type_headers': List[str],
393+
'pure_virtual_type_method_declarations': List[str],
385394
'type_method_declarations': List[str],
386395
'type_method_definitions': List[str],
387396
'type_method_inline_definitions': List[str],
@@ -815,18 +824,26 @@ def process_option(option, output_options):
815824
# NN function with no _forward/_backward suffix don't have cimpls.
816825
# They call the _forward function and discard any buffer returns
817826
abstract = False
827+
top_env['pure_virtual_type_method_declarations'].append(
828+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
818829
top_env['type_method_declarations'].append(
819830
TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
820831
body = emit_nn_body(option)
821832
top_env['type_method_definitions'].append(
822833
TYPE_METHOD_DEFINITION_CONCRETE.substitute(
823834
env, type_definition_body=body))
824835
elif broadcast_arg is None:
836+
top_env['pure_virtual_type_method_declarations'].append(
837+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
825838
top_env['type_method_declarations'].append(
826839
TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
827840
top_env['type_method_definitions'].append(
828841
TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
829842
else:
843+
top_env['pure_virtual_type_method_declarations'].append(
844+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
845+
top_env['pure_virtual_type_method_declarations'].append(
846+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
830847
top_env['type_method_declarations'].append(
831848
TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
832849
top_env['type_method_declarations'].append(
@@ -1031,9 +1048,12 @@ def find_formal(formal_name, formals):
10311048
# Factory methods are not dispatched over `Type`.
10321049
if not is_factory_method:
10331050
if option['deprecated']:
1034-
top_env['type_method_declarations'].append(DEPRECATED_TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
1051+
top_env['pure_virtual_type_method_declarations'].append(
1052+
DEPRECATED_PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
10351053
else:
1036-
top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
1054+
top_env['pure_virtual_type_method_declarations'].append(
1055+
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
1056+
top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
10371057
dispatch = option['type_method_definition_dispatch']
10381058
option['native_type_method_dispatch'] = dispatch
10391059

aten/src/ATen/gen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def check_all_files_written(self):
107107
SPARSE_TYPE_DERIVED_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/SparseTypeDerived.cpp")
108108
TYPE_DERIVED_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDerived.h")
109109
TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h")
110-
TYPE_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.cpp")
110+
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
111+
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
111112

112113
REGISTER_CPU_H = CodeTemplate.from_file(TEMPLATE_PATH + "/RegisterCPU.h")
113114
REGISTER_CPU_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/RegisterCPU.cpp")
@@ -166,6 +167,7 @@ def check_all_files_written(self):
166167
'cpu_type_headers': [],
167168
'cuda_type_registrations': [],
168169
'cuda_type_headers': [],
170+
'pure_virtual_type_method_declarations': [],
169171
'type_method_declarations': [],
170172
'type_method_definitions': [],
171173
'type_method_inline_definitions': [],
@@ -329,7 +331,7 @@ def iterate_types():
329331
# so that the script runs quickly when we are just querying the
330332
# outputs
331333
def declare_outputs():
332-
files = ['Declarations.yaml', 'Type.h', 'Type.cpp', 'Tensor.h',
334+
files = ['Declarations.yaml', 'Type.h', 'TypeDefault.cpp', 'TypeDefault.h', 'Tensor.h',
333335
'TensorMethods.h', 'Functions.h',
334336
'CPUCopy.cpp', 'NativeFunctions.h',
335337
'RegisterCPU.cpp', 'RegisterCPU.h']
@@ -399,7 +401,8 @@ def generate_outputs():
399401
backend, density, scalar_type, declarations))
400402

401403
file_manager.write('Type.h', TYPE_H, top_env)
402-
file_manager.write('Type.cpp', TYPE_CPP, top_env)
404+
file_manager.write('TypeDefault.h', TYPE_DEFAULT_H, top_env)
405+
file_manager.write('TypeDefault.cpp', TYPE_DEFAULT_CPP, top_env)
403406

404407
file_manager.write('RegisterCPU.h', REGISTER_CPU_H, top_env)
405408
file_manager.write('RegisterCPU.cpp', REGISTER_CPU_CPP, top_env)

aten/src/ATen/templates/SparseTypeDerived.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
namespace at {
2929

3030
${Type}::${Type}()
31-
: Type(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
31+
: TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
3232
ScalarType ${Type}::scalarType() const {
3333
return ScalarType::${ScalarName};
3434
}

aten/src/ATen/templates/Type.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum class TypeID {
4747
struct AT_API Type {
4848
explicit Type(TensorTypeId type_id, bool is_variable, bool is_undefined)
4949
: type_id_(type_id), is_variable_(is_variable), is_undefined_(is_undefined) {}
50+
5051
virtual ~Type() {}
5152
virtual ScalarType scalarType() const = 0;
5253
virtual Backend backend() const = 0;
@@ -65,8 +66,8 @@ struct AT_API Type {
6566
virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const = 0;
6667
virtual const char * toString() const = 0;
6768
virtual size_t elementSizeInBytes() const = 0;
68-
virtual Type & toBackend(Backend b) const;
69-
virtual Type & toScalarType(ScalarType s) const;
69+
virtual Type & toBackend(Backend b) const = 0;
70+
virtual Type & toScalarType(ScalarType s) const = 0;
7071
Type & toSparse() const {
7172
return this->toBackend(at::toSparse(this->backend()));
7273
}
@@ -91,23 +92,27 @@ struct AT_API Type {
9192
return backendToDeviceType(backend());
9293
}
9394

94-
Tensor copy(const Tensor & src, bool non_blocking=false) const;
95-
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const;
95+
virtual Tensor copy(const Tensor & src, bool non_blocking=false) const = 0;
96+
virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const = 0;
9697
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
9798
virtual Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const = 0;
9899

99-
Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const;
100-
Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const;
101-
Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const;
102-
Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const;
103-
Tensor scalarTensor(Scalar s) const;
100+
virtual Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
101+
virtual Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
102+
virtual Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const = 0;
103+
virtual Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const = 0;
104+
virtual Tensor scalarTensor(Scalar s) const = 0;
104105

105-
bool operator==(const Type& other) const;
106-
bool operator!=(const Type& other) const;
106+
bool operator==(const Type& other) const {
107+
return this == &other;
108+
}
109+
bool operator!=(const Type& other) const {
110+
return this != &other;
111+
}
107112

108113
// example
109114
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
110-
${type_method_declarations}
115+
${pure_virtual_type_method_declarations}
111116
protected:
112117
TensorTypeId type_id_;
113118
bool is_variable_;

aten/src/ATen/templates/Type.cpp renamed to aten/src/ATen/templates/TypeDefault.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "ATen/Type.h"
1+
#include "ATen/TypeDefault.h"
22

33
// ${generated_comment}
44

@@ -13,13 +13,13 @@
1313

1414
namespace at {
1515

16-
Tensor & Type::copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
16+
Tensor & TypeDefault::copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
1717
Tensor b_src;
1818
std::tie(b_src) = expand_inplace(self, src, "copy");
1919
return s_copy_(self, b_src, non_blocking);
2020
}
2121

22-
Tensor Type::copy(const Tensor & src, bool non_blocking) const {
22+
Tensor TypeDefault::copy(const Tensor & src, bool non_blocking) const {
2323
// TODO(psag): have a DeviceGuard here
2424
AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
2525
if (is_sparse()) {
@@ -37,10 +37,10 @@ Tensor Type::copy(const Tensor & src, bool non_blocking) const {
3737
}
3838
}
3939

40-
Type & Type::toBackend(Backend b) const {
40+
Type & TypeDefault::toBackend(Backend b) const {
4141
return at::globalContext().getNonVariableType(b,scalarType());
4242
}
43-
Type & Type::toScalarType(ScalarType s) const {
43+
Type & TypeDefault::toScalarType(ScalarType s) const {
4444
return at::globalContext().getNonVariableType(backend(),s);
4545
}
4646
static std::vector<int64_t> defaultStrides(IntList sizes) {
@@ -64,31 +64,24 @@ static int64_t computeStorageSize(IntList sizes, IntList strides) {
6464
}
6565
return size;
6666
}
67-
Tensor Type::tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter) const {
67+
Tensor TypeDefault::tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter) const {
6868
return tensorFromBlob(data, sizes, defaultStrides(sizes), deleter);
6969
}
70-
Tensor Type::tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter) const {
70+
Tensor TypeDefault::tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter) const {
7171
auto storage = storageFromBlob(data, computeStorageSize(sizes, strides), deleter);
7272
return tensor(storage, 0, sizes, strides);
7373
}
74-
Tensor Type::tensorWithAllocator(IntList sizes, Allocator* allocator) const {
74+
Tensor TypeDefault::tensorWithAllocator(IntList sizes, Allocator* allocator) const {
7575
return tensorWithAllocator(sizes, defaultStrides(sizes), std::move(allocator));
7676
}
77-
Tensor Type::tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const {
77+
Tensor TypeDefault::tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const {
7878
auto storage = storageWithAllocator(computeStorageSize(sizes, strides), std::move(allocator));
7979
return tensor(storage, 0, sizes, strides);
8080
}
81-
Tensor Type::scalarTensor(Scalar s) const {
81+
Tensor TypeDefault::scalarTensor(Scalar s) const {
8282
return tensor({}).fill_(s);
8383
}
8484

85-
bool Type::operator==(const Type& other) const {
86-
return this == &other;
87-
}
88-
bool Type::operator!=(const Type& other) const {
89-
return this != &other;
90-
}
91-
9285
${type_method_definitions}
9386

9487
}

aten/src/ATen/templates/TypeDefault.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
// ${generated_comment}
4+
5+
#include "ATen/Type.h"
6+
7+
namespace at {
8+
9+
struct AT_API TypeDefault : public Type {
10+
explicit TypeDefault(TensorTypeId type_id, bool is_variable, bool is_undefined)
11+
: Type(type_id, is_variable, is_undefined) {}
12+
13+
// Make sure overload resolution considers the nullary virtual method.
14+
// (A single argument overload is generated in the list.)
15+
bool is_cuda() const override = 0;
16+
bool is_sparse() const override = 0;
17+
bool is_distributed() const override = 0;
18+
19+
Type & toBackend(Backend b) const override;
20+
Type & toScalarType(ScalarType s) const override;
21+
22+
Tensor copy(const Tensor & src, bool non_blocking=false) const override;
23+
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const override;
24+
25+
Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const override;
26+
Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const override;
27+
Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const override;
28+
Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const override;
29+
Tensor scalarTensor(Scalar s) const override;
30+
31+
// example
32+
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
33+
${type_method_declarations}
34+
};
35+
36+
} // namespace at

aten/src/ATen/templates/TypeDerived.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ static int getPointerDevice(void* ptr) {
3939
#endif
4040

4141
${Type}::${Type}()
42-
: Type(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
42+
: TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
4343
ScalarType ${Type}::scalarType() const {
4444
return ScalarType::${ScalarName};
4545
}

aten/src/ATen/templates/TypeDerived.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// ${generated_comment}
44

5-
#include "ATen/Type.h"
5+
#include "ATen/TypeDefault.h"
66
#include "ATen/Context.h"
77
#include "ATen/TensorMethods.h"
88
#include "ATen/CheckGenerator.h"
@@ -15,7 +15,7 @@
1515

1616
namespace at {
1717

18-
struct ${Type} final : public Type {
18+
struct ${Type} final : public TypeDefault {
1919
explicit ${Type}();
2020
virtual ScalarType scalarType() const override;
2121
virtual Backend backend() const override;

0 commit comments

Comments
 (0)