Skip to content

Commit 20ed344

Browse files
zdevitoElias Ellison
authored and
Elias Ellison
committed
Allow autograd to work even when the shape of values cannot be determined (pytorch#8641)
This commit implements the solution proposed in pytorch#8410 to workaround the need to create zero tensors with the same shape as inputs. It introduces the concept of a LinearBlock which marks places in the code where we know if all the inputs to the node are zero, then the outputs to the node are also zero. Autodiff introduces LinearBlocks around backwards functions, which have this property. specializeUndef then propagates Undef nodes using this information. Notes: * Since we do not always specialize, we have a pass LowerLinearBlocks that replaces the block with an if statement that dynamically guards the Undef case. * We introduce AutogradAdd which is addition that still works when its inputs might be undefined. In cases where we specialize this will get removed in favor of a normal add, but there are cases where gradient graphs do not specialize (e.g. when they are not differentiable, but a derivative is required) so it is important for this op to be executable.
1 parent 9a53368 commit 20ed344

File tree

1 file changed

+113
-64
lines changed

1 file changed

+113
-64
lines changed

torch/csrc/jit/interned_strings.h

Lines changed: 113 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,27 @@
99

1010
namespace torch { namespace jit {
1111

12-
#define FORALL_NS_SYMBOLS(_) \
13-
_(namespaces, prim) \
14-
_(namespaces, aten) \
15-
_(namespaces, onnx) \
16-
_(namespaces, attr) \
17-
_(namespaces, scope) \
18-
_(namespaces, namespaces) \
12+
// Every symbol is classified in a namespace, specifying what kind of symbol it
13+
// is. Unsigned char to ensure widening to unique_t (also an unsigned type)
14+
enum class SymbolNamespace : unsigned char {
15+
onnx = 'o',
16+
prim = 'p',
17+
aten = 't',
18+
// NB: ONNX and ATen attributes all live in a unified namespace, as
19+
// their interpretation depends on the operator name (which is namespaced)
20+
attr = 'a',
21+
// TODO: eliminate me
22+
scope = 's'
23+
};
24+
25+
// Primitive symbols are synthetic operators that occur only in the IR
26+
// and don't have corresponding implementations in ATen.
27+
//
28+
// TODO: We need documentation for all of these symbols.
29+
//
30+
// TODO: Consider moving the synthetic onnx operators to their own
31+
// namespace.
32+
#define FORALL_PRIM_SYMBOLS(_) \
1933
_(prim, Assign) \
2034
_(prim, Constant) \
2135
_(prim, CppOp) \
@@ -48,9 +62,32 @@ _(prim, NumToTensor) \
4862
_(prim, TensorToNum) \
4963
_(prim, AutogradAdd) \
5064
_(prim, GradOf) \
51-
_(prim, AnyDefined) \
65+
_(prim, AnyDefined)
66+
/* end */
67+
68+
// Workaround for some not-yet-defined ATen symbols, see
69+
// - __not__: https://github.com/pytorch/pytorch/issues/5495
70+
// - ones, zeros: https://github.com/pytorch/pytorch/issues/5496
71+
72+
#define FORALL_ATEN_EXTRA_SYMBOLS(_) \
5273
_(aten, __not__) \
74+
/* end */
75+
76+
#define FORALL_ATEN_SYMBOLS(_) \
5377
FORALL_ATEN_BASE_SYMBOLS(_) \
78+
FORALL_ATEN_EXTRA_SYMBOLS(_)
79+
80+
// These symbols correspond to ONNX operators. Their semantics
81+
// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
82+
// The particular version we are targeting is specified by '_onnx_opset_version'
83+
// in torch.onnx.symbolic
84+
//
85+
// In general, most ONNX operators won't get an entry here, because they
86+
// are handled from the Python end. However, you may occasionally need
87+
// to intern an ONNX symbol here so that you can conveniently write an
88+
// optimization on ONNX operations.
89+
90+
#define FORALL_ONNX_SYMBOLS(_) \
5491
_(onnx, Add) \
5592
_(onnx, Concat) \
5693
_(onnx, Constant) \
@@ -72,8 +109,15 @@ _(onnx, Transpose) \
72109
_(onnx, Unsqueeze) \
73110
_(onnx, Loop) \
74111
_(onnx, If) \
75-
_(onnx, Reshape) \
76-
FORALL_ATTR_BASE_SYMBOLS(_) \
112+
_(onnx, Reshape)
113+
/* end */
114+
115+
// These symbols are attribute keys. They are shared between both ONNX and ATen
116+
// operators (you disambiguate their meaning by looking at the operator itself).
117+
// In general, you only need to define attribute keys that are used by
118+
// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
119+
120+
#define FORALL_ATTR_EXTRA_SYMBOLS(_) \
77121
_(attr, Subgraph) \
78122
_(attr, axes) \
79123
_(attr, axis) \
@@ -89,33 +133,40 @@ _(attr, starts) \
89133
_(attr, transA) \
90134
_(attr, transB) \
91135
_(attr, name)
136+
/* end */
92137

93-
// 'prim' symbols are synthetic operators that occur only in the IR
94-
// and don't have corresponding implementations in ATen.
95-
96-
// 'onnx' symbols correspond to ONNX operators. Their semantics
97-
// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
98-
// The particular version we are targeting is specified by '_onnx_opset_version'
99-
// in torch.onnx.symbolic
100-
//
101-
// In general, most ONNX operators won't get an entry here, because they
102-
// are handled from the Python end. However, you may occasionally need
103-
// to intern an ONNX symbol here so that you can conveniently write an
104-
// optimization on ONNX operations.
138+
#define FORALL_ATTR_SYMBOLS(_) \
139+
FORALL_ATTR_BASE_SYMBOLS(_) \
140+
FORALL_ATTR_EXTRA_SYMBOLS(_)
105141

106-
// 'attr' symbols are attribute keys. They are shared between both ONNX and ATen
107-
// operators (you disambiguate their meaning by looking at the operator itself).
108-
// In general, you only need to define attribute keys that are used by
109-
// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
142+
#define FORALL_BUILTIN_SYMBOLS(_) \
143+
FORALL_ONNX_SYMBOLS(_) \
144+
FORALL_ATEN_SYMBOLS(_) \
145+
FORALL_ATTR_SYMBOLS(_) \
146+
FORALL_PRIM_SYMBOLS(_) \
147+
/* end */
110148

111149
// Note [Symbol allocation]
112150
// ~~~~~~~~~~~~~~~~~~~~~~~~
113151
//
114-
// 1. Symbol namespace is split up into namespaces.
152+
// 1. Symbol namespace is split up into namespaces. The hex structure
153+
// of our symbols is TTUUUUUU, where TT is the tag byte and U are the unique
154+
// bytes.
115155
//
116-
// 2. The intended access pattern for built-in symbols is onnx::MatMul
156+
// 2. We only maintain a single counter for the unique bytes, which means that
157+
// we take 256 more space than we would have if we maintained multiple
158+
// counters.
159+
//
160+
// 3. The first unique_start symbols are reserved for "built-in" symbols.
161+
// These symbols are allocated at compile time and get put into the intern
162+
// table at process startup time. Since it's pretty easy to maintain a
163+
// distinct counter for every built-in namespace, we let the unique bytes of
164+
// built-in symbols to overlap (this is why unique_start is a max)
165+
//
166+
// 4. The intended access pattern for built-in symbols is onnx::MatMul
117167
// in the torch::jit namespace (this is a Symbol).
118168
//
169+
// The code here is not very economical but it gets the job done.
119170

120171

121172
// Built-in constant definition strategy:
@@ -129,15 +180,19 @@ _(attr, name)
129180

130181
typedef uint32_t unique_t;
131182

183+
constexpr size_t unique_tag_bits = 8;
184+
constexpr size_t unique_bits = sizeof(unique_t) * 8 - unique_tag_bits;
185+
constexpr unique_t unique_mask = (1ULL << unique_bits) - 1;
186+
132187
static const std::string domain_prefix = "org.pytorch.";
133188

134189
// A Symbol is like an interned string, but with a little extra
135190
// structure; it is namespaced via SymbolNamespace and the resulting
136191
// intern pointers support efficient namespace testing.
137192
struct Symbol {
138193
explicit constexpr Symbol() : value(0) {};
139-
explicit constexpr Symbol(unique_t uniq)
140-
: value(uniq) {}
194+
explicit constexpr Symbol(SymbolNamespace ns, uint32_t uniq)
195+
: value((static_cast<uint32_t>(ns) << unique_bits) | (uniq & unique_mask)) {};
141196

142197
// Get a Symbol for a qualified string like "attr::bar"
143198
static Symbol fromQualString(const std::string & s);
@@ -150,24 +205,26 @@ struct Symbol {
150205
// argument "foo", and then attempt to intern it. DO NOT USE THIS
151206
// with a string literal; attr::foo should be available in that case
152207
// (and if it's not, you should add it to the built-ins list above.)
153-
static Symbol attr(const std::string & s);
154-
static Symbol aten(const std::string & s);
155-
static Symbol onnx(const std::string & s);
156-
static Symbol prim(const std::string & s);
208+
static Symbol attr(const std::string & s) { return Symbol(SymbolNamespace::attr, s); };
209+
static Symbol aten(const std::string & s) { return Symbol(SymbolNamespace::aten, s); };
210+
static Symbol onnx(const std::string & s) { return Symbol(SymbolNamespace::onnx, s); };
211+
static Symbol prim(const std::string & s) { return Symbol(SymbolNamespace::prim, s); };
157212
// TODO: eliminate me
158-
static Symbol scope(const std::string & s);
213+
static Symbol scope(const std::string & s) { return Symbol(SymbolNamespace::scope, s); };
159214

160-
bool is_attr() const;
161-
bool is_aten() const;
162-
bool is_prim() const;
163-
bool is_onnx() const;
215+
constexpr bool is_attr() const { return ns() == SymbolNamespace::attr; };
216+
constexpr bool is_aten() const { return ns() == SymbolNamespace::aten; };
217+
constexpr bool is_prim() const { return ns() == SymbolNamespace::prim; };
218+
constexpr bool is_onnx() const { return ns() == SymbolNamespace::onnx; };
164219

165220
// So we can switch on this
166221
constexpr operator unique_t() const {
167222
return value;
168223
}
169224

170-
Symbol ns() const;
225+
constexpr SymbolNamespace ns() const {
226+
return static_cast<SymbolNamespace>(value >> unique_bits);
227+
}
171228

172229
// Give a string corresponding to the unqualified version of this name, e.g.,
173230
// "mm". Use this in a context where the intended namespace of the string is
@@ -189,41 +246,33 @@ struct Symbol {
189246
std::string domainString() const;
190247

191248
private:
192-
explicit Symbol(Symbol ns, const std::string & s);
249+
explicit Symbol(SymbolNamespace ns, const std::string & s);
193250
unique_t value;
194251
};
195252

196253
static inline bool operator==(Symbol lhs, Symbol rhs) {
197254
return static_cast<unique_t>(lhs) == static_cast<unique_t>(rhs);
198255
}
199256

200-
enum class _keys : unique_t {
201-
#define DEFINE_KEY(ns, s) ns##_##s,
202-
FORALL_NS_SYMBOLS(DEFINE_KEY)
203-
#undef DEFINE_KEY
204-
num_symbols
205-
};
206-
207-
#define DEFINE_SYMBOL(s) \
208-
constexpr Symbol s(static_cast<unique_t>(_keys::s));
257+
#define DEFINE_KEY(ns, s) s,
258+
#define DEFINE_SYMBOL(ns, s) constexpr Symbol s(SymbolNamespace::ns, static_cast<unique_t>(_keys::s));
259+
#define DEFINE_BUILTINS(ns, forall_symbols) \
260+
namespace ns { \
261+
enum class _keys : unique_t { \
262+
forall_symbols(DEFINE_KEY) \
263+
num_symbols \
264+
}; \
265+
forall_symbols(DEFINE_SYMBOL) \
266+
}
209267

210-
#undef DEFINE_SYMBOL
268+
DEFINE_BUILTINS(onnx, FORALL_ONNX_SYMBOLS)
269+
DEFINE_BUILTINS(aten, FORALL_ATEN_SYMBOLS)
270+
DEFINE_BUILTINS(attr, FORALL_ATTR_SYMBOLS)
271+
DEFINE_BUILTINS(prim, FORALL_PRIM_SYMBOLS)
211272

212-
#define DEFINE_SYMBOL(ns, s) \
213-
namespace ns { constexpr Symbol s(static_cast<unique_t>(_keys::ns##_##s)); }
214-
FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
273+
#undef DEFINE_KEY
215274
#undef DEFINE_SYMBOL
216275

217-
inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
218-
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
219-
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
220-
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
221-
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
222-
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
223-
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
224-
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
225-
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
226-
227276
}} // namespace torch::jit
228277

229278
// make symbol behave like an integer in hash tables

0 commit comments

Comments
 (0)