Skip to content

Commit 1a55e31

Browse files
committed
Support pickling of extension classes
This operates by providing default implementations of `__getstate__` and `__setstate__` for extension classes. Our implementations work by storing a `__mypyc_attrs__` tuple in each class that we generate and collecting all of the attributes in it into a dict. Fixes #697.
1 parent 9f1b8e9 commit 1a55e31

File tree

8 files changed

+237
-71
lines changed

8 files changed

+237
-71
lines changed

mypy/types.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Classes for representing mypy types."""
22

3+
import copy
34
import sys
45
from abc import abstractmethod
56
from collections import OrderedDict
@@ -17,7 +18,7 @@
1718
FuncDef,
1819
)
1920
from mypy.sharedparse import argument_elide_name
20-
from mypy.util import IdMapper, replace_object_state
21+
from mypy.util import IdMapper
2122
from mypy.bogus_type import Bogus
2223

2324

@@ -2077,13 +2078,7 @@ def copy_type(t: TP) -> TP:
20772078
"""
20782079
Build a copy of the type; used to mutate the copy with truthiness information
20792080
"""
2080-
# We'd like to just do a copy.copy(), but mypyc types aren't
2081-
# pickleable so we hack around it by manually creating a new type
2082-
# and copying everything in with replace_object_state.
2083-
typ = type(t)
2084-
nt = typ.__new__(typ)
2085-
replace_object_state(nt, t, copy_dict=True)
2086-
return nt
2081+
return copy.copy(t)
20872082

20882083

20892084
def function_type(func: mypy.nodes.FuncBase, fallback: Instance) -> FunctionLike:

mypyc/emitclass.py

+8
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,14 @@ def generate_methods_table(cl: ClassIR,
550550
flags.append('METH_CLASS')
551551

552552
emitter.emit_line(' {}, NULL}},'.format(' | '.join(flags)))
553+
554+
# Provide a default __getstate__ and __setstate__
555+
if not cl.has_method('__setstate__') and not cl.has_method('__getstate__'):
556+
emitter.emit_lines(
557+
'{"__setstate__", (PyCFunction)CPyPickle_SetState, METH_O, NULL},',
558+
'{"__getstate__", (PyCFunction)CPyPickle_GetState, METH_NOARGS, NULL},',
559+
)
560+
553561
emitter.emit_line('{NULL} /* Sentinel */')
554562
emitter.emit_line('};')
555563

mypyc/genops.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef,
533533

534534
for name, node in info.names.items():
535535
if isinstance(node.node, Var):
536-
assert node.node.type, "Class member missing type"
536+
assert node.node.type, "Class member %s missing type" % name
537537
if not node.node.is_classvar and name != '__slots__':
538538
ir.attributes[name] = mapper.type_to_rtype(node.node.type)
539539
elif isinstance(node.node, (FuncDef, Decorator)):
@@ -595,6 +595,8 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef,
595595
ir.inherits_python = True
596596
continue
597597
base_ir = mapper.type_to_ir[cls]
598+
if not base_ir.is_ext_class:
599+
ir.inherits_python = True
598600
if not base_ir.is_trait:
599601
base_mro.append(base_ir)
600602
mro.append(base_ir)
@@ -1477,6 +1479,14 @@ def visit_class_def(self, cdef: ClassDef) -> None:
14771479
# Set this attribute back to None until the next non-extension class is visited.
14781480
self.non_ext_info = None
14791481

1482+
def create_mypyc_attrs_tuple(self, ir: ClassIR, line: int) -> Value:
1483+
attrs = [name for ancestor in ir.mro for name in ancestor.attributes]
1484+
if ir.inherits_python:
1485+
attrs.append('__dict__')
1486+
return self.primitive_op(new_tuple_op,
1487+
[self.load_static_unicode(attr) for attr in attrs],
1488+
line)
1489+
14801490
def allocate_class(self, cdef: ClassDef) -> None:
14811491
# OK AND NOW THE FUN PART
14821492
base_exprs = cdef.base_type_exprs + cdef.removed_base_type_exprs
@@ -1496,6 +1506,12 @@ def allocate_class(self, cdef: ClassDef) -> None:
14961506
FuncDecl(cdef.name + '_trait_vtable_setup',
14971507
None, self.module_name,
14981508
FuncSignature([], bool_rprimitive)), [], -1))
1509+
# Populate a '__mypyc_attrs__' field containing the list of attrs
1510+
self.primitive_op(py_setattr_op, [
1511+
tp, self.load_static_unicode('__mypyc_attrs__'),
1512+
self.create_mypyc_attrs_tuple(self.mapper.type_to_ir[cdef.info], cdef.line)],
1513+
cdef.line)
1514+
14991515
# Save the class
15001516
self.add(InitStatic(tp, cdef.name, self.module_name, NAMESPACE_TYPE))
15011517

mypyc/lib-rt/CPy.h

+62
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,68 @@ static int CPy_YieldFromErrorHandle(PyObject *iter, PyObject **outp)
13331333
return 2;
13341334
}
13351335

1336+
// Support for pickling; reusable getstate and setstate functions
1337+
static PyObject *
1338+
CPyPickle_SetState(PyObject *obj, PyObject *state)
1339+
{
1340+
Py_ssize_t pos = 0;
1341+
PyObject *key, *value;
1342+
while (PyDict_Next(state, &pos, &key, &value)) {
1343+
if (PyObject_SetAttr(obj, key, value) != 0) {
1344+
return NULL;
1345+
}
1346+
}
1347+
Py_RETURN_NONE;
1348+
}
1349+
1350+
static PyObject *
1351+
CPyPickle_GetState(PyObject *obj)
1352+
{
1353+
PyObject *attrs = NULL, *state = NULL;
1354+
1355+
attrs = PyObject_GetAttrString((PyObject *)Py_TYPE(obj), "__mypyc_attrs__");
1356+
if (!attrs) {
1357+
goto fail;
1358+
}
1359+
if (!PyTuple_Check(attrs)) {
1360+
PyErr_SetString(PyExc_TypeError, "__mypyc_attrs__ is not a tuple");
1361+
goto fail;
1362+
}
1363+
state = PyDict_New();
1364+
if (!state) {
1365+
goto fail;
1366+
}
1367+
1368+
// Collect all the values of attributes in __mypyc_attrs__
1369+
// Attributes that are missing we just ignore
1370+
int i;
1371+
for (i = 0; i < PyTuple_GET_SIZE(attrs); i++) {
1372+
PyObject *key = PyTuple_GET_ITEM(attrs, i);
1373+
PyObject *value = PyObject_GetAttr(obj, key);
1374+
if (!value) {
1375+
if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
1376+
PyErr_Clear();
1377+
continue;
1378+
}
1379+
goto fail;
1380+
}
1381+
int result = PyDict_SetItem(state, key, value);
1382+
Py_DECREF(value);
1383+
if (result != 0) {
1384+
goto fail;
1385+
}
1386+
}
1387+
1388+
Py_DECREF(attrs);
1389+
1390+
return state;
1391+
fail:
1392+
Py_XDECREF(attrs);
1393+
Py_XDECREF(state);
1394+
return NULL;
1395+
}
1396+
1397+
13361398
int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *,
13371399
const char *, char **, ...);
13381400

mypyc/test-data/genops-basic.test

+3-3
Original file line numberDiff line numberDiff line change
@@ -1588,11 +1588,11 @@ def g(a):
15881588
r4 :: str
15891589
r5, r6 :: None
15901590
L0:
1591-
r0 = unicode_3 :: static ('a')
1591+
r0 = unicode_4 :: static ('a')
15921592
r1 = 0
15931593
r2 = a.f(r1, r0)
15941594
r3 = 1
1595-
r4 = unicode_4 :: static ('b')
1595+
r4 = unicode_5 :: static ('b')
15961596
r5 = a.f(r3, r4)
15971597
r6 = None
15981598
return r6
@@ -1828,7 +1828,7 @@ L1:
18281828
L2:
18291829
if is_error(z) goto L3 else goto L4
18301830
L3:
1831-
r1 = unicode_3 :: static ('test')
1831+
r1 = unicode_4 :: static ('test')
18321832
z = r1
18331833
L4:
18341834
r2 = None

mypyc/test-data/genops-classes.test

+76-58
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def g(a):
9494
r3 :: None
9595
L0:
9696
r0 = 1
97-
r1 = unicode_3 :: static ('hi')
97+
r1 = unicode_4 :: static ('hi')
9898
r2 = a.f(r0, r1)
9999
r3 = None
100100
return r3
@@ -360,31 +360,40 @@ def __top_level__():
360360
r40 :: str
361361
r41, r42 :: object
362362
r43 :: bool
363-
r44 :: dict
364-
r45 :: str
363+
r44 :: str
364+
r45 :: tuple
365365
r46 :: bool
366-
r47 :: object
366+
r47 :: dict
367367
r48 :: str
368-
r49, r50 :: object
369-
r51 :: bool
370-
r52 :: dict
371-
r53 :: str
368+
r49 :: bool
369+
r50 :: object
370+
r51 :: str
371+
r52, r53 :: object
372372
r54 :: bool
373-
r55, r56 :: object
374-
r57 :: dict
375-
r58 :: str
376-
r59 :: object
377-
r60 :: dict
378-
r61 :: str
379-
r62, r63 :: object
380-
r64 :: tuple
381-
r65 :: str
382-
r66, r67 :: object
383-
r68 :: bool
384-
r69 :: dict
385-
r70 :: str
386-
r71 :: bool
387-
r72 :: None
373+
r55 :: str
374+
r56 :: tuple
375+
r57 :: bool
376+
r58 :: dict
377+
r59 :: str
378+
r60 :: bool
379+
r61, r62 :: object
380+
r63 :: dict
381+
r64 :: str
382+
r65 :: object
383+
r66 :: dict
384+
r67 :: str
385+
r68, r69 :: object
386+
r70 :: tuple
387+
r71 :: str
388+
r72, r73 :: object
389+
r74 :: bool
390+
r75, r76 :: str
391+
r77 :: tuple
392+
r78 :: bool
393+
r79 :: dict
394+
r80 :: str
395+
r81 :: bool
396+
r82 :: None
388397
L0:
389398
r0 = builtins.module :: static
390399
r1 = builtins.None :: object
@@ -442,39 +451,49 @@ L6:
442451
r41 = __main__.C_template :: type
443452
r42 = pytype_from_template(r41, r39, r40)
444453
r43 = C_trait_vtable_setup()
454+
r44 = unicode_8 :: static ('__mypyc_attrs__')
455+
r45 = () :: tuple
456+
r46 = setattr r42, r44, r45
445457
__main__.C = r42 :: type
446-
r44 = __main__.globals :: static
447-
r45 = unicode_8 :: static ('C')
448-
r46 = r44.__setitem__(r45, r42) :: dict
449-
r47 = <error> :: object
450-
r48 = unicode_7 :: static ('__main__')
451-
r49 = __main__.S_template :: type
452-
r50 = pytype_from_template(r49, r47, r48)
453-
r51 = S_trait_vtable_setup()
454-
__main__.S = r50 :: type
455-
r52 = __main__.globals :: static
456-
r53 = unicode_9 :: static ('S')
457-
r54 = r52.__setitem__(r53, r50) :: dict
458-
r55 = __main__.C :: type
459-
r56 = __main__.S :: type
460-
r57 = __main__.globals :: static
461-
r58 = unicode_3 :: static ('Generic')
462-
r59 = r57[r58] :: dict
463-
r60 = __main__.globals :: static
464-
r61 = unicode_6 :: static ('T')
465-
r62 = r60[r61] :: dict
466-
r63 = r59[r62] :: object
467-
r64 = (r55, r56, r63) :: tuple
468-
r65 = unicode_7 :: static ('__main__')
469-
r66 = __main__.D_template :: type
470-
r67 = pytype_from_template(r66, r64, r65)
471-
r68 = D_trait_vtable_setup()
472-
__main__.D = r67 :: type
473-
r69 = __main__.globals :: static
474-
r70 = unicode_10 :: static ('D')
475-
r71 = r69.__setitem__(r70, r67) :: dict
476-
r72 = None
477-
return r72
458+
r47 = __main__.globals :: static
459+
r48 = unicode_9 :: static ('C')
460+
r49 = r47.__setitem__(r48, r42) :: dict
461+
r50 = <error> :: object
462+
r51 = unicode_7 :: static ('__main__')
463+
r52 = __main__.S_template :: type
464+
r53 = pytype_from_template(r52, r50, r51)
465+
r54 = S_trait_vtable_setup()
466+
r55 = unicode_8 :: static ('__mypyc_attrs__')
467+
r56 = () :: tuple
468+
r57 = setattr r53, r55, r56
469+
__main__.S = r53 :: type
470+
r58 = __main__.globals :: static
471+
r59 = unicode_10 :: static ('S')
472+
r60 = r58.__setitem__(r59, r53) :: dict
473+
r61 = __main__.C :: type
474+
r62 = __main__.S :: type
475+
r63 = __main__.globals :: static
476+
r64 = unicode_3 :: static ('Generic')
477+
r65 = r63[r64] :: dict
478+
r66 = __main__.globals :: static
479+
r67 = unicode_6 :: static ('T')
480+
r68 = r66[r67] :: dict
481+
r69 = r65[r68] :: object
482+
r70 = (r61, r62, r69) :: tuple
483+
r71 = unicode_7 :: static ('__main__')
484+
r72 = __main__.D_template :: type
485+
r73 = pytype_from_template(r72, r70, r71)
486+
r74 = D_trait_vtable_setup()
487+
r75 = unicode_8 :: static ('__mypyc_attrs__')
488+
r76 = unicode_11 :: static ('__dict__')
489+
r77 = (r76) :: tuple
490+
r78 = setattr r73, r75, r77
491+
__main__.D = r73 :: type
492+
r79 = __main__.globals :: static
493+
r80 = unicode_12 :: static ('D')
494+
r81 = r79.__setitem__(r80, r73) :: dict
495+
r82 = None
496+
return r82
478497

479498
[case testIsInstance]
480499
class A: pass
@@ -785,12 +804,11 @@ def f():
785804
r3 :: int
786805
L0:
787806
r0 = __main__.A :: type
788-
r1 = unicode_5 :: static ('x')
807+
r1 = unicode_6 :: static ('x')
789808
r2 = getattr r0, r1
790809
r3 = unbox(int, r2)
791810
return r3
792811

793-
794812
[case testNoEqDefined]
795813
class A:
796814
pass
@@ -1014,7 +1032,7 @@ L0:
10141032
r0 = 10
10151033
__mypyc_self__.x = r0; r1 = is_error
10161034
r2 = __main__.globals :: static
1017-
r3 = unicode_7 :: static ('LOL')
1035+
r3 = unicode_9 :: static ('LOL')
10181036
r4 = r2[r3] :: dict
10191037
r5 = cast(str, r4)
10201038
__mypyc_self__.y = r5; r6 = is_error

mypyc/test-data/genops-optional.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def set(o, s):
384384
r1 :: bool
385385
r2 :: None
386386
L0:
387-
r0 = unicode_6 :: static ('a')
387+
r0 = unicode_5 :: static ('a')
388388
r1 = setattr o, r0, s
389389
r2 = None
390390
return r2

0 commit comments

Comments
 (0)