Skip to content

Commit 347d7e9

Browse files
eellisonElias Ellison
and
Elias Ellison
authored
Add fix for writing to closures (pytorch#233)
* Add fix for writing to closures * run black * one more time Co-authored-by: Elias Ellison <[email protected]>
1 parent fd0c103 commit 347d7e9

File tree

4 files changed

+82
-11
lines changed

4 files changed

+82
-11
lines changed

tests/test_misc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,38 @@ def fn3():
947947
self.assertEqual(cnts.frame_count, 2)
948948
self.assertEqual(cnts.op_count, 11)
949949

950+
def test_write_to_closures_in_inlining(self):
951+
out = []
952+
for use_dynamo in [False, True]:
953+
954+
def make_counter():
955+
x = torch.randn(10)
956+
957+
def counter():
958+
nonlocal x
959+
x = x + 1
960+
return x
961+
962+
return counter
963+
964+
torch.manual_seed(0)
965+
counter = make_counter()
966+
if not use_dynamo:
967+
out.append(counter() + counter())
968+
else:
969+
cnts = torchdynamo.testing.CompileCounter()
970+
971+
@torchdynamo.optimize(cnts, nopython=True)
972+
def fn(counter):
973+
return counter() + counter()
974+
975+
out.append(fn(counter))
976+
self.assertEqual(cnts.frame_count, 1)
977+
self.assertEqual(cnts.op_count, 3)
978+
self.assertFalse(same(counter() + counter(), out[-1]))
979+
980+
self.assertTrue(same(out[0], out[1]))
981+
950982
def test_top_package_import(self):
951983
def fn(x):
952984
import torch.fx

torchdynamo/side_effects.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def track_cell_new(
187187
self.keepalive.append(obj)
188188
return variable
189189

190+
def track_cell_existing(self, source: Source, item: Any):
191+
variable = variables.NewCellVariable(
192+
mutable_local=AttributeMutationExisting(source),
193+
)
194+
self.id_to_variable[id(item)] = variable
195+
self.keepalive.append(item)
196+
return variable
197+
190198
def prune_dead_object_new(self, tx):
191199
live_new_objects = set()
192200
skip_obj = None
@@ -232,13 +240,14 @@ def codegen(self, cg: PyCodegen):
232240
]
233241

234242
for var in modified_vars:
235-
if isinstance(var.mutable_local, AttributeMutationNew) and isinstance(
236-
var, variables.NewCellVariable
237-
):
243+
if isinstance(
244+
var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
245+
) and isinstance(var, variables.NewCellVariable):
238246
cg.load_import_from(utils.__name__, "make_cell")
239247
cg.extend_output([create_instruction("CALL_FUNCTION", 0)])
240248
cg.add_cache(var)
241-
var.mutable_local.source = LocalSource(cg.tempvars[var])
249+
if isinstance(var.mutable_local, AttributeMutationNew):
250+
var.mutable_local.source = LocalSource(cg.tempvars[var])
242251
elif isinstance(var.mutable_local, AttributeMutationNew):
243252
cg.load_import_from(utils.__name__, "object_new")
244253
cg(var.mutable_local.cls_source)

torchdynamo/symbolic_convert.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,15 @@ def STORE_DEREF(self, inst):
13321332
else:
13331333
self.output.side_effects.store_cell(cell, val)
13341334
else:
1335-
unimplemented("write to __closure__ while inlining")
1335+
if isinstance(
1336+
self.symbolic_locals.get(inst.argval),
1337+
torchdynamo.variables.NewCellVariable,
1338+
):
1339+
self.output.side_effects.store_cell(
1340+
self.symbolic_locals[inst.argval], self.pop()
1341+
)
1342+
else:
1343+
unimplemented("write to __closure__ while inlining")
13361344

13371345
def LOAD_DEREF(self, inst):
13381346
if inst.argval in self.closure_cells:
@@ -1342,7 +1350,11 @@ def LOAD_DEREF(self, inst):
13421350
else:
13431351
self.push(self.output.side_effects.load_cell(cell))
13441352
else:
1345-
super().LOAD_DEREF(inst)
1353+
maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
1354+
if isinstance(maybe_sym_local, torchdynamo.variables.NewCellVariable):
1355+
self.push(self.output.side_effects.load_cell(maybe_sym_local))
1356+
else:
1357+
super().LOAD_DEREF(inst)
13461358

13471359
def LOAD_CLOSURE(self, inst):
13481360
assert inst.argval in self.cell_and_freevars()

torchdynamo/variables/functions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,29 @@ def bind_args(self, parent, args, kwargs):
140140
elif self.source:
141141
from .builder import VariableBuilder
142142

143-
source = AttrSource(
144-
GetItemSource(AttrSource(self.source, "__closure__"), idx),
145-
"cell_contents",
146-
)
147-
result[name] = VariableBuilder(parent, source)(cell.cell_contents)
143+
side_effects = parent.output.side_effects
144+
if cell in side_effects:
145+
out = side_effects[cell]
146+
else:
147+
closure_cell = GetItemSource(
148+
AttrSource(self.source, "__closure__"), idx
149+
)
150+
closure_cell_contents = AttrSource(
151+
closure_cell, "cell_contents"
152+
)
153+
154+
# cells are written to with "cell_contents",
155+
# so the source should just be the closure_cell, not its contents
156+
out = side_effects.track_cell_existing(closure_cell, cell)
157+
side_effects.store_cell(
158+
out,
159+
VariableBuilder(parent, closure_cell_contents)(
160+
cell.cell_contents
161+
),
162+
)
163+
164+
result[name] = out
165+
148166
else:
149167
unimplemented("inline with __closure__")
150168

0 commit comments

Comments
 (0)