9
9
10
10
namespace torch { namespace jit {
11
11
12
- #define FORALL_NS_SYMBOLS (_ ) \
13
- _ (namespaces, prim) \
14
- _ (namespaces, aten) \
15
- _ (namespaces, onnx) \
16
- _ (namespaces, attr) \
17
- _ (namespaces, scope) \
18
- _ (namespaces, namespaces) \
12
+ // Every symbol is classified in a namespace, specifying what kind of symbol it
13
+ // is. Unsigned char to ensure widening to unique_t (also an unsigned type)
14
+ enum class SymbolNamespace : unsigned char {
15
+ onnx = ' o' ,
16
+ prim = ' p' ,
17
+ aten = ' t' ,
18
+ // NB: ONNX and ATen attributes all live in a unified namespace, as
19
+ // their interpretation depends on the operator name (which is namespaced)
20
+ attr = ' a' ,
21
+ // TODO: eliminate me
22
+ scope = ' s'
23
+ };
24
+
25
+ // Primitive symbols are synthetic operators that occur only in the IR
26
+ // and don't have corresponding implementations in ATen.
27
+ //
28
+ // TODO: We need documentation for all of these symbols.
29
+ //
30
+ // TODO: Consider moving the synthetic onnx operators to their own
31
+ // namespace.
32
+ #define FORALL_PRIM_SYMBOLS (_ ) \
19
33
_ (prim, Assign) \
20
34
_ (prim, Constant) \
21
35
_ (prim, CppOp) \
@@ -48,9 +62,32 @@ _(prim, NumToTensor) \
48
62
_ (prim, TensorToNum) \
49
63
_ (prim, AutogradAdd) \
50
64
_ (prim, GradOf) \
51
- _ (prim, AnyDefined) \
65
+ _ (prim, AnyDefined)
66
+ /* end */
67
+
68
+ // Workaround for some not-yet-defined ATen symbols, see
69
+ // - __not__: https://github.com/pytorch/pytorch/issues/5495
70
+ // - ones, zeros: https://github.com/pytorch/pytorch/issues/5496
71
+
72
+ #define FORALL_ATEN_EXTRA_SYMBOLS (_ ) \
52
73
_ (aten, __not__) \
74
+ /* end */
75
+
76
+ #define FORALL_ATEN_SYMBOLS (_ ) \
53
77
FORALL_ATEN_BASE_SYMBOLS (_) \
78
+ FORALL_ATEN_EXTRA_SYMBOLS (_)
79
+
80
+ // These symbols correspond to ONNX operators. Their semantics
81
+ // are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
82
+ // The particular version we are targeting is specified by '_onnx_opset_version'
83
+ // in torch.onnx.symbolic
84
+ //
85
+ // In general, most ONNX operators won't get an entry here, because they
86
+ // are handled from the Python end. However, you may occasionally need
87
+ // to intern an ONNX symbol here so that you can conveniently write an
88
+ // optimization on ONNX operations.
89
+
90
+ #define FORALL_ONNX_SYMBOLS (_ ) \
54
91
_ (onnx, Add) \
55
92
_ (onnx, Concat) \
56
93
_ (onnx, Constant) \
@@ -72,8 +109,15 @@ _(onnx, Transpose) \
72
109
_ (onnx, Unsqueeze) \
73
110
_ (onnx, Loop) \
74
111
_ (onnx, If) \
75
- _ (onnx, Reshape) \
76
- FORALL_ATTR_BASE_SYMBOLS (_) \
112
+ _ (onnx, Reshape)
113
+ /* end */
114
+
115
+ // These symbols are attribute keys. They are shared between both ONNX and ATen
116
+ // operators (you disambiguate their meaning by looking at the operator itself).
117
+ // In general, you only need to define attribute keys that are used by
118
+ // onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
119
+
120
+ #define FORALL_ATTR_EXTRA_SYMBOLS (_ ) \
77
121
_ (attr, Subgraph) \
78
122
_ (attr, axes) \
79
123
_ (attr, axis) \
@@ -89,33 +133,40 @@ _(attr, starts) \
89
133
_ (attr, transA) \
90
134
_ (attr, transB) \
91
135
_ (attr, name)
136
+ /* end */
92
137
93
- // 'prim' symbols are synthetic operators that occur only in the IR
94
- // and don't have corresponding implementations in ATen.
95
-
96
- // 'onnx' symbols correspond to ONNX operators. Their semantics
97
- // are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
98
- // The particular version we are targeting is specified by '_onnx_opset_version'
99
- // in torch.onnx.symbolic
100
- //
101
- // In general, most ONNX operators won't get an entry here, because they
102
- // are handled from the Python end. However, you may occasionally need
103
- // to intern an ONNX symbol here so that you can conveniently write an
104
- // optimization on ONNX operations.
138
+ #define FORALL_ATTR_SYMBOLS (_ ) \
139
+ FORALL_ATTR_BASE_SYMBOLS (_) \
140
+ FORALL_ATTR_EXTRA_SYMBOLS (_)
105
141
106
- // 'attr' symbols are attribute keys. They are shared between both ONNX and ATen
107
- // operators (you disambiguate their meaning by looking at the operator itself).
108
- // In general, you only need to define attribute keys that are used by
109
- // onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
142
+ #define FORALL_BUILTIN_SYMBOLS (_ ) \
143
+ FORALL_ONNX_SYMBOLS (_) \
144
+ FORALL_ATEN_SYMBOLS (_) \
145
+ FORALL_ATTR_SYMBOLS (_) \
146
+ FORALL_PRIM_SYMBOLS (_) \
147
+ /* end */
110
148
111
149
// Note [Symbol allocation]
112
150
// ~~~~~~~~~~~~~~~~~~~~~~~~
113
151
//
114
- // 1. Symbol namespace is split up into namespaces.
152
+ // 1. Symbol namespace is split up into namespaces. The hex structure
153
+ // of our symbols is TTUUUUUU, where TT is the tag byte and U are the unique
154
+ // bytes.
115
155
//
116
- // 2. The intended access pattern for built-in symbols is onnx::MatMul
156
+ // 2. We only maintain a single counter for the unique bytes, which means that
157
+ // we take 256 more space than we would have if we maintained multiple
158
+ // counters.
159
+ //
160
+ // 3. The first unique_start symbols are reserved for "built-in" symbols.
161
+ // These symbols are allocated at compile time and get put into the intern
162
+ // table at process startup time. Since it's pretty easy to maintain a
163
+ // distinct counter for every built-in namespace, we let the unique bytes of
164
+ // built-in symbols to overlap (this is why unique_start is a max)
165
+ //
166
+ // 4. The intended access pattern for built-in symbols is onnx::MatMul
117
167
// in the torch::jit namespace (this is a Symbol).
118
168
//
169
+ // The code here is not very economical but it gets the job done.
119
170
120
171
121
172
// Built-in constant definition strategy:
@@ -129,15 +180,19 @@ _(attr, name)
129
180
130
181
typedef uint32_t unique_t ;
131
182
183
+ constexpr size_t unique_tag_bits = 8 ;
184
+ constexpr size_t unique_bits = sizeof (unique_t ) * 8 - unique_tag_bits;
185
+ constexpr unique_t unique_mask = (1ULL << unique_bits) - 1 ;
186
+
132
187
static const std::string domain_prefix = " org.pytorch." ;
133
188
134
189
// A Symbol is like an interned string, but with a little extra
135
190
// structure; it is namespaced via SymbolNamespace and the resulting
136
191
// intern pointers support efficient namespace testing.
137
192
struct Symbol {
138
193
explicit constexpr Symbol () : value(0 ) {};
139
- explicit constexpr Symbol (unique_t uniq)
140
- : value(uniq) {}
194
+ explicit constexpr Symbol (SymbolNamespace ns, uint32_t uniq)
195
+ : value(( static_cast < uint32_t >(ns) << unique_bits) | ( uniq & unique_mask)) {};
141
196
142
197
// Get a Symbol for a qualified string like "attr::bar"
143
198
static Symbol fromQualString (const std::string & s);
@@ -150,24 +205,26 @@ struct Symbol {
150
205
// argument "foo", and then attempt to intern it. DO NOT USE THIS
151
206
// with a string literal; attr::foo should be available in that case
152
207
// (and if it's not, you should add it to the built-ins list above.)
153
- static Symbol attr (const std::string & s);
154
- static Symbol aten (const std::string & s);
155
- static Symbol onnx (const std::string & s);
156
- static Symbol prim (const std::string & s);
208
+ static Symbol attr (const std::string & s) { return Symbol (SymbolNamespace::attr, s); } ;
209
+ static Symbol aten (const std::string & s) { return Symbol (SymbolNamespace::aten, s); } ;
210
+ static Symbol onnx (const std::string & s) { return Symbol (SymbolNamespace::onnx, s); } ;
211
+ static Symbol prim (const std::string & s) { return Symbol (SymbolNamespace::prim, s); } ;
157
212
// TODO: eliminate me
158
- static Symbol scope (const std::string & s);
213
+ static Symbol scope (const std::string & s) { return Symbol (SymbolNamespace::scope, s); } ;
159
214
160
- bool is_attr () const ;
161
- bool is_aten () const ;
162
- bool is_prim () const ;
163
- bool is_onnx () const ;
215
+ constexpr bool is_attr () const { return ns () == SymbolNamespace::attr; } ;
216
+ constexpr bool is_aten () const { return ns () == SymbolNamespace::aten; } ;
217
+ constexpr bool is_prim () const { return ns () == SymbolNamespace::prim; } ;
218
+ constexpr bool is_onnx () const { return ns () == SymbolNamespace::onnx; } ;
164
219
165
220
// So we can switch on this
166
221
constexpr operator unique_t () const {
167
222
return value;
168
223
}
169
224
170
- Symbol ns () const ;
225
+ constexpr SymbolNamespace ns () const {
226
+ return static_cast <SymbolNamespace>(value >> unique_bits);
227
+ }
171
228
172
229
// Give a string corresponding to the unqualified version of this name, e.g.,
173
230
// "mm". Use this in a context where the intended namespace of the string is
@@ -189,41 +246,33 @@ struct Symbol {
189
246
std::string domainString () const ;
190
247
191
248
private:
192
- explicit Symbol (Symbol ns, const std::string & s);
249
+ explicit Symbol (SymbolNamespace ns, const std::string & s);
193
250
unique_t value;
194
251
};
195
252
196
253
static inline bool operator ==(Symbol lhs, Symbol rhs) {
197
254
return static_cast <unique_t >(lhs) == static_cast <unique_t >(rhs);
198
255
}
199
256
200
- enum class _keys : unique_t {
201
- #define DEFINE_KEY (ns, s ) ns##_##s,
202
- FORALL_NS_SYMBOLS (DEFINE_KEY)
203
- #undef DEFINE_KEY
204
- num_symbols
205
- };
206
-
207
- #define DEFINE_SYMBOL (s ) \
208
- constexpr Symbol s (static_cast <unique_t >(_keys::s));
257
+ #define DEFINE_KEY (ns, s ) s,
258
+ #define DEFINE_SYMBOL (ns, s ) constexpr Symbol s (SymbolNamespace::ns, static_cast <unique_t >(_keys::s));
259
+ #define DEFINE_BUILTINS (ns, forall_symbols ) \
260
+ namespace ns { \
261
+ enum class _keys : unique_t { \
262
+ forall_symbols (DEFINE_KEY) \
263
+ num_symbols \
264
+ }; \
265
+ forall_symbols (DEFINE_SYMBOL) \
266
+ }
209
267
210
- #undef DEFINE_SYMBOL
268
+ DEFINE_BUILTINS (onnx, FORALL_ONNX_SYMBOLS)
269
+ DEFINE_BUILTINS (aten, FORALL_ATEN_SYMBOLS)
270
+ DEFINE_BUILTINS (attr, FORALL_ATTR_SYMBOLS)
271
+ DEFINE_BUILTINS (prim, FORALL_PRIM_SYMBOLS)
211
272
212
- #define DEFINE_SYMBOL (ns, s ) \
213
- namespace ns { constexpr Symbol s (static_cast <unique_t >(_keys::ns##_##s)); }
214
- FORALL_NS_SYMBOLS (DEFINE_SYMBOL)
273
+ #undef DEFINE_KEY
215
274
#undef DEFINE_SYMBOL
216
275
217
- inline Symbol Symbol::attr (const std::string & s) { return Symbol::fromQualString (" attr::" + s); }
218
- inline Symbol Symbol::aten (const std::string & s) { return Symbol::fromQualString (" aten::" + s); }
219
- inline Symbol Symbol::onnx (const std::string & s) { return Symbol::fromQualString (" onnx::" + s); }
220
- inline Symbol Symbol::prim (const std::string & s) { return Symbol::fromQualString (" prim::" + s); }
221
- inline Symbol Symbol::scope (const std::string & s) { return Symbol::fromQualString (" scope::" + s); }
222
- inline bool Symbol::is_attr () const { return ns () == namespaces::attr; }
223
- inline bool Symbol::is_aten () const { return ns () == namespaces::aten; }
224
- inline bool Symbol::is_prim () const { return ns () == namespaces::prim; }
225
- inline bool Symbol::is_onnx () const { return ns () == namespaces::onnx; }
226
-
227
276
}} // namespace torch::jit
228
277
229
278
// make symbol behave like an integer in hash tables
0 commit comments