Skip to content

Commit f463a39

Browse files
[mypyc] Implement dict setdefault primitive (#10286)
Related issue: mypyc/mypyc#644 Add a new primitive for dict.setdefault(). Move some dict creation primitives to beginning of dict_ops.py for better code structure.
1 parent 3cebc97 commit f463a39

File tree

6 files changed

+114
-31
lines changed

6 files changed

+114
-31
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ PyObject *CPyDict_GetItem(PyObject *dict, PyObject *key);
339339
int CPyDict_SetItem(PyObject *dict, PyObject *key, PyObject *value);
340340
PyObject *CPyDict_Get(PyObject *dict, PyObject *key, PyObject *fallback);
341341
PyObject *CPyDict_GetWithNone(PyObject *dict, PyObject *key);
342+
PyObject *CPyDict_SetDefault(PyObject *dict, PyObject *key, PyObject *value);
343+
PyObject *CPyDict_SetDefaultWithNone(PyObject *dict, PyObject *key);
342344
PyObject *CPyDict_Build(Py_ssize_t size, ...);
343345
int CPyDict_Update(PyObject *dict, PyObject *stuff);
344346
int CPyDict_UpdateInDisplay(PyObject *dict, PyObject *stuff);

mypyc/lib-rt/dict_ops.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ PyObject *CPyDict_GetWithNone(PyObject *dict, PyObject *key) {
6767
return CPyDict_Get(dict, key, Py_None);
6868
}
6969

70+
PyObject *CPyDict_SetDefault(PyObject *dict, PyObject *key, PyObject *value) {
71+
if (PyDict_CheckExact(dict)){
72+
return PyDict_SetDefault(dict, key, value);
73+
}
74+
return PyObject_CallMethod(dict, "setdefault", "(OO)", key, value);
75+
}
76+
77+
PyObject *CPyDict_SetDefaultWithNone(PyObject *dict, PyObject *key) {
78+
return CPyDict_SetDefault(dict, key, Py_None);
79+
}
80+
7081
int CPyDict_SetItem(PyObject *dict, PyObject *key, PyObject *value) {
7182
if (PyDict_CheckExact(dict)) {
7283
return PyDict_SetItem(dict, key, value);

mypyc/primitives/dict_ops.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,40 @@
1717
type=object_rprimitive,
1818
src='PyDict_Type')
1919

20+
# Construct an empty dictionary.
21+
dict_new_op = custom_op(
22+
arg_types=[],
23+
return_type=dict_rprimitive,
24+
c_function_name='PyDict_New',
25+
error_kind=ERR_MAGIC)
26+
27+
# Construct a dictionary from keys and values.
28+
# Positional argument is the number of key-value pairs
29+
# Variable arguments are (key1, value1, ..., keyN, valueN).
30+
dict_build_op = custom_op(
31+
arg_types=[c_pyssize_t_rprimitive],
32+
return_type=dict_rprimitive,
33+
c_function_name='CPyDict_Build',
34+
error_kind=ERR_MAGIC,
35+
var_arg_type=object_rprimitive)
36+
37+
# Construct a dictionary from another dictionary.
38+
function_op(
39+
name='builtins.dict',
40+
arg_types=[dict_rprimitive],
41+
return_type=dict_rprimitive,
42+
c_function_name='PyDict_Copy',
43+
error_kind=ERR_MAGIC,
44+
priority=2)
45+
46+
# Generic one-argument dict constructor: dict(obj)
47+
function_op(
48+
name='builtins.dict',
49+
arg_types=[object_rprimitive],
50+
return_type=dict_rprimitive,
51+
c_function_name='CPyDict_FromAny',
52+
error_kind=ERR_MAGIC)
53+
2054
# dict[key]
2155
dict_get_item_op = method_op(
2256
name='__getitem__',
@@ -84,38 +118,22 @@
84118
c_function_name='CPyDict_GetWithNone',
85119
error_kind=ERR_MAGIC)
86120

87-
# Construct an empty dictionary.
88-
dict_new_op = custom_op(
89-
arg_types=[],
90-
return_type=dict_rprimitive,
91-
c_function_name='PyDict_New',
121+
# dict.setdefault(key, default)
122+
method_op(
123+
name='setdefault',
124+
arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive],
125+
return_type=object_rprimitive,
126+
c_function_name='CPyDict_SetDefault',
127+
is_borrowed=True,
92128
error_kind=ERR_MAGIC)
93129

94-
# Construct a dictionary from keys and values.
95-
# Positional argument is the number of key-value pairs
96-
# Variable arguments are (key1, value1, ..., keyN, valueN).
97-
dict_build_op = custom_op(
98-
arg_types=[c_pyssize_t_rprimitive],
99-
return_type=dict_rprimitive,
100-
c_function_name='CPyDict_Build',
101-
error_kind=ERR_MAGIC,
102-
var_arg_type=object_rprimitive)
103-
104-
# Construct a dictionary from another dictionary.
105-
function_op(
106-
name='builtins.dict',
107-
arg_types=[dict_rprimitive],
108-
return_type=dict_rprimitive,
109-
c_function_name='PyDict_Copy',
110-
error_kind=ERR_MAGIC,
111-
priority=2)
112-
113-
# Generic one-argument dict constructor: dict(obj)
114-
function_op(
115-
name='builtins.dict',
116-
arg_types=[object_rprimitive],
117-
return_type=dict_rprimitive,
118-
c_function_name='CPyDict_FromAny',
130+
# dict.setdefault(key)
131+
method_op(
132+
name='setdefault',
133+
arg_types=[dict_rprimitive, object_rprimitive],
134+
return_type=object_rprimitive,
135+
c_function_name='CPyDict_SetDefaultWithNone',
136+
is_borrowed=True,
119137
error_kind=ERR_MAGIC)
120138

121139
# dict.keys()

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def values(self) -> Iterable[V]: pass
176176
def items(self) -> Iterable[Tuple[K, V]]: pass
177177
def clear(self) -> None: pass
178178
def copy(self) -> Dict[K, V]: pass
179+
def setdefault(self, k: K, v: Optional[V] = None) -> Optional[V]: pass
179180

180181
class set(Generic[T]):
181182
def __init__(self, i: Optional[Iterable[T]] = None) -> None: pass

mypyc/test-data/irbuild-dict.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,18 @@ def f(d):
335335
L0:
336336
r0 = CPyDict_Copy(d)
337337
return r0
338+
339+
[case testDictSetdefault]
340+
from typing import Dict
341+
def f(d: Dict[object, object]) -> object:
342+
return d.setdefault('a', 'b')
343+
[out]
344+
def f(d):
345+
d :: dict
346+
r0, r1 :: str
347+
r2 :: object
348+
L0:
349+
r0 = 'a'
350+
r1 = 'b'
351+
r2 = CPyDict_SetDefault(d, r0, r1)
352+
return r2

mypyc/test-data/run-dicts.test

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ else:
196196

197197
[case testDictMethods]
198198
from collections import defaultdict
199-
from typing import Dict
199+
from typing import Dict, Optional
200200

201201
def test_dict_clear() -> None:
202202
d = {'a': 1, 'b': 2}
@@ -217,3 +217,39 @@ def test_dict_copy() -> None:
217217
dd['a'] = 1
218218
assert dd.copy() == dd
219219
assert isinstance(dd.copy(), defaultdict)
220+
221+
class MyDict(dict):
222+
def __init__(self, *args, **kwargs):
223+
self.update(*args, **kwargs)
224+
225+
def setdefault(self, k, v=None):
226+
if v is None:
227+
if k in self.keys():
228+
return self[k]
229+
else:
230+
return None
231+
else:
232+
return super().setdefault(k, v) + 10
233+
234+
def test_dict_setdefault() -> None:
235+
d: Dict[str, int] = {'a': 1, 'b': 2}
236+
assert d.setdefault('a', 2) == 1
237+
assert d.setdefault('b', 2) == 2
238+
assert d.setdefault('c', 3) == 3
239+
assert d['a'] == 1
240+
assert d['c'] == 3
241+
assert d.setdefault('a') == 1
242+
assert d.setdefault('e') == None
243+
assert d.setdefault('e', 100) == None
244+
245+
def test_dict_subclass_setdefault() -> None:
246+
d = MyDict()
247+
d['a'] = 1
248+
assert d.setdefault('a', 2) == 11
249+
assert d.setdefault('b', 2) == 12
250+
assert d.setdefault('c', 3) == 13
251+
assert d['a'] == 1
252+
assert d['c'] == 3
253+
assert d.setdefault('a') == 1
254+
assert d.setdefault('e') == None
255+
assert d.setdefault('e', 100) == 110

0 commit comments

Comments
 (0)