diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 6360d052fa6d..223a0b3e6be9 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -377,6 +377,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split); +PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace); PyObject *CPyStr_Append(PyObject *o1, PyObject *o2); PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end); bool CPyStr_Startswith(PyObject *self, PyObject *subobj); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 87e473e27574..16edc21478dd 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -53,6 +53,16 @@ PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) return PyUnicode_Split(str, sep, temp_max_split); } +PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace) +{ + Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace); + if (temp_max_replace == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, "Python int too large to convert to C ssize_t"); + return NULL; + } + return PyUnicode_Replace(str, old_substr, new_substr, temp_max_replace); +} + bool CPyStr_Startswith(PyObject *self, PyObject *subobj) { Py_ssize_t start = 0; Py_ssize_t end = PyUnicode_GET_LENGTH(self); diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 6eb24be33e51..f12cf997e4a5 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -109,3 +109,20 @@ return_type=object_rprimitive, c_function_name='CPyStr_GetSlice', error_kind=ERR_MAGIC) + +# str.replace(old, new) +method_op( + name='replace', + arg_types=[str_rprimitive, str_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name="PyUnicode_Replace", + error_kind=ERR_MAGIC, + extra_int_constants=[(-1, c_int_rprimitive)]) + +# str.replace(old, new, count) +method_op( + name='replace', + arg_types=[str_rprimitive, str_rprimitive, str_rprimitive, int_rprimitive], + return_type=str_rprimitive, + c_function_name="CPyStr_Replace", + error_kind=ERR_MAGIC) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index bb85a41fbe24..b4a379e86c01 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -75,6 +75,7 @@ def format(self, *args: Any, **kwargs: Any) -> str: ... def upper(self) -> str: pass def startswith(self, x: str, start: int=..., end: int=...) -> bool: pass def endswith(self, x: str, start: int=..., end: int=...) -> bool: pass + def replace(self, old: str, new: str, maxcount: Optional[int] = None) -> str: pass class float: def __init__(self, x: object) -> None: pass diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index c573871d15a4..4a5774c67d64 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -103,3 +103,39 @@ L3: r5 = r0 != 0 return r5 +[case testStrReplace] +from typing import Optional + +def do_replace(s: str, old_substr: str, new_substr: str, max_count: Optional[int] = None) -> str: + if max_count is not None: + return s.replace(old_substr, new_substr, max_count) + else: + return s.replace(old_substr, new_substr) +[out] +def do_replace(s, old_substr, new_substr, max_count): + s, old_substr, new_substr :: str + max_count :: union[int, None] + r0, r1 :: object + r2, r3 :: bit + r4 :: int + r5, r6 :: str +L0: + if is_error(max_count) goto L1 else goto L2 +L1: + r0 = box(None, 1) + max_count = r0 +L2: + r1 = box(None, 1) + r2 = max_count == r1 + r3 = r2 ^ 1 + if r3 goto L3 else goto L4 :: bool +L3: + r4 = unbox(int, max_count) + r5 = CPyStr_Replace(s, old_substr, new_substr, r4) + return r5 +L4: + r6 = PyUnicode_Replace(s, old_substr, new_substr, -1) + return r6 +L5: + unreachable + \ No newline at end of file diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 5e5fee818a59..7740e77eb8e3 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -150,3 +150,12 @@ def test_slicing() -> None: assert s[1:big_int] == "oobar" assert s[big_int:] == "" assert s[-big_int:-1] == "fooba" + +def test_str_replace() -> None: + a = "foofoofoo" + assert a.replace("foo", "bar") == "barbarbar" + assert a.replace("foo", "bar", -1) == "barbarbar" + assert a.replace("foo", "bar", 1) == "barfoofoo" + assert a.replace("foo", "bar", 4) == "barbarbar" + assert a.replace("aaa", "bar") == "foofoofoo" + assert a.replace("ofo", "xyzw") == "foxyzwxyzwo"