Skip to content

Commit b86fdf2

Browse files
belm0Stefan Krah
and
Stefan Krah
authored
[3.11] gh-114563: C decimal falls back to pydecimal for unsupported format strings (GH-114879) (GH-115384)
Immediate merits: * eliminate complex workarounds for 'z' format support (NOTE: mpdecimal recently added 'z' support, so this becomes efficient in the long term.) * fix 'z' format memory leak * fix 'z' format applied to 'F' * fix missing 'GH-' format support Suggested and prototyped by Stefan Krah. Fixes gh-114563, gh-91060 (cherry picked from commit 72340d1) (cherry picked from commit 09c98e4) Co-authored-by: Stefan Krah <[email protected]>
1 parent d87a022 commit b86fdf2

File tree

4 files changed

+87
-122
lines changed

4 files changed

+87
-122
lines changed

Lib/test/test_decimal.py

+22
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,13 @@ def test_formatting(self):
11211121
('z>z6.1f', '-0.', 'zzz0.0'),
11221122
('x>z6.1f', '-0.', 'xxx0.0'),
11231123
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
1124+
('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char
1125+
1126+
# issue 114563 ('z' format on F type in cdecimal)
1127+
('z3,.10F', '-6.24E-323', '0.0000000000'),
1128+
1129+
# issue 91060 ('#' format in cdecimal)
1130+
('#', '0', '0.'),
11241131

11251132
# issue 6850
11261133
('a=-7.0', '0.12345', 'aaaa0.1'),
@@ -5712,6 +5719,21 @@ def test_c_signaldict_segfault(self):
57125719
with self.assertRaisesRegex(ValueError, err_msg):
57135720
sd.copy()
57145721

5722+
def test_format_fallback_capitals(self):
5723+
# Fallback to _pydecimal formatting (triggered by `#` format which
5724+
# is unsupported by mpdecimal) should honor the current context.
5725+
x = C.Decimal('6.09e+23')
5726+
self.assertEqual(format(x, '#'), '6.09E+23')
5727+
with C.localcontext(capitals=0):
5728+
self.assertEqual(format(x, '#'), '6.09e+23')
5729+
5730+
def test_format_fallback_rounding(self):
5731+
y = C.Decimal('6.09')
5732+
self.assertEqual(format(y, '#.1f'), '6.1')
5733+
with C.localcontext(rounding=C.ROUND_DOWN):
5734+
self.assertEqual(format(y, '#.1f'), '6.0')
5735+
5736+
57155737
@requires_docstrings
57165738
@requires_cdecimal
57175739
class SignatureTest(unittest.TestCase):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`:
2+
* memory leak in some rare cases when using the ``z`` format option (coerce negative 0)
3+
* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``)
4+
* incorrect output when applying the ``#`` format option (alternate form)

Modules/_decimal/_decimal.c

+60-122
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ static PyObject *default_context_template = NULL;
144144
static PyObject *basic_context_template = NULL;
145145
static PyObject *extended_context_template = NULL;
146146

147+
/* Invariant: NULL or pointer to _pydecimal.Decimal */
148+
static PyObject *PyDecimal = NULL;
147149

148150
/* Error codes for functions that return signals or conditions */
149151
#define DEC_INVALID_SIGNALS (MPD_Max_status+1U)
@@ -3245,56 +3247,6 @@ dotsep_as_utf8(const char *s)
32453247
return utf8;
32463248
}
32473249

3248-
/* copy of libmpdec _mpd_round() */
3249-
static void
3250-
_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
3251-
const mpd_context_t *ctx, uint32_t *status)
3252-
{
3253-
mpd_ssize_t exp = a->exp + a->digits - prec;
3254-
3255-
if (prec <= 0) {
3256-
mpd_seterror(result, MPD_Invalid_operation, status);
3257-
return;
3258-
}
3259-
if (mpd_isspecial(a) || mpd_iszero(a)) {
3260-
mpd_qcopy(result, a, status);
3261-
return;
3262-
}
3263-
3264-
mpd_qrescale_fmt(result, a, exp, ctx, status);
3265-
if (result->digits > prec) {
3266-
mpd_qrescale_fmt(result, result, exp+1, ctx, status);
3267-
}
3268-
}
3269-
3270-
/* Locate negative zero "z" option within a UTF-8 format spec string.
3271-
* Returns pointer to "z", else NULL.
3272-
* The portion of the spec we're working with is [[fill]align][sign][z] */
3273-
static const char *
3274-
format_spec_z_search(char const *fmt, Py_ssize_t size) {
3275-
char const *pos = fmt;
3276-
char const *fmt_end = fmt + size;
3277-
/* skip over [[fill]align] (fill may be multi-byte character) */
3278-
pos += 1;
3279-
while (pos < fmt_end && *pos & 0x80) {
3280-
pos += 1;
3281-
}
3282-
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
3283-
pos += 1;
3284-
} else {
3285-
/* fill not present-- skip over [align] */
3286-
pos = fmt;
3287-
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
3288-
pos += 1;
3289-
}
3290-
}
3291-
/* skip over [sign] */
3292-
if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
3293-
pos += 1;
3294-
}
3295-
return pos < fmt_end && *pos == 'z' ? pos : NULL;
3296-
}
3297-
32983250
static int
32993251
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
33003252
{
@@ -3320,6 +3272,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
33203272
return 0;
33213273
}
33223274

3275+
/*
3276+
* Fallback _pydecimal formatting for new format specifiers that mpdecimal does
3277+
* not yet support. As documented, libmpdec follows the PEP-3101 format language:
3278+
* https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
3279+
*/
3280+
static PyObject *
3281+
pydec_format(PyObject *dec, PyObject *context, PyObject *fmt)
3282+
{
3283+
PyObject *result;
3284+
PyObject *pydec;
3285+
PyObject *u;
3286+
3287+
if (PyDecimal == NULL) {
3288+
PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
3289+
if (PyDecimal == NULL) {
3290+
return NULL;
3291+
}
3292+
}
3293+
3294+
u = dec_str(dec);
3295+
if (u == NULL) {
3296+
return NULL;
3297+
}
3298+
3299+
pydec = PyObject_CallOneArg(PyDecimal, u);
3300+
Py_DECREF(u);
3301+
if (pydec == NULL) {
3302+
return NULL;
3303+
}
3304+
3305+
result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
3306+
Py_DECREF(pydec);
3307+
3308+
if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
3309+
/* Do not confuse users with the _pydecimal exception */
3310+
PyErr_Clear();
3311+
PyErr_SetString(PyExc_ValueError, "invalid format string");
3312+
}
3313+
3314+
return result;
3315+
}
3316+
33233317
/* Formatted representation of a PyDecObject. */
33243318
static PyObject *
33253319
dec_format(PyObject *dec, PyObject *args)
@@ -3332,16 +3326,11 @@ dec_format(PyObject *dec, PyObject *args)
33323326
PyObject *fmtarg;
33333327
PyObject *context;
33343328
mpd_spec_t spec;
3335-
char const *fmt;
3336-
char *fmt_copy = NULL;
3329+
char *fmt;
33373330
char *decstring = NULL;
33383331
uint32_t status = 0;
33393332
int replace_fillchar = 0;
3340-
int no_neg_0 = 0;
33413333
Py_ssize_t size;
3342-
mpd_t *mpd = MPD(dec);
3343-
mpd_uint_t dt[MPD_MINALLOC_MAX];
3344-
mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};
33453334

33463335

33473336
CURRENT_CONTEXT(context);
@@ -3350,39 +3339,20 @@ dec_format(PyObject *dec, PyObject *args)
33503339
}
33513340

33523341
if (PyUnicode_Check(fmtarg)) {
3353-
fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
3342+
fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
33543343
if (fmt == NULL) {
33553344
return NULL;
33563345
}
3357-
/* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
3358-
* format string manipulation below can be eliminated by enhancing
3359-
* the forked mpd_parse_fmt_str(). */
3346+
33603347
if (size > 0 && fmt[0] == '\0') {
33613348
/* NUL fill character: must be replaced with a valid UTF-8 char
33623349
before calling mpd_parse_fmt_str(). */
33633350
replace_fillchar = 1;
3364-
fmt = fmt_copy = dec_strdup(fmt, size);
3365-
if (fmt_copy == NULL) {
3351+
fmt = dec_strdup(fmt, size);
3352+
if (fmt == NULL) {
33663353
return NULL;
33673354
}
3368-
fmt_copy[0] = '_';
3369-
}
3370-
/* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
3371-
* NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
3372-
char const *z_position = format_spec_z_search(fmt, size);
3373-
if (z_position != NULL) {
3374-
no_neg_0 = 1;
3375-
size_t z_index = z_position - fmt;
3376-
if (fmt_copy == NULL) {
3377-
fmt = fmt_copy = dec_strdup(fmt, size);
3378-
if (fmt_copy == NULL) {
3379-
return NULL;
3380-
}
3381-
}
3382-
/* Shift characters (including null terminator) left,
3383-
overwriting the 'z' option. */
3384-
memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
3385-
size -= 1;
3355+
fmt[0] = '_';
33863356
}
33873357
}
33883358
else {
@@ -3392,10 +3362,13 @@ dec_format(PyObject *dec, PyObject *args)
33923362
}
33933363

33943364
if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
3395-
PyErr_SetString(PyExc_ValueError,
3396-
"invalid format string");
3397-
goto finish;
3365+
if (replace_fillchar) {
3366+
PyMem_Free(fmt);
3367+
}
3368+
3369+
return pydec_format(dec, context, fmtarg);
33983370
}
3371+
33993372
if (replace_fillchar) {
34003373
/* In order to avoid clobbering parts of UTF-8 thousands separators or
34013374
decimal points when the substitution is reversed later, the actual
@@ -3448,45 +3421,8 @@ dec_format(PyObject *dec, PyObject *args)
34483421
}
34493422
}
34503423

3451-
if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
3452-
/* Round into a temporary (carefully mirroring the rounding
3453-
of mpd_qformat_spec()), and check if the result is negative zero.
3454-
If so, clear the sign and format the resulting positive zero. */
3455-
mpd_ssize_t prec;
3456-
mpd_qcopy(&tmp, mpd, &status);
3457-
if (spec.prec >= 0) {
3458-
switch (spec.type) {
3459-
case 'f':
3460-
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
3461-
break;
3462-
case '%':
3463-
tmp.exp += 2;
3464-
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
3465-
break;
3466-
case 'g':
3467-
prec = (spec.prec == 0) ? 1 : spec.prec;
3468-
if (tmp.digits > prec) {
3469-
_mpd_round(&tmp, &tmp, prec, CTX(context), &status);
3470-
}
3471-
break;
3472-
case 'e':
3473-
if (!mpd_iszero(&tmp)) {
3474-
_mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
3475-
}
3476-
break;
3477-
}
3478-
}
3479-
if (status & MPD_Errors) {
3480-
PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
3481-
goto finish;
3482-
}
3483-
if (mpd_iszero(&tmp)) {
3484-
mpd_set_positive(&tmp);
3485-
mpd = &tmp;
3486-
}
3487-
}
34883424

3489-
decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
3425+
decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
34903426
if (decstring == NULL) {
34913427
if (status & MPD_Malloc_error) {
34923428
PyErr_NoMemory();
@@ -3509,7 +3445,7 @@ dec_format(PyObject *dec, PyObject *args)
35093445
Py_XDECREF(grouping);
35103446
Py_XDECREF(sep);
35113447
Py_XDECREF(dot);
3512-
if (fmt_copy) PyMem_Free(fmt_copy);
3448+
if (replace_fillchar) PyMem_Free(fmt);
35133449
if (decstring) mpd_free(decstring);
35143450
return result;
35153451
}
@@ -5944,6 +5880,8 @@ PyInit__decimal(void)
59445880
/* Create the module */
59455881
ASSIGN_PTR(m, PyModule_Create(&_decimal_module));
59465882

5883+
/* For format specifiers not yet supported by libmpdec */
5884+
PyDecimal = NULL;
59475885

59485886
/* Add types to the module */
59495887
CHECK_INT(PyModule_AddObjectRef(m, "Decimal", (PyObject *)&PyDec_Type));

Tools/c-analyzer/cpython/globals-to-fix.tsv

+1
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ Modules/_decimal/_decimal.c - basic_context_template -
12731273
Modules/_decimal/_decimal.c - current_context_var -
12741274
Modules/_decimal/_decimal.c - default_context_template -
12751275
Modules/_decimal/_decimal.c - extended_context_template -
1276+
Modules/_decimal/_decimal.c - PyDecimal -
12761277
Modules/_decimal/_decimal.c - round_map -
12771278
Modules/_decimal/_decimal.c - Rational -
12781279
Modules/_decimal/_decimal.c - SignalTuple -

0 commit comments

Comments
 (0)