|
1 | 1 | #pragma once
|
2 |
| - |
3 |
| -/** |
4 |
| - * Simple registry implementation that uses static variables to |
5 |
| - * register object creators during program initialization time. |
6 |
| - */ |
7 |
| - |
8 |
| -// NB: This Registry works poorly when you have other namespaces. |
9 |
| -// Make all macro invocations from inside the at namespace. |
10 |
| - |
11 |
| -#include <algorithm> |
12 |
| -#include <cstdio> |
13 |
| -#include <cstdlib> |
14 |
| -#include <functional> |
15 |
| -#include <memory> |
16 |
| -#include <mutex> |
17 |
| -#include <unordered_map> |
18 |
| -#include <string> |
19 |
| -#include <vector> |
20 |
| - |
21 |
| -#include <ATen/core/ATenGeneral.h> |
22 |
| -#include <ATen/core/Backtrace.h> |
23 |
| - |
24 |
| -namespace at { |
25 |
| - |
26 |
| -template <typename KeyType> |
27 |
| -inline void PrintOffendingKey(const KeyType& /*key*/) { |
28 |
| - printf("[key type printing not supported]\n"); |
29 |
| -} |
30 |
| - |
31 |
| -template <> |
32 |
| -inline void PrintOffendingKey(const std::string& key) { |
33 |
| - printf("Offending key: %s.\n", key.c_str()); |
34 |
| -} |
35 |
| - |
36 |
| -/** |
37 |
| - * @brief A template class that allows one to register classes by keys. |
38 |
| - * |
39 |
| - * The keys are usually a std::string specifying the name, but can be anything that |
40 |
| - * can be used in a std::map. |
41 |
| - * |
42 |
| - * You should most likely not use the Registry class explicitly, but use the |
43 |
| - * helper macros below to declare specific registries as well as registering |
44 |
| - * objects. |
45 |
| - */ |
46 |
| -template <class SrcType, class ObjectPtrType, class... Args> |
47 |
| -class AT_API Registry { |
48 |
| - public: |
49 |
| - typedef std::function<ObjectPtrType(Args...)> Creator; |
50 |
| - |
51 |
| - Registry() : registry_() {} |
52 |
| - |
53 |
| - void Register(const SrcType& key, Creator creator) { |
54 |
| - // The if statement below is essentially the same as the following line: |
55 |
| - // CHECK_EQ(registry_.count(key), 0) << "Key " << key |
56 |
| - // << " registered twice."; |
57 |
| - // However, CHECK_EQ depends on google logging, and since registration is |
58 |
| - // carried out at static initialization time, we do not want to have an |
59 |
| - // explicit dependency on glog's initialization function. |
60 |
| - std::lock_guard<std::mutex> lock(register_mutex_); |
61 |
| - if (registry_.count(key) != 0) { |
62 |
| - printf("Key already registered.\n"); |
63 |
| - PrintOffendingKey(key); |
64 |
| - std::exit(1); |
65 |
| - } |
66 |
| - registry_[key] = creator; |
67 |
| - } |
68 |
| - |
69 |
| - void Register(const SrcType& key, Creator creator, const std::string& help_msg) { |
70 |
| - Register(key, creator); |
71 |
| - help_message_[key] = help_msg; |
72 |
| - } |
73 |
| - |
74 |
| - inline bool Has(const SrcType& key) { return (registry_.count(key) != 0); } |
75 |
| - |
76 |
| - ObjectPtrType Create(const SrcType& key, Args... args) { |
77 |
| - if (registry_.count(key) == 0) { |
78 |
| - // Returns nullptr if the key is not registered. |
79 |
| - return nullptr; |
80 |
| - } |
81 |
| - return registry_[key](args...); |
82 |
| - } |
83 |
| - |
84 |
| - /** |
85 |
| - * Returns the keys currently registered as a std::vector. |
86 |
| - */ |
87 |
| - std::vector<SrcType> Keys() { |
88 |
| - std::vector<SrcType> keys; |
89 |
| - for (const auto& it : registry_) { |
90 |
| - keys.push_back(it.first); |
91 |
| - } |
92 |
| - return keys; |
93 |
| - } |
94 |
| - |
95 |
| - const std::unordered_map<SrcType, std::string>& HelpMessage() const { |
96 |
| - return help_message_; |
97 |
| - } |
98 |
| - |
99 |
| - const char* HelpMessage(const SrcType& key) const { |
100 |
| - auto it = help_message_.find(key); |
101 |
| - if (it == help_message_.end()) { |
102 |
| - return nullptr; |
103 |
| - } |
104 |
| - return it->second.c_str(); |
105 |
| - } |
106 |
| - |
107 |
| - private: |
108 |
| - std::unordered_map<SrcType, Creator> registry_; |
109 |
| - std::unordered_map<SrcType, std::string> help_message_; |
110 |
| - std::mutex register_mutex_; |
111 |
| - |
112 |
| - Registry(const Registry&) = delete; |
113 |
| - Registry& operator=(const Registry&) = delete; |
114 |
| -}; |
115 |
| - |
116 |
| -template <class SrcType, class ObjectPtrType, class... Args> |
117 |
| -class AT_API Registerer { |
118 |
| - public: |
119 |
| - Registerer( |
120 |
| - const SrcType& key, |
121 |
| - Registry<SrcType, ObjectPtrType, Args...>* registry, |
122 |
| - typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator, |
123 |
| - const std::string& help_msg = "") { |
124 |
| - registry->Register(key, creator, help_msg); |
125 |
| - } |
126 |
| - |
127 |
| - template <class DerivedType> |
128 |
| - static ObjectPtrType DefaultCreator(Args... args) { |
129 |
| - // TODO(jiayq): old versions of NVCC does not handle make_unique well |
130 |
| - // so we are forced to use a unique_ptr constructor here. Check if it is |
131 |
| - // fine to use make_unique in the future. |
132 |
| - // return make_unique<DerivedType>(args...); |
133 |
| - return ObjectPtrType(new DerivedType(args...)); |
134 |
| - } |
135 |
| -}; |
136 |
| - |
137 |
| -/** |
138 |
| - * AT_ANONYMOUS_VARIABLE(str) introduces an identifier starting with |
139 |
| - * str and ending with a number that varies with the line. |
140 |
| - * Pretty much a copy from 'folly/Preprocessor.h' |
141 |
| - */ |
142 |
| -#define AT_CONCATENATE_IMPL(s1, s2) s1##s2 |
143 |
| -#define AT_CONCATENATE(s1, s2) AT_CONCATENATE_IMPL(s1, s2) |
144 |
| -#ifdef __COUNTER__ |
145 |
| -#define AT_ANONYMOUS_VARIABLE(str) AT_CONCATENATE(str, __COUNTER__) |
146 |
| -#else |
147 |
| -#define AT_ANONYMOUS_VARIABLE(str) AT_CONCATENATE(str, __LINE__) |
148 |
| -#endif |
149 |
| - |
150 |
| -/** |
151 |
| - * AT_DECLARE_TYPED_REGISTRY is a macro that expands to a function |
152 |
| - * declaration, as well as creating a convenient typename for its corresponding |
153 |
| - * registerer. |
154 |
| - */ |
155 |
| -#define AT_DECLARE_TYPED_REGISTRY( \ |
156 |
| - RegistryName, SrcType, ObjectType, PtrType, ...) \ |
157 |
| - AT_API Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>* RegistryName(); \ |
158 |
| - typedef Registerer<SrcType, PtrType<ObjectType>, __VA_ARGS__> \ |
159 |
| - Registerer##RegistryName; \ |
160 |
| - extern template class Registerer<SrcType, PtrType<ObjectType>, __VA_ARGS__>; |
161 |
| - |
162 |
| -#define AT_DEFINE_TYPED_REGISTRY( \ |
163 |
| - RegistryName, SrcType, ObjectType, PtrType, ...) \ |
164 |
| - Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>* RegistryName() { \ |
165 |
| - static Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>* registry = \ |
166 |
| - new Registry<SrcType, PtrType<ObjectType>, __VA_ARGS__>(); \ |
167 |
| - return registry; \ |
168 |
| - } \ |
169 |
| - template class Registerer<SrcType, PtrType<ObjectType>, __VA_ARGS__>; |
170 |
| - |
171 |
| -// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated |
172 |
| -// creator with comma in its templated arguments. |
173 |
| -#define AT_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ |
174 |
| - namespace { \ |
175 |
| - Registerer##RegistryName AT_ANONYMOUS_VARIABLE(g_##RegistryName)( \ |
176 |
| - key, RegistryName(), __VA_ARGS__); \ |
177 |
| - } |
178 |
| - |
179 |
| -#define AT_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ |
180 |
| - namespace { \ |
181 |
| - Registerer##RegistryName AT_ANONYMOUS_VARIABLE(g_##RegistryName)( \ |
182 |
| - key, \ |
183 |
| - RegistryName(), \ |
184 |
| - Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ |
185 |
| - ::at::demangle_type<__VA_ARGS__>()); \ |
186 |
| - } |
187 |
| - |
188 |
| -// AT_DECLARE_REGISTRY and AT_DEFINE_REGISTRY are hard-wired to use std::string |
189 |
| -// as the key |
190 |
| -// type, because that is the most commonly used cases. |
191 |
| -#define AT_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ |
192 |
| - AT_DECLARE_TYPED_REGISTRY( \ |
193 |
| - RegistryName, std::string, ObjectType, std::unique_ptr, __VA_ARGS__) |
194 |
| - |
195 |
| -#define AT_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ |
196 |
| - AT_DEFINE_TYPED_REGISTRY( \ |
197 |
| - RegistryName, std::string, ObjectType, std::unique_ptr, __VA_ARGS__) |
198 |
| - |
199 |
| -#define AT_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ |
200 |
| - AT_DECLARE_TYPED_REGISTRY( \ |
201 |
| - RegistryName, std::string, ObjectType, std::shared_ptr, __VA_ARGS__) |
202 |
| - |
203 |
| -#define AT_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ |
204 |
| - AT_DEFINE_TYPED_REGISTRY( \ |
205 |
| - RegistryName, std::string, ObjectType, std::shared_ptr, __VA_ARGS__) |
206 |
| - |
207 |
| -// AT_REGISTER_CREATOR and AT_REGISTER_CLASS are hard-wired to use std::string |
208 |
| -// as the key |
209 |
| -// type, because that is the most commonly used cases. |
210 |
| -#define AT_REGISTER_CREATOR(RegistryName, key, ...) \ |
211 |
| - AT_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) |
212 |
| - |
213 |
| -#define AT_REGISTER_CLASS(RegistryName, key, ...) \ |
214 |
| - AT_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) |
215 |
| - |
216 |
| -} // namespace at |
| 2 | +#include <ATen/core/Registry.h> |
0 commit comments