Skip to content

Commit 3bbe3b7

Browse files
authored
gh-110222: Add support of PyStructSequence in copy.replace() (GH-110223)
1 parent 9561648 commit 3bbe3b7

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

Lib/test/test_structseq.py

+78
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,84 @@ def test_match_args_with_unnamed_fields(self):
264264
self.assertEqual(os.stat_result.n_unnamed_fields, 3)
265265
self.assertEqual(os.stat_result.__match_args__, expected_args)
266266

267+
def test_copy_replace_all_fields_visible(self):
268+
assert os.times_result.n_unnamed_fields == 0
269+
assert os.times_result.n_sequence_fields == os.times_result.n_fields
270+
271+
t = os.times()
272+
273+
# visible fields
274+
self.assertEqual(copy.replace(t), t)
275+
self.assertIsInstance(copy.replace(t), os.times_result)
276+
self.assertEqual(copy.replace(t, user=1.5), (1.5, *t[1:]))
277+
self.assertEqual(copy.replace(t, system=2.5), (t[0], 2.5, *t[2:]))
278+
self.assertEqual(copy.replace(t, user=1.5, system=2.5), (1.5, 2.5, *t[2:]))
279+
280+
# unknown fields
281+
with self.assertRaisesRegex(TypeError, 'unexpected field name'):
282+
copy.replace(t, error=-1)
283+
with self.assertRaisesRegex(TypeError, 'unexpected field name'):
284+
copy.replace(t, user=1, error=-1)
285+
286+
def test_copy_replace_with_invisible_fields(self):
287+
assert time.struct_time.n_unnamed_fields == 0
288+
assert time.struct_time.n_sequence_fields < time.struct_time.n_fields
289+
290+
t = time.gmtime(0)
291+
292+
# visible fields
293+
t2 = copy.replace(t)
294+
self.assertEqual(t2, (1970, 1, 1, 0, 0, 0, 3, 1, 0))
295+
self.assertIsInstance(t2, time.struct_time)
296+
t3 = copy.replace(t, tm_year=2000)
297+
self.assertEqual(t3, (2000, 1, 1, 0, 0, 0, 3, 1, 0))
298+
self.assertEqual(t3.tm_year, 2000)
299+
t4 = copy.replace(t, tm_mon=2)
300+
self.assertEqual(t4, (1970, 2, 1, 0, 0, 0, 3, 1, 0))
301+
self.assertEqual(t4.tm_mon, 2)
302+
t5 = copy.replace(t, tm_year=2000, tm_mon=2)
303+
self.assertEqual(t5, (2000, 2, 1, 0, 0, 0, 3, 1, 0))
304+
self.assertEqual(t5.tm_year, 2000)
305+
self.assertEqual(t5.tm_mon, 2)
306+
307+
# named invisible fields
308+
self.assertTrue(hasattr(t, 'tm_zone'), f"{t} has no attribute 'tm_zone'")
309+
with self.assertRaisesRegex(AttributeError, 'readonly attribute'):
310+
t.tm_zone = 'some other zone'
311+
self.assertEqual(t2.tm_zone, t.tm_zone)
312+
self.assertEqual(t3.tm_zone, t.tm_zone)
313+
self.assertEqual(t4.tm_zone, t.tm_zone)
314+
t6 = copy.replace(t, tm_zone='some other zone')
315+
self.assertEqual(t, t6)
316+
self.assertEqual(t6.tm_zone, 'some other zone')
317+
t7 = copy.replace(t, tm_year=2000, tm_zone='some other zone')
318+
self.assertEqual(t7, (2000, 1, 1, 0, 0, 0, 3, 1, 0))
319+
self.assertEqual(t7.tm_year, 2000)
320+
self.assertEqual(t7.tm_zone, 'some other zone')
321+
322+
# unknown fields
323+
with self.assertRaisesRegex(TypeError, 'unexpected field name'):
324+
copy.replace(t, error=2)
325+
with self.assertRaisesRegex(TypeError, 'unexpected field name'):
326+
copy.replace(t, tm_year=2000, error=2)
327+
with self.assertRaisesRegex(TypeError, 'unexpected field name'):
328+
copy.replace(t, tm_zone='some other zone', error=2)
329+
330+
def test_copy_replace_with_unnamed_fields(self):
331+
assert os.stat_result.n_unnamed_fields > 0
332+
333+
r = os.stat_result(range(os.stat_result.n_sequence_fields))
334+
335+
error_message = re.escape('__replace__() is not supported')
336+
with self.assertRaisesRegex(TypeError, error_message):
337+
copy.replace(r)
338+
with self.assertRaisesRegex(TypeError, error_message):
339+
copy.replace(r, st_mode=1)
340+
with self.assertRaisesRegex(TypeError, error_message):
341+
copy.replace(r, error=2)
342+
with self.assertRaisesRegex(TypeError, error_message):
343+
copy.replace(r, st_mode=1, error=2)
344+
267345

268346
if __name__ == "__main__":
269347
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add support of struct sequence objects in :func:`copy.replace`.
2+
Patched by Xuehai Pan.

Objects/structseq.c

+75-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*/
99

1010
#include "Python.h"
11+
#include "pycore_dict.h" // _PyDict_Pop()
1112
#include "pycore_tuple.h" // _PyTuple_FromArray()
1213
#include "pycore_object.h" // _PyObject_GC_TRACK()
1314

@@ -380,9 +381,82 @@ structseq_reduce(PyStructSequence* self, PyObject *Py_UNUSED(ignored))
380381
return NULL;
381382
}
382383

384+
385+
static PyObject *
386+
structseq_replace(PyStructSequence *self, PyObject *args, PyObject *kwargs)
387+
{
388+
PyStructSequence *result = NULL;
389+
Py_ssize_t n_fields, n_unnamed_fields, i;
390+
391+
if (!_PyArg_NoPositional("__replace__", args)) {
392+
return NULL;
393+
}
394+
395+
n_fields = REAL_SIZE(self);
396+
if (n_fields < 0) {
397+
return NULL;
398+
}
399+
n_unnamed_fields = UNNAMED_FIELDS(self);
400+
if (n_unnamed_fields < 0) {
401+
return NULL;
402+
}
403+
if (n_unnamed_fields > 0) {
404+
PyErr_Format(PyExc_TypeError,
405+
"__replace__() is not supported for %.500s "
406+
"because it has unnamed field(s)",
407+
Py_TYPE(self)->tp_name);
408+
return NULL;
409+
}
410+
411+
result = (PyStructSequence *) PyStructSequence_New(Py_TYPE(self));
412+
if (!result) {
413+
return NULL;
414+
}
415+
416+
if (kwargs != NULL) {
417+
// We do not support types with unnamed fields, so we can iterate over
418+
// i >= n_visible_fields case without slicing with (i - n_unnamed_fields).
419+
for (i = 0; i < n_fields; ++i) {
420+
PyObject *key = PyUnicode_FromString(Py_TYPE(self)->tp_members[i].name);
421+
if (!key) {
422+
goto error;
423+
}
424+
PyObject *ob = _PyDict_Pop(kwargs, key, self->ob_item[i]);
425+
Py_DECREF(key);
426+
if (!ob) {
427+
goto error;
428+
}
429+
result->ob_item[i] = ob;
430+
}
431+
// Check if there are any unexpected fields.
432+
if (PyDict_GET_SIZE(kwargs) > 0) {
433+
PyObject *names = PyDict_Keys(kwargs);
434+
if (names) {
435+
PyErr_Format(PyExc_TypeError, "Got unexpected field name(s): %R", names);
436+
Py_DECREF(names);
437+
}
438+
goto error;
439+
}
440+
}
441+
else
442+
{
443+
// Just create a copy of the original.
444+
for (i = 0; i < n_fields; ++i) {
445+
result->ob_item[i] = Py_NewRef(self->ob_item[i]);
446+
}
447+
}
448+
449+
return (PyObject *)result;
450+
451+
error:
452+
Py_DECREF(result);
453+
return NULL;
454+
}
455+
383456
static PyMethodDef structseq_methods[] = {
384457
{"__reduce__", (PyCFunction)structseq_reduce, METH_NOARGS, NULL},
385-
{NULL, NULL}
458+
{"__replace__", _PyCFunction_CAST(structseq_replace), METH_VARARGS | METH_KEYWORDS, NULL},
459+
{NULL, NULL} // sentinel
386460
};
387461

388462
static Py_ssize_t

0 commit comments

Comments
 (0)