From 0c52510037f916137cd639f0ea26c58e38645155 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Tue, 4 Feb 2025 11:44:39 +0200 Subject: [PATCH] Fix use-after-free in the unicode-escape decoder with error handler If the error handler is used, a new bytes object is created to set as the object attribute of UnicodeDecodeError, and that bytes object then replaces the original data. A pointer to the decoded data will became invalid after destroying that temporary bytes object. So we need other way to return the first invalid escape from _PyUnicode_DecodeUnicodeEscapeInternal(). _PyBytes_DecodeEscape() does not have such issue, because it does not use the error handlers registry, but it should be changed for compatibility with _PyUnicode_DecodeUnicodeEscapeInternal(). --- Include/internal/pycore_bytesobject.h | 5 ++ Include/internal/pycore_unicodeobject.h | 16 ++++++ Lib/test/test_codeccallbacks.py | 39 +++++++++++++- Lib/test/test_codecs.py | 52 +++++++++++++++---- Lib/test/test_codeop.py | 4 +- Lib/test/test_string_literals.py | 8 +-- Lib/test/test_unparse.py | 2 +- ...-05-09-20-22-54.gh-issue-133767.kN2i3Q.rst | 2 + Objects/bytesobject.c | 43 ++++++++------- Objects/unicodeobject.c | 46 +++++++++------- Parser/string_parser.c | 32 +++++++----- 11 files changed, 182 insertions(+), 67 deletions(-) create mode 100644 Misc/NEWS.d/next/Security/2025-05-09-20-22-54.gh-issue-133767.kN2i3Q.rst diff --git a/Include/internal/pycore_bytesobject.h b/Include/internal/pycore_bytesobject.h index d36fa9569d64a5..350f54a3c01ab6 100644 --- a/Include/internal/pycore_bytesobject.h +++ b/Include/internal/pycore_bytesobject.h @@ -8,6 +8,11 @@ extern "C" { # error "this header requires Py_BUILD_CORE define" #endif +// Helper for PyBytes_DecodeEscape that detects invalid escape chars. +// Export for test_peg_generator. +PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape2(const char *, Py_ssize_t, + const char *, + int *, const char **); /* Substring Search. diff --git a/Include/internal/pycore_unicodeobject.h b/Include/internal/pycore_unicodeobject.h index cecdabe4155063..b25b821af6591d 100644 --- a/Include/internal/pycore_unicodeobject.h +++ b/Include/internal/pycore_unicodeobject.h @@ -79,6 +79,22 @@ extern void _PyUnicode_ClearInterned(PyInterpreterState *interp); // Like PyUnicode_AsUTF8(), but check for embedded null characters. extern const char* _PyUnicode_AsUTF8NoNUL(PyObject *); +// Helper for PyUnicode_DecodeUnicodeEscape that detects invalid escape +// chars. +// Export for test_peg_generator. +PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal2( + const char *string, /* Unicode-Escape encoded string */ + Py_ssize_t length, /* size of string */ + const char *errors, /* error handling */ + Py_ssize_t *consumed, /* bytes consumed */ + int *first_invalid_escape_char, /* on return, if not -1, contain the first + invalid escaped char (<= 0xff) or invalid + octal escape (> 0xff) in string. */ + const char **first_invalid_escape_ptr); /* on return, if not NULL, may + point to the first invalid escaped + char in string. + May be NULL if errors is not NULL. */ + #ifdef __cplusplus } diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index 4991330489d139..a169458df1c607 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -1,6 +1,7 @@ import codecs import html.entities import itertools +import re import sys import unicodedata import unittest @@ -1124,7 +1125,7 @@ def test_bug828737(self): text = 'abcghi'*n text.translate(charmap) - def test_mutatingdecodehandler(self): + def test_mutating_decode_handler(self): baddata = [ ("ascii", b"\xff"), ("utf-7", b"++"), @@ -1159,6 +1160,42 @@ def mutating(exc): for (encoding, data) in baddata: self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242") + def test_mutating_decode_handler_unicode_escape(self): + decode = codecs.unicode_escape_decode + def mutating(exc): + if isinstance(exc, UnicodeDecodeError): + r = data.get(exc.object[:exc.end]) + if r is not None: + exc.object = r[0] + exc.object[exc.end:] + return ('\u0404', r[1]) + raise AssertionError("don't know how to handle %r" % exc) + + codecs.register_error('test.mutating2', mutating) + data = { + br'\x0': (b'\\', 0), + br'\x3': (b'xxx\\', 3), + br'\x5': (b'x\\', 1), + } + def check(input, expected, msg): + with self.assertWarns(DeprecationWarning) as cm: + self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input))) + self.assertIn(msg, str(cm.warning)) + + check(br'\x0n\z', '\u0404\n\\z', r"invalid escape sequence") + check(br'\x0n\501', '\u0404\n\u0141', r'invalid octal escape sequence') + check(br'\x0z', '\u0404\\z', r'invalid escape sequence') + + check(br'\x3n\zr', '\u0404\n\\zr', r'invalid escape sequence') + check(br'\x3zr', '\u0404\\zr', r'invalid escape sequence') + check(br'\x3z5', '\u0404\\z5', r'invalid escape sequence') + check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r'invalid escape sequence') + check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r'invalid escape sequence') + + check(br'\x5n\z', '\u0404\n\\z', r'invalid escape sequence') + check(br'\x5n\501', '\u0404\n\u0141', r'invalid octal escape sequence') + check(br'\x5z', '\u0404\\z', r'invalid escape sequence') + check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r'invalid escape sequence') + # issue32583 def test_crashing_decode_handler(self): # better generating one more character to fill the extra space slot diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index f683f069ae1397..14e71f4645c352 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -1196,23 +1196,39 @@ def test_escape(self): check(br"[\1010]", b"[A0]") check(br"[\x41]", b"[A]") check(br"[\x410]", b"[A0]") + + def test_warnings(self): + decode = codecs.escape_decode + check = coding_checker(self, decode) for i in range(97, 123): b = bytes([i]) if b not in b'abfnrtvx': - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r"'\\%c' is an invalid escape sequence" % i): check(b"\\" + b, b"\\" + b) - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r"invalid escape sequence"): check(b"\\" + b.upper(), b"\\" + b.upper()) - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r"'\\8' is an invalid escape sequence"): check(br"\8", b"\\8") with self.assertWarns(DeprecationWarning): check(br"\9", b"\\9") - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r'invalid escape sequence') as cm: check(b"\\\xfa", b"\\\xfa") for i in range(0o400, 0o1000): - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r'invalid octal escape sequence'): check(rb'\%o' % i, bytes([i & 0o377])) + with self.assertWarnsRegex(DeprecationWarning, + r'invalid escape sequence'): + self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4)) + with self.assertWarnsRegex(DeprecationWarning, + r'invalid octal escape sequence'): + self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6)) + def test_errors(self): decode = codecs.escape_decode self.assertRaises(ValueError, decode, br"\x") @@ -2479,24 +2495,40 @@ def test_escape_decode(self): check(br"[\x410]", "[A0]") check(br"\u20ac", "\u20ac") check(br"\U0001d120", "\U0001d120") + + def test_decode_warnings(self): + decode = codecs.unicode_escape_decode + check = coding_checker(self, decode) for i in range(97, 123): b = bytes([i]) if b not in b'abfnrtuvx': - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r'invalid escape sequence'): check(b"\\" + b, "\\" + chr(i)) if b.upper() not in b'UN': - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + 'invalid escape sequence'): check(b"\\" + b.upper(), "\\" + chr(i-32)) - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r'invalid escape sequence'): check(br"\8", "\\8") with self.assertWarns(DeprecationWarning): check(br"\9", "\\9") - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r'invalid escape sequence') as cm: check(b"\\\xfa", "\\\xfa") for i in range(0o400, 0o1000): - with self.assertWarns(DeprecationWarning): + with self.assertWarnsRegex(DeprecationWarning, + r'invalid octal escape sequence'): check(rb'\%o' % i, chr(i)) + with self.assertWarnsRegex(DeprecationWarning, + r'invalid escape sequence'): + self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4)) + with self.assertWarnsRegex(DeprecationWarning, + r'invalid octal escape sequence'): + self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6)) + def test_decode_errors(self): decode = codecs.unicode_escape_decode for c, d in (b'x', 2), (b'u', 4), (b'U', 4): diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py index 787bd1b6a79e20..b418ffa826b9bd 100644 --- a/Lib/test/test_codeop.py +++ b/Lib/test/test_codeop.py @@ -281,8 +281,8 @@ def test_filename(self): def test_warning(self): # Test that the warning is only returned once. with warnings_helper.check_warnings( - ('"is" with \'str\' literal', SyntaxWarning), - ("invalid escape sequence", SyntaxWarning), + (r'"is" with.*literal', SyntaxWarning), + (r'invalid escape sequence', SyntaxWarning), ) as w: compile_command(r"'\e' is 0") self.assertEqual(len(w.warnings), 2) diff --git a/Lib/test/test_string_literals.py b/Lib/test/test_string_literals.py index 849efaba6180f7..3e636b09056915 100644 --- a/Lib/test/test_string_literals.py +++ b/Lib/test/test_string_literals.py @@ -116,7 +116,7 @@ def test_eval_str_invalid_escape(self): warnings.simplefilter('always', category=SyntaxWarning) eval("'''\n\\z'''") self.assertEqual(len(w), 1) - self.assertEqual(str(w[0].message), r"invalid escape sequence '\z'") + self.assertEqual(str(w[0].message), r"'\z' is an invalid escape sequence. ") self.assertEqual(w[0].filename, '') self.assertEqual(w[0].lineno, 2) @@ -153,7 +153,7 @@ def test_eval_str_invalid_octal_escape(self): eval("'''\n\\407'''") self.assertEqual(len(w), 1) self.assertEqual(str(w[0].message), - r"invalid octal escape sequence '\407'") + r"'\407' is an invalid octal escape sequence. ") self.assertEqual(w[0].filename, '') self.assertEqual(w[0].lineno, 2) @@ -228,7 +228,7 @@ def test_eval_bytes_invalid_escape(self): warnings.simplefilter('always', category=SyntaxWarning) eval("b'''\n\\z'''") self.assertEqual(len(w), 1) - self.assertEqual(str(w[0].message), r"invalid escape sequence '\z'") + self.assertEqual(str(w[0].message), r"'\z' is an invalid escape sequence. ") self.assertEqual(w[0].filename, '') self.assertEqual(w[0].lineno, 2) @@ -252,7 +252,7 @@ def test_eval_bytes_invalid_octal_escape(self): eval("b'''\n\\407'''") self.assertEqual(len(w), 1) self.assertEqual(str(w[0].message), - r"invalid octal escape sequence '\407'") + r"'\407' is an invalid octal escape sequence. ") self.assertEqual(w[0].filename, '') self.assertEqual(w[0].lineno, 2) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 496ccce261aa3e..d90af722bb555e 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -653,7 +653,7 @@ def test_multiquote_joined_string(self): def test_backslash_in_format_spec(self): import re - msg = re.escape("invalid escape sequence '\\ '") + msg = re.escape("invalid escape sequence") with self.assertWarnsRegex(SyntaxWarning, msg): self.check_ast_roundtrip("""f"{x:\\ }" """) self.check_ast_roundtrip("""f"{x:\\n}" """) diff --git a/Misc/NEWS.d/next/Security/2025-05-09-20-22-54.gh-issue-133767.kN2i3Q.rst b/Misc/NEWS.d/next/Security/2025-05-09-20-22-54.gh-issue-133767.kN2i3Q.rst new file mode 100644 index 00000000000000..39d2f1e1a892cf --- /dev/null +++ b/Misc/NEWS.d/next/Security/2025-05-09-20-22-54.gh-issue-133767.kN2i3Q.rst @@ -0,0 +1,2 @@ +Fix use-after-free in the "unicode-escape" decoder with a non-"strict" error +handler. diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index f3a978c86c3606..93db4235d71c49 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -1048,10 +1048,11 @@ _PyBytes_FormatEx(const char *format, Py_ssize_t format_len, } /* Unescape a backslash-escaped string. */ -PyObject *_PyBytes_DecodeEscape(const char *s, +PyObject *_PyBytes_DecodeEscape2(const char *s, Py_ssize_t len, const char *errors, - const char **first_invalid_escape) + int *first_invalid_escape_char, + const char **first_invalid_escape_ptr) { int c; char *p; @@ -1065,7 +1066,8 @@ PyObject *_PyBytes_DecodeEscape(const char *s, return NULL; writer.overallocate = 1; - *first_invalid_escape = NULL; + *first_invalid_escape_char = -1; + *first_invalid_escape_ptr = NULL; end = s + len; while (s < end) { @@ -1103,9 +1105,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s, c = (c<<3) + *s++ - '0'; } if (c > 0377) { - if (*first_invalid_escape == NULL) { - *first_invalid_escape = s-3; /* Back up 3 chars, since we've - already incremented s. */ + if (*first_invalid_escape_char == -1) { + *first_invalid_escape_char = c; + /* Back up 3 chars, since we've already incremented s. */ + *first_invalid_escape_ptr = s - 3; } } *p++ = c; @@ -1146,9 +1149,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s, break; default: - if (*first_invalid_escape == NULL) { - *first_invalid_escape = s-1; /* Back up one char, since we've - already incremented s. */ + if (*first_invalid_escape_char == -1) { + *first_invalid_escape_char = (unsigned char)s[-1]; + /* Back up one char, since we've already incremented s. */ + *first_invalid_escape_ptr = s - 1; } *p++ = '\\'; s--; @@ -1168,17 +1172,18 @@ PyObject *PyBytes_DecodeEscape(const char *s, Py_ssize_t Py_UNUSED(unicode), const char *Py_UNUSED(recode_encoding)) { - const char* first_invalid_escape; - PyObject *result = _PyBytes_DecodeEscape(s, len, errors, - &first_invalid_escape); + int first_invalid_escape_char; + const char *first_invalid_escape_ptr; + PyObject *result = _PyBytes_DecodeEscape2(s, len, errors, + &first_invalid_escape_char, + &first_invalid_escape_ptr); if (result == NULL) return NULL; - if (first_invalid_escape != NULL) { - unsigned char c = *first_invalid_escape; - if ('4' <= c && c <= '7') { + if (first_invalid_escape_char != -1) { + if (first_invalid_escape_char > 0xff) { if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, - "invalid octal escape sequence '\\%.3s'", - first_invalid_escape) < 0) + "'\\%o' is an invalid octal escape sequence. ", + first_invalid_escape_char) < 0) { Py_DECREF(result); return NULL; @@ -1186,8 +1191,8 @@ PyObject *PyBytes_DecodeEscape(const char *s, } else { if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, - "invalid escape sequence '\\%c'", - c) < 0) + "'\\%c' is an invalid escape sequence. ", + first_invalid_escape_char) < 0) { Py_DECREF(result); return NULL; diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 05562ad9927989..abeac3b4e96462 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -6046,13 +6046,15 @@ PyUnicode_AsUTF16String(PyObject *unicode) /* --- Unicode Escape Codec ----------------------------------------------- */ PyObject * -_PyUnicode_DecodeUnicodeEscapeInternal(const char *s, +_PyUnicode_DecodeUnicodeEscapeInternal2(const char *s, Py_ssize_t size, const char *errors, Py_ssize_t *consumed, - const char **first_invalid_escape) + int *first_invalid_escape_char, + const char **first_invalid_escape_ptr) { const char *starts = s; + const char *initial_starts = starts; _PyUnicodeWriter writer; const char *end; PyObject *errorHandler = NULL; @@ -6061,7 +6063,8 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s, PyInterpreterState *interp = _PyInterpreterState_Get(); // so we can remember if we've seen an invalid escape char or not - *first_invalid_escape = NULL; + *first_invalid_escape_char = -1; + *first_invalid_escape_ptr = NULL; if (size == 0) { if (consumed) { @@ -6149,9 +6152,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s, } } if (ch > 0377) { - if (*first_invalid_escape == NULL) { - *first_invalid_escape = s-3; /* Back up 3 chars, since we've - already incremented s. */ + if (*first_invalid_escape_char == -1) { + *first_invalid_escape_char = ch; + if (starts == initial_starts) { + /* Back up 3 chars, since we've already incremented s. */ + *first_invalid_escape_ptr = s - 3; + } } } WRITE_CHAR(ch); @@ -6252,9 +6258,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s, goto error; default: - if (*first_invalid_escape == NULL) { - *first_invalid_escape = s-1; /* Back up one char, since we've - already incremented s. */ + if (*first_invalid_escape_char == -1) { + *first_invalid_escape_char = c; + if (starts == initial_starts) { + /* Back up one char, since we've already incremented s. */ + *first_invalid_escape_ptr = s - 1; + } } WRITE_ASCII_CHAR('\\'); WRITE_CHAR(c); @@ -6299,18 +6308,19 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s, const char *errors, Py_ssize_t *consumed) { - const char *first_invalid_escape; - PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal(s, size, errors, + int first_invalid_escape_char; + const char *first_invalid_escape_ptr; + PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal2(s, size, errors, consumed, - &first_invalid_escape); + &first_invalid_escape_char, + &first_invalid_escape_ptr); if (result == NULL) return NULL; - if (first_invalid_escape != NULL) { - unsigned char c = *first_invalid_escape; - if ('4' <= c && c <= '7') { + if (first_invalid_escape_char != -1) { + if (first_invalid_escape_char > 0xff) { if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, - "invalid octal escape sequence '\\%.3s'", - first_invalid_escape) < 0) + "invalid octal escape sequence '\\%o'", + first_invalid_escape_char) < 0) { Py_DECREF(result); return NULL; @@ -6319,7 +6329,7 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s, else { if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, "invalid escape sequence '\\%c'", - c) < 0) + first_invalid_escape_char) < 0) { Py_DECREF(result); return NULL; diff --git a/Parser/string_parser.c b/Parser/string_parser.c index 8607885f2e46bd..815c0bde03f887 100644 --- a/Parser/string_parser.c +++ b/Parser/string_parser.c @@ -1,4 +1,6 @@ #include +#include "pycore_bytesobject.h" // _PyBytes_DecodeEscape() +#include "pycore_unicodeobject.h" // _PyUnicode_DecodeUnicodeEscapeInternal() #include "tokenizer.h" #include "pegen.h" @@ -25,9 +27,9 @@ warn_invalid_escape_sequence(Parser *p, const char* buffer, const char *first_in int octal = ('4' <= c && c <= '7'); PyObject *msg = octal - ? PyUnicode_FromFormat("invalid octal escape sequence '\\%.3s'", + ? PyUnicode_FromFormat("'\\%.3s' is an invalid octal escape sequence. ", first_invalid_escape) - : PyUnicode_FromFormat("invalid escape sequence '\\%c'", c); + : PyUnicode_FromFormat("'\\%c' is an invalid escape sequence. ", c); if (msg == NULL) { return -1; } @@ -181,15 +183,18 @@ decode_unicode_with_escapes(Parser *parser, const char *s, size_t len, Token *t) len = p - buf; s = buf; - const char *first_invalid_escape; - v = _PyUnicode_DecodeUnicodeEscapeInternal(s, len, NULL, NULL, &first_invalid_escape); + int first_invalid_escape_char; + const char *first_invalid_escape_ptr; + v = _PyUnicode_DecodeUnicodeEscapeInternal2(s, (Py_ssize_t)len, NULL, NULL, + &first_invalid_escape_char, + &first_invalid_escape_ptr); // HACK: later we can simply pass the line no, since we don't preserve the tokens // when we are decoding the string but we preserve the line numbers. - if (v != NULL && first_invalid_escape != NULL && t != NULL) { - if (warn_invalid_escape_sequence(parser, s, first_invalid_escape, t) < 0) { - /* We have not decref u before because first_invalid_escape points - inside u. */ + if (v != NULL && first_invalid_escape_ptr != NULL && t != NULL) { + if (warn_invalid_escape_sequence(parser, s, first_invalid_escape_ptr, t) < 0) { + /* We have not decref u before because first_invalid_escape_ptr + points inside u. */ Py_XDECREF(u); Py_DECREF(v); return NULL; @@ -202,14 +207,17 @@ decode_unicode_with_escapes(Parser *parser, const char *s, size_t len, Token *t) static PyObject * decode_bytes_with_escapes(Parser *p, const char *s, Py_ssize_t len, Token *t) { - const char *first_invalid_escape; - PyObject *result = _PyBytes_DecodeEscape(s, len, NULL, &first_invalid_escape); + int first_invalid_escape_char; + const char *first_invalid_escape_ptr; + PyObject *result = _PyBytes_DecodeEscape2(s, len, NULL, + &first_invalid_escape_char, + &first_invalid_escape_ptr); if (result == NULL) { return NULL; } - if (first_invalid_escape != NULL) { - if (warn_invalid_escape_sequence(p, s, first_invalid_escape, t) < 0) { + if (first_invalid_escape_ptr != NULL) { + if (warn_invalid_escape_sequence(p, s, first_invalid_escape_ptr, t) < 0) { Py_DECREF(result); return NULL; }