Skip to content

Commit 12fa434

Browse files
authored
pickle ContextVars in process replay [pr] (tinygrad#8484)
* pickle ContextVars in process replay * add test_pickle_context_var [pr] * more realistic
1 parent bd4d7dc commit 12fa434

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

test/external/process_replay/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ def get_process_replay_ctx() -> Tuple[ProcessReplayContext, Dict]:
1414
loc = "\n".join(traceback.format_list(stack))
1515
try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
1616
except Exception: head_sha = ""
17-
return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), {k:v.value for k,v in ContextVar._cache.items()}
17+
return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), ContextVar._cache

test/external/process_replay/process_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
6060
continue
6161
# try recreate
6262
try:
63-
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
63+
with Context(**{k:v.value for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
6464
if good is None: continue
6565
except Exception as e:
6666
changed += 1

test/test_pickle.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest, pickle, types
22
import numpy as np
33
from tinygrad import Tensor, TinyJit, Variable, dtypes
4-
from tinygrad.helpers import GlobalCounters
4+
from tinygrad.helpers import GlobalCounters, ContextVar, Context
55
from tinygrad.ops import PatternMatcher, UPat, UOp
66

77
class TestPickle(unittest.TestCase):
@@ -95,6 +95,13 @@ def add(a, b): return a.sum()+b+1
9595
out = add_fxn(x, y)
9696
np.testing.assert_equal(out.numpy(), 102)
9797

98+
def test_pickle_context_var(self):
99+
v = ContextVar("test_var", 0)
100+
with Context(test_var=1):
101+
vs = pickle.dumps(v)
102+
v2 = pickle.loads(vs)
103+
self.assertEqual(v2.value, 1)
104+
98105
def test_pickle_schedule(self):
99106
a = Tensor([1,2])
100107
out = a + 2

tinygrad/engine/schedule.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
276276
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
277277
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
278278
# capture process replay
279-
if getenv("RUN_PROCESS_REPLAY"):
280-
PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, {k:v.value for k,v in ContextVar._cache.items()}, sink))
279+
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
281280
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
282281
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
283282

0 commit comments

Comments
 (0)