diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 84b8bc81edec..953a496ad501 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -339,6 +339,8 @@ PyObject *CPyDict_GetItem(PyObject *dict, PyObject *key); int CPyDict_SetItem(PyObject *dict, PyObject *key, PyObject *value); PyObject *CPyDict_Get(PyObject *dict, PyObject *key, PyObject *fallback); PyObject *CPyDict_GetWithNone(PyObject *dict, PyObject *key); +PyObject *CPyDict_SetDefault(PyObject *dict, PyObject *key, PyObject *value); +PyObject *CPyDict_SetDefaultWithNone(PyObject *dict, PyObject *key); PyObject *CPyDict_Build(Py_ssize_t size, ...); int CPyDict_Update(PyObject *dict, PyObject *stuff); int CPyDict_UpdateInDisplay(PyObject *dict, PyObject *stuff); diff --git a/mypyc/lib-rt/dict_ops.c b/mypyc/lib-rt/dict_ops.c index 4f831d0ba850..c4dbe8d31c32 100644 --- a/mypyc/lib-rt/dict_ops.c +++ b/mypyc/lib-rt/dict_ops.c @@ -67,6 +67,17 @@ PyObject *CPyDict_GetWithNone(PyObject *dict, PyObject *key) { return CPyDict_Get(dict, key, Py_None); } +PyObject *CPyDict_SetDefault(PyObject *dict, PyObject *key, PyObject *value) { + if (PyDict_CheckExact(dict)){ + return PyDict_SetDefault(dict, key, value); + } + return PyObject_CallMethod(dict, "setdefault", "(OO)", key, value); +} + +PyObject *CPyDict_SetDefaultWithNone(PyObject *dict, PyObject *key) { + return CPyDict_SetDefault(dict, key, Py_None); +} + int CPyDict_SetItem(PyObject *dict, PyObject *key, PyObject *value) { if (PyDict_CheckExact(dict)) { return PyDict_SetItem(dict, key, value); diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index ff9b6482a782..267f9d79179d 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -17,6 +17,40 @@ type=object_rprimitive, src='PyDict_Type') +# Construct an empty dictionary. +dict_new_op = custom_op( + arg_types=[], + return_type=dict_rprimitive, + c_function_name='PyDict_New', + error_kind=ERR_MAGIC) + +# Construct a dictionary from keys and values. +# Positional argument is the number of key-value pairs +# Variable arguments are (key1, value1, ..., keyN, valueN). +dict_build_op = custom_op( + arg_types=[c_pyssize_t_rprimitive], + return_type=dict_rprimitive, + c_function_name='CPyDict_Build', + error_kind=ERR_MAGIC, + var_arg_type=object_rprimitive) + +# Construct a dictionary from another dictionary. +function_op( + name='builtins.dict', + arg_types=[dict_rprimitive], + return_type=dict_rprimitive, + c_function_name='PyDict_Copy', + error_kind=ERR_MAGIC, + priority=2) + +# Generic one-argument dict constructor: dict(obj) +function_op( + name='builtins.dict', + arg_types=[object_rprimitive], + return_type=dict_rprimitive, + c_function_name='CPyDict_FromAny', + error_kind=ERR_MAGIC) + # dict[key] dict_get_item_op = method_op( name='__getitem__', @@ -84,38 +118,22 @@ c_function_name='CPyDict_GetWithNone', error_kind=ERR_MAGIC) -# Construct an empty dictionary. -dict_new_op = custom_op( - arg_types=[], - return_type=dict_rprimitive, - c_function_name='PyDict_New', +# dict.setdefault(key, default) +method_op( + name='setdefault', + arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name='CPyDict_SetDefault', + is_borrowed=True, error_kind=ERR_MAGIC) -# Construct a dictionary from keys and values. -# Positional argument is the number of key-value pairs -# Variable arguments are (key1, value1, ..., keyN, valueN). -dict_build_op = custom_op( - arg_types=[c_pyssize_t_rprimitive], - return_type=dict_rprimitive, - c_function_name='CPyDict_Build', - error_kind=ERR_MAGIC, - var_arg_type=object_rprimitive) - -# Construct a dictionary from another dictionary. -function_op( - name='builtins.dict', - arg_types=[dict_rprimitive], - return_type=dict_rprimitive, - c_function_name='PyDict_Copy', - error_kind=ERR_MAGIC, - priority=2) - -# Generic one-argument dict constructor: dict(obj) -function_op( - name='builtins.dict', - arg_types=[object_rprimitive], - return_type=dict_rprimitive, - c_function_name='CPyDict_FromAny', +# dict.setdefault(key) +method_op( + name='setdefault', + arg_types=[dict_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name='CPyDict_SetDefaultWithNone', + is_borrowed=True, error_kind=ERR_MAGIC) # dict.keys() diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 70ecd9d9ab2e..e97f00dd71b6 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -176,6 +176,7 @@ def values(self) -> Iterable[V]: pass def items(self) -> Iterable[Tuple[K, V]]: pass def clear(self) -> None: pass def copy(self) -> Dict[K, V]: pass + def setdefault(self, k: K, v: Optional[V] = None) -> Optional[V]: pass class set(Generic[T]): def __init__(self, i: Optional[Iterable[T]] = None) -> None: pass diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index 1225750b74df..da08ed79d5bd 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -335,3 +335,18 @@ def f(d): L0: r0 = CPyDict_Copy(d) return r0 + +[case testDictSetdefault] +from typing import Dict +def f(d: Dict[object, object]) -> object: + return d.setdefault('a', 'b') +[out] +def f(d): + d :: dict + r0, r1 :: str + r2 :: object +L0: + r0 = 'a' + r1 = 'b' + r2 = CPyDict_SetDefault(d, r0, r1) + return r2 diff --git a/mypyc/test-data/run-dicts.test b/mypyc/test-data/run-dicts.test index 329d186874f4..89188e09b1f1 100644 --- a/mypyc/test-data/run-dicts.test +++ b/mypyc/test-data/run-dicts.test @@ -196,7 +196,7 @@ else: [case testDictMethods] from collections import defaultdict -from typing import Dict +from typing import Dict, Optional def test_dict_clear() -> None: d = {'a': 1, 'b': 2} @@ -217,3 +217,39 @@ def test_dict_copy() -> None: dd['a'] = 1 assert dd.copy() == dd assert isinstance(dd.copy(), defaultdict) + +class MyDict(dict): + def __init__(self, *args, **kwargs): + self.update(*args, **kwargs) + + def setdefault(self, k, v=None): + if v is None: + if k in self.keys(): + return self[k] + else: + return None + else: + return super().setdefault(k, v) + 10 + +def test_dict_setdefault() -> None: + d: Dict[str, int] = {'a': 1, 'b': 2} + assert d.setdefault('a', 2) == 1 + assert d.setdefault('b', 2) == 2 + assert d.setdefault('c', 3) == 3 + assert d['a'] == 1 + assert d['c'] == 3 + assert d.setdefault('a') == 1 + assert d.setdefault('e') == None + assert d.setdefault('e', 100) == None + +def test_dict_subclass_setdefault() -> None: + d = MyDict() + d['a'] = 1 + assert d.setdefault('a', 2) == 11 + assert d.setdefault('b', 2) == 12 + assert d.setdefault('c', 3) == 13 + assert d['a'] == 1 + assert d['c'] == 3 + assert d.setdefault('a') == 1 + assert d.setdefault('e') == None + assert d.setdefault('e', 100) == 110