Skip to content

Commit 72340d1

Browse files
belm0Stefan Krah
and
Stefan Krah
authored
gh-114563: C decimal falls back to pydecimal for unsupported format strings (GH-114879)
Immediate merits: * eliminate complex workarounds for 'z' format support (NOTE: mpdecimal recently added 'z' support, so this becomes efficient in the long term.) * fix 'z' format memory leak * fix 'z' format applied to 'F' * fix missing '#' format support Suggested and prototyped by Stefan Krah. Fixes gh-114563, gh-91060 Co-authored-by: Stefan Krah <[email protected]>
1 parent 235cacf commit 72340d1

File tree

3 files changed

+88
-122
lines changed

3 files changed

+88
-122
lines changed

Lib/test/test_decimal.py

+22
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,13 @@ def test_formatting(self):
11101110
('z>z6.1f', '-0.', 'zzz0.0'),
11111111
('x>z6.1f', '-0.', 'xxx0.0'),
11121112
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
1113+
('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char
1114+
1115+
# issue 114563 ('z' format on F type in cdecimal)
1116+
('z3,.10F', '-6.24E-323', '0.0000000000'),
1117+
1118+
# issue 91060 ('#' format in cdecimal)
1119+
('#', '0', '0.'),
11131120

11141121
# issue 6850
11151122
('a=-7.0', '0.12345', 'aaaa0.1'),
@@ -5726,6 +5733,21 @@ def test_c_signaldict_segfault(self):
57265733
with self.assertRaisesRegex(ValueError, err_msg):
57275734
sd.copy()
57285735

5736+
def test_format_fallback_capitals(self):
5737+
# Fallback to _pydecimal formatting (triggered by `#` format which
5738+
# is unsupported by mpdecimal) should honor the current context.
5739+
x = C.Decimal('6.09e+23')
5740+
self.assertEqual(format(x, '#'), '6.09E+23')
5741+
with C.localcontext(capitals=0):
5742+
self.assertEqual(format(x, '#'), '6.09e+23')
5743+
5744+
def test_format_fallback_rounding(self):
5745+
y = C.Decimal('6.09')
5746+
self.assertEqual(format(y, '#.1f'), '6.1')
5747+
with C.localcontext(rounding=C.ROUND_DOWN):
5748+
self.assertEqual(format(y, '#.1f'), '6.0')
5749+
5750+
57295751
@requires_docstrings
57305752
@requires_cdecimal
57315753
class SignatureTest(unittest.TestCase):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`:
2+
* memory leak in some rare cases when using the ``z`` format option (coerce negative 0)
3+
* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``)
4+
* incorrect output when applying the ``#`` format option (alternate form)

Modules/_decimal/_decimal.c

+62-122
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ typedef struct {
8282
/* Convert rationals for comparison */
8383
PyObject *Rational;
8484

85+
/* Invariant: NULL or pointer to _pydecimal.Decimal */
86+
PyObject *PyDecimal;
87+
8588
PyObject *SignalTuple;
8689

8790
struct DecCondMap *signal_map;
@@ -3336,56 +3339,6 @@ dotsep_as_utf8(const char *s)
33363339
return utf8;
33373340
}
33383341

3339-
/* copy of libmpdec _mpd_round() */
3340-
static void
3341-
_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
3342-
const mpd_context_t *ctx, uint32_t *status)
3343-
{
3344-
mpd_ssize_t exp = a->exp + a->digits - prec;
3345-
3346-
if (prec <= 0) {
3347-
mpd_seterror(result, MPD_Invalid_operation, status);
3348-
return;
3349-
}
3350-
if (mpd_isspecial(a) || mpd_iszero(a)) {
3351-
mpd_qcopy(result, a, status);
3352-
return;
3353-
}
3354-
3355-
mpd_qrescale_fmt(result, a, exp, ctx, status);
3356-
if (result->digits > prec) {
3357-
mpd_qrescale_fmt(result, result, exp+1, ctx, status);
3358-
}
3359-
}
3360-
3361-
/* Locate negative zero "z" option within a UTF-8 format spec string.
3362-
* Returns pointer to "z", else NULL.
3363-
* The portion of the spec we're working with is [[fill]align][sign][z] */
3364-
static const char *
3365-
format_spec_z_search(char const *fmt, Py_ssize_t size) {
3366-
char const *pos = fmt;
3367-
char const *fmt_end = fmt + size;
3368-
/* skip over [[fill]align] (fill may be multi-byte character) */
3369-
pos += 1;
3370-
while (pos < fmt_end && *pos & 0x80) {
3371-
pos += 1;
3372-
}
3373-
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
3374-
pos += 1;
3375-
} else {
3376-
/* fill not present-- skip over [align] */
3377-
pos = fmt;
3378-
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
3379-
pos += 1;
3380-
}
3381-
}
3382-
/* skip over [sign] */
3383-
if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
3384-
pos += 1;
3385-
}
3386-
return pos < fmt_end && *pos == 'z' ? pos : NULL;
3387-
}
3388-
33893342
static int
33903343
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
33913344
{
@@ -3411,6 +3364,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
34113364
return 0;
34123365
}
34133366

3367+
/*
3368+
* Fallback _pydecimal formatting for new format specifiers that mpdecimal does
3369+
* not yet support. As documented, libmpdec follows the PEP-3101 format language:
3370+
* https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
3371+
*/
3372+
static PyObject *
3373+
pydec_format(PyObject *dec, PyObject *context, PyObject *fmt, decimal_state *state)
3374+
{
3375+
PyObject *result;
3376+
PyObject *pydec;
3377+
PyObject *u;
3378+
3379+
if (state->PyDecimal == NULL) {
3380+
state->PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
3381+
if (state->PyDecimal == NULL) {
3382+
return NULL;
3383+
}
3384+
}
3385+
3386+
u = dec_str(dec);
3387+
if (u == NULL) {
3388+
return NULL;
3389+
}
3390+
3391+
pydec = PyObject_CallOneArg(state->PyDecimal, u);
3392+
Py_DECREF(u);
3393+
if (pydec == NULL) {
3394+
return NULL;
3395+
}
3396+
3397+
result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
3398+
Py_DECREF(pydec);
3399+
3400+
if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
3401+
/* Do not confuse users with the _pydecimal exception */
3402+
PyErr_Clear();
3403+
PyErr_SetString(PyExc_ValueError, "invalid format string");
3404+
}
3405+
3406+
return result;
3407+
}
3408+
34143409
/* Formatted representation of a PyDecObject. */
34153410
static PyObject *
34163411
dec_format(PyObject *dec, PyObject *args)
@@ -3423,16 +3418,11 @@ dec_format(PyObject *dec, PyObject *args)
34233418
PyObject *fmtarg;
34243419
PyObject *context;
34253420
mpd_spec_t spec;
3426-
char const *fmt;
3427-
char *fmt_copy = NULL;
3421+
char *fmt;
34283422
char *decstring = NULL;
34293423
uint32_t status = 0;
34303424
int replace_fillchar = 0;
3431-
int no_neg_0 = 0;
34323425
Py_ssize_t size;
3433-
mpd_t *mpd = MPD(dec);
3434-
mpd_uint_t dt[MPD_MINALLOC_MAX];
3435-
mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};
34363426

34373427

34383428
decimal_state *state = get_module_state_by_def(Py_TYPE(dec));
@@ -3442,7 +3432,7 @@ dec_format(PyObject *dec, PyObject *args)
34423432
}
34433433

34443434
if (PyUnicode_Check(fmtarg)) {
3445-
fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
3435+
fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
34463436
if (fmt == NULL) {
34473437
return NULL;
34483438
}
@@ -3454,35 +3444,15 @@ dec_format(PyObject *dec, PyObject *args)
34543444
}
34553445
}
34563446

3457-
/* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
3458-
* format string manipulation below can be eliminated by enhancing
3459-
* the forked mpd_parse_fmt_str(). */
34603447
if (size > 0 && fmt[0] == '\0') {
34613448
/* NUL fill character: must be replaced with a valid UTF-8 char
34623449
before calling mpd_parse_fmt_str(). */
34633450
replace_fillchar = 1;
3464-
fmt = fmt_copy = dec_strdup(fmt, size);
3465-
if (fmt_copy == NULL) {
3451+
fmt = dec_strdup(fmt, size);
3452+
if (fmt == NULL) {
34663453
return NULL;
34673454
}
3468-
fmt_copy[0] = '_';
3469-
}
3470-
/* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
3471-
* NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
3472-
char const *z_position = format_spec_z_search(fmt, size);
3473-
if (z_position != NULL) {
3474-
no_neg_0 = 1;
3475-
size_t z_index = z_position - fmt;
3476-
if (fmt_copy == NULL) {
3477-
fmt = fmt_copy = dec_strdup(fmt, size);
3478-
if (fmt_copy == NULL) {
3479-
return NULL;
3480-
}
3481-
}
3482-
/* Shift characters (including null terminator) left,
3483-
overwriting the 'z' option. */
3484-
memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
3485-
size -= 1;
3455+
fmt[0] = '_';
34863456
}
34873457
}
34883458
else {
@@ -3492,10 +3462,13 @@ dec_format(PyObject *dec, PyObject *args)
34923462
}
34933463

34943464
if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
3495-
PyErr_SetString(PyExc_ValueError,
3496-
"invalid format string");
3497-
goto finish;
3465+
if (replace_fillchar) {
3466+
PyMem_Free(fmt);
3467+
}
3468+
3469+
return pydec_format(dec, context, fmtarg, state);
34983470
}
3471+
34993472
if (replace_fillchar) {
35003473
/* In order to avoid clobbering parts of UTF-8 thousands separators or
35013474
decimal points when the substitution is reversed later, the actual
@@ -3548,45 +3521,8 @@ dec_format(PyObject *dec, PyObject *args)
35483521
}
35493522
}
35503523

3551-
if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
3552-
/* Round into a temporary (carefully mirroring the rounding
3553-
of mpd_qformat_spec()), and check if the result is negative zero.
3554-
If so, clear the sign and format the resulting positive zero. */
3555-
mpd_ssize_t prec;
3556-
mpd_qcopy(&tmp, mpd, &status);
3557-
if (spec.prec >= 0) {
3558-
switch (spec.type) {
3559-
case 'f':
3560-
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
3561-
break;
3562-
case '%':
3563-
tmp.exp += 2;
3564-
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
3565-
break;
3566-
case 'g':
3567-
prec = (spec.prec == 0) ? 1 : spec.prec;
3568-
if (tmp.digits > prec) {
3569-
_mpd_round(&tmp, &tmp, prec, CTX(context), &status);
3570-
}
3571-
break;
3572-
case 'e':
3573-
if (!mpd_iszero(&tmp)) {
3574-
_mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
3575-
}
3576-
break;
3577-
}
3578-
}
3579-
if (status & MPD_Errors) {
3580-
PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
3581-
goto finish;
3582-
}
3583-
if (mpd_iszero(&tmp)) {
3584-
mpd_set_positive(&tmp);
3585-
mpd = &tmp;
3586-
}
3587-
}
35883524

3589-
decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
3525+
decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
35903526
if (decstring == NULL) {
35913527
if (status & MPD_Malloc_error) {
35923528
PyErr_NoMemory();
@@ -3609,7 +3545,7 @@ dec_format(PyObject *dec, PyObject *args)
36093545
Py_XDECREF(grouping);
36103546
Py_XDECREF(sep);
36113547
Py_XDECREF(dot);
3612-
if (fmt_copy) PyMem_Free(fmt_copy);
3548+
if (replace_fillchar) PyMem_Free(fmt);
36133549
if (decstring) mpd_free(decstring);
36143550
return result;
36153551
}
@@ -5987,6 +5923,9 @@ _decimal_exec(PyObject *m)
59875923
Py_CLEAR(collections_abc);
59885924
Py_CLEAR(MutableMapping);
59895925

5926+
/* For format specifiers not yet supported by libmpdec */
5927+
state->PyDecimal = NULL;
5928+
59905929
/* Add types to the module */
59915930
CHECK_INT(PyModule_AddType(m, state->PyDec_Type));
59925931
CHECK_INT(PyModule_AddType(m, state->PyDecContext_Type));
@@ -6192,6 +6131,7 @@ decimal_clear(PyObject *module)
61926131
Py_CLEAR(state->extended_context_template);
61936132
Py_CLEAR(state->Rational);
61946133
Py_CLEAR(state->SignalTuple);
6134+
Py_CLEAR(state->PyDecimal);
61956135

61966136
PyMem_Free(state->signal_map);
61976137
PyMem_Free(state->cond_map);

0 commit comments

Comments
 (0)