Skip to content

Commit 1e6a7cc

Browse files
ezyangPenghuiCheng
authored andcommitted
Move ATen/Registry.h to ATen/core/Registry.h (pytorch#11270)
Summary: Pull Request resolved: pytorch#11270 Still need to deduplicate this with caffe2/core/registry.h, but this will be a bit tricky because the current formulation of the macro is namespace sensitive (i.e., the macro for classes defined in at:: namespace won't work if you call from caffe2:: namespace). Reviewed By: gchanan Differential Revision: D9654871 fbshipit-source-id: 2207d1f2cc6d50bd41bf64ce0eb0b8523b05d9d9
1 parent 4ff98c2 commit 1e6a7cc

File tree

4 files changed

+219
-217
lines changed

4 files changed

+219
-217
lines changed

aten/src/ATen/Registry.h

Lines changed: 1 addition & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -1,216 +1,2 @@
11
#pragma once
2-
3-
/**
4-
* Simple registry implementation that uses static variables to
5-
* register object creators during program initialization time.
6-
*/
7-
8-
// NB: This Registry works poorly when you have other namespaces.
9-
// Make all macro invocations from inside the at namespace.
10-
11-
#include <algorithm>
12-
#include <cstdio>
13-
#include <cstdlib>
14-
#include <functional>
15-
#include <memory>
16-
#include <mutex>
17-
#include <unordered_map>
18-
#include <string>
19-
#include <vector>
20-
21-
#include <ATen/core/ATenGeneral.h>
22-
#include <ATen/core/Backtrace.h>
23-
24-
namespace at {
25-
26-
template <typename KeyType>
27-
inline void PrintOffendingKey(const KeyType& /*key*/) {
28-
printf("[key type printing not supported]\n");
29-
}
30-
31-
template <>
32-
inline void PrintOffendingKey(const std::string& key) {
33-
printf("Offending key: %s.\n", key.c_str());
34-
}
35-
36-
/**
37-
* @brief A template class that allows one to register classes by keys.
38-
*
39-
* The keys are usually a std::string specifying the name, but can be anything that
40-
* can be used in a std::map.
41-
*
42-
* You should most likely not use the Registry class explicitly, but use the
43-
* helper macros below to declare specific registries as well as registering
44-
* objects.
45-
*/
46-
template <class SrcType, class ObjectPtrType, class... Args>
47-
class AT_API Registry {
48-
public:
49-
typedef std::function<ObjectPtrType(Args...)> Creator;
50-
51-
Registry() : registry_() {}
52-
53-
void Register(const SrcType& key, Creator creator) {
54-
// The if statement below is essentially the same as the following line:
55-
// CHECK_EQ(registry_.count(key), 0) << "Key " << key
56-
// << " registered twice.";
57-
// However, CHECK_EQ depends on google logging, and since registration is
58-
// carried out at static initialization time, we do not want to have an
59-
// explicit dependency on glog's initialization function.
60-
std::lock_guard<std::mutex> lock(register_mutex_);
61-
if (registry_.count(key) != 0) {
62-
printf("Key already registered.\n");
63-
PrintOffendingKey(key);
64-
std::exit(1);
65-
}
66-
registry_[key] = creator;
67-
}
68-
69-
void Register(const SrcType& key, Creator creator, const std::string& help_msg) {
70-
Register(key, creator);
71-
help_message_[key] = help_msg;
72-
}
73-
74-
inline bool Has(const SrcType& key) { return (registry_.count(key) != 0); }
75-
76-
ObjectPtrType Create(const SrcType& key, Args... args) {
77-
if (registry_.count(key) == 0) {
78-
// Returns nullptr if the key is not registered.
79-
return nullptr;
80-
}
81-
return registry_[key](args...);
82-
}
83-
84-
/**
85-
* Returns the keys currently registered as a std::vector.
86-
*/
87-
std::vector<SrcType> Keys() {
88-
std::vector<SrcType> keys;
89-
for (const auto& it : registry_) {
90-
keys.push_back(it.first);
91-
}
92-
return keys;
93-
}
94-
95-
const std::unordered_map<SrcType, std::string>& HelpMessage() const {
96-
return help_message_;
97-
}
98-
99-
const char* HelpMessage(const SrcType& key) const {
100-
auto it = help_message_.find(key);
101-
if (it == help_message_.end()) {
102-
return nullptr;
103-
}
104-
return it->second.c_str();
105-
}
106-
107-
private:
108-
std::unordered_map<SrcType, Creator> registry_;
109-
std::unordered_map<SrcType, std::string> help_message_;
110-
std::mutex register_mutex_;
111-
112-
Registry(const Registry&) = delete;
113-
Registry& operator=(const Registry&) = delete;
114-
};
115-
116-
template <class SrcType, class ObjectPtrType, class... Args>
117-
class AT_API Registerer {
118-
public:
119-
Registerer(
120-
const SrcType& key,
121-
Registry<SrcType, ObjectPtrType, Args...>* registry,
122-
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
123-
const std::string& help_msg = "") {
124-
registry->Register(key, creator, help_msg);
125-
}
126-
127-
template <class DerivedType>
128-
static ObjectPtrType DefaultCreator(Args... args) {
129-
// TODO(jiayq): old versions of NVCC does not handle make_unique well
130-
// so we are forced to use a unique_ptr constructor here. Check if it is
131-
// fine to use make_unique in the future.
132-
// return make_unique<DerivedType>(args...);
133-
return ObjectPtrType(new DerivedType(args...));
134-
}
135-
};
136-
137-
/**
138-
* AT_ANONYMOUS_VARIABLE(str) introduces an identifier starting with
139-
* str and ending with a number that varies with the line.
140-
* Pretty much a copy from 'folly/Preprocessor.h'
141-
*/
142-
#define AT_CONCATENATE_IMPL(s1, s2) s1##s2
143-
#define AT_CONCATENATE(s1, s2) AT_CONCATENATE_IMPL(s1, s2)
144-
#ifdef __COUNTER__
145-
#define AT_ANONYMOUS_VARIABLE(str) AT_CONCATENATE(str, __COUNTER__)
146-
#else
147-
#define AT_ANONYMOUS_VARIABLE(str) AT_CONCATENATE(str, __LINE__)
148-
#endif
149-
150-
/**
151-
* AT_DECLARE_TYPED_REGISTRY is a macro that expands to a function
152-
* declaration, as well as creating a convenient typename for its corresponding
153-
* registerer.
154-
*/
155-
#define AT_DECLARE_TYPED_REGISTRY( \
156-
RegistryName, SrcType, ObjectType, PtrType, ...) \
157-
AT_API Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>* RegistryName(); \
158-
typedef Registerer<SrcType, PtrType<ObjectType>, __VA_ARGS__> \
159-
Registerer##RegistryName; \
160-
extern template class Registerer<SrcType, PtrType<ObjectType>, __VA_ARGS__>;
161-
162-
#define AT_DEFINE_TYPED_REGISTRY( \
163-
RegistryName, SrcType, ObjectType, PtrType, ...) \
164-
Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>* RegistryName() { \
165-
static Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>* registry = \
166-
new Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>(); \
167-
return registry; \
168-
} \
169-
template class Registerer<SrcType, PtrType<ObjectType>, __VA_ARGS__>;
170-
171-
// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
172-
// creator with comma in its templated arguments.
173-
#define AT_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
174-
namespace { \
175-
Registerer##RegistryName AT_ANONYMOUS_VARIABLE(g_##RegistryName)( \
176-
key, RegistryName(), __VA_ARGS__); \
177-
}
178-
179-
#define AT_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
180-
namespace { \
181-
Registerer##RegistryName AT_ANONYMOUS_VARIABLE(g_##RegistryName)( \
182-
key, \
183-
RegistryName(), \
184-
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
185-
::at::demangle_type<__VA_ARGS__>()); \
186-
}
187-
188-
// AT_DECLARE_REGISTRY and AT_DEFINE_REGISTRY are hard-wired to use std::string
189-
// as the key
190-
// type, because that is the most commonly used cases.
191-
#define AT_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
192-
AT_DECLARE_TYPED_REGISTRY( \
193-
RegistryName, std::string, ObjectType, std::unique_ptr, __VA_ARGS__)
194-
195-
#define AT_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
196-
AT_DEFINE_TYPED_REGISTRY( \
197-
RegistryName, std::string, ObjectType, std::unique_ptr, __VA_ARGS__)
198-
199-
#define AT_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
200-
AT_DECLARE_TYPED_REGISTRY( \
201-
RegistryName, std::string, ObjectType, std::shared_ptr, __VA_ARGS__)
202-
203-
#define AT_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
204-
AT_DEFINE_TYPED_REGISTRY( \
205-
RegistryName, std::string, ObjectType, std::shared_ptr, __VA_ARGS__)
206-
207-
// AT_REGISTER_CREATOR and AT_REGISTER_CLASS are hard-wired to use std::string
208-
// as the key
209-
// type, because that is the most commonly used cases.
210-
#define AT_REGISTER_CREATOR(RegistryName, key, ...) \
211-
AT_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
212-
213-
#define AT_REGISTER_CLASS(RegistryName, key, ...) \
214-
AT_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
215-
216-
} // namespace at
2+
#include <ATen/core/Registry.h>

0 commit comments

Comments
 (0)