Skip to content

Commit 0c93a0f

Browse files
authored
Fix Unicode support for ostream redirects (#2982)
* Crash when printing Unicode to redirected cout Add failing tests * Fix Unicode crashes redirected cout * pythonbuf::utf8_remainder check end iterator * Remove trailing whitespace and formatting iostream * Avoid buffer overflow if ostream redirect races This doesn't solve the actual race, but at least it now has a much lower probability of reading past the end of the buffer even when data races do occur.
1 parent 5443043 commit 0c93a0f

File tree

2 files changed

+153
-12
lines changed

2 files changed

+153
-12
lines changed

include/pybind11/iostream.h

+63-12
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include <string>
1717
#include <memory>
1818
#include <iostream>
19+
#include <cstring>
20+
#include <iterator>
21+
#include <algorithm>
1922

2023
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
2124
PYBIND11_NAMESPACE_BEGIN(detail)
@@ -38,25 +41,73 @@ class pythonbuf : public std::streambuf {
3841
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
3942
}
4043

44+
// Computes how many bytes at the end of the buffer are part of an
45+
// incomplete sequence of UTF-8 bytes.
46+
// Precondition: pbase() < pptr()
47+
size_t utf8_remainder() const {
48+
const auto rbase = std::reverse_iterator<char *>(pbase());
49+
const auto rpptr = std::reverse_iterator<char *>(pptr());
50+
auto is_ascii = [](char c) {
51+
return (static_cast<unsigned char>(c) & 0x80) == 0x00;
52+
};
53+
auto is_leading = [](char c) {
54+
return (static_cast<unsigned char>(c) & 0xC0) == 0xC0;
55+
};
56+
auto is_leading_2b = [](char c) {
57+
return static_cast<unsigned char>(c) <= 0xDF;
58+
};
59+
auto is_leading_3b = [](char c) {
60+
return static_cast<unsigned char>(c) <= 0xEF;
61+
};
62+
// If the last character is ASCII, there are no incomplete code points
63+
if (is_ascii(*rpptr))
64+
return 0;
65+
// Otherwise, work back from the end of the buffer and find the first
66+
// UTF-8 leading byte
67+
const auto rpend = rbase - rpptr >= 3 ? rpptr + 3 : rbase;
68+
const auto leading = std::find_if(rpptr, rpend, is_leading);
69+
if (leading == rbase)
70+
return 0;
71+
const auto dist = static_cast<size_t>(leading - rpptr);
72+
size_t remainder = 0;
73+
74+
if (dist == 0)
75+
remainder = 1; // 1-byte code point is impossible
76+
else if (dist == 1)
77+
remainder = is_leading_2b(*leading) ? 0 : dist + 1;
78+
else if (dist == 2)
79+
remainder = is_leading_3b(*leading) ? 0 : dist + 1;
80+
// else if (dist >= 3), at least 4 bytes before encountering an UTF-8
81+
// leading byte, either no remainder or invalid UTF-8.
82+
// Invalid UTF-8 will cause an exception later when converting
83+
// to a Python string, so that's not handled here.
84+
return remainder;
85+
}
86+
4187
// This function must be non-virtual to be called in a destructor. If the
4288
// rare MSVC test failure shows up with this version, then this should be
4389
// simplified to a fully qualified call.
4490
int _sync() {
45-
if (pbase() != pptr()) {
46-
47-
{
48-
gil_scoped_acquire tmp;
49-
91+
if (pbase() != pptr()) { // If buffer is not empty
92+
gil_scoped_acquire tmp;
93+
// Placed inside gil_scoped_acquire as a mutex to avoid a race.
94+
if (pbase() != pptr()) { // Check again under the lock
5095
// This subtraction cannot be negative, so dropping the sign.
51-
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
52-
53-
pywrite(line);
54-
pyflush();
55-
56-
// Placed inside gil_scoped_aquire as a mutex to avoid a race
96+
auto size = static_cast<size_t>(pptr() - pbase());
97+
size_t remainder = utf8_remainder();
98+
99+
if (size > remainder) {
100+
str line(pbase(), size - remainder);
101+
pywrite(line);
102+
pyflush();
103+
}
104+
105+
// Copy the remainder at the end of the buffer to the beginning:
106+
if (remainder > 0)
107+
std::memmove(pbase(), pptr() - remainder, remainder);
57108
setp(pbase(), epptr());
109+
pbump(static_cast<int>(remainder));
58110
}
59-
60111
}
61112
return 0;
62113
}

tests/test_iostream.py

+90
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,96 @@ def test_captured_large_string(capsys):
6969
assert stderr == ""
7070

7171

72+
def test_captured_utf8_2byte_offset0(capsys):
73+
msg = "\u07FF"
74+
msg = "" + msg * (1024 // len(msg) + 1)
75+
76+
m.captured_output_default(msg)
77+
stdout, stderr = capsys.readouterr()
78+
assert stdout == msg
79+
assert stderr == ""
80+
81+
82+
def test_captured_utf8_2byte_offset1(capsys):
83+
msg = "\u07FF"
84+
msg = "1" + msg * (1024 // len(msg) + 1)
85+
86+
m.captured_output_default(msg)
87+
stdout, stderr = capsys.readouterr()
88+
assert stdout == msg
89+
assert stderr == ""
90+
91+
92+
def test_captured_utf8_3byte_offset0(capsys):
93+
msg = "\uFFFF"
94+
msg = "" + msg * (1024 // len(msg) + 1)
95+
96+
m.captured_output_default(msg)
97+
stdout, stderr = capsys.readouterr()
98+
assert stdout == msg
99+
assert stderr == ""
100+
101+
102+
def test_captured_utf8_3byte_offset1(capsys):
103+
msg = "\uFFFF"
104+
msg = "1" + msg * (1024 // len(msg) + 1)
105+
106+
m.captured_output_default(msg)
107+
stdout, stderr = capsys.readouterr()
108+
assert stdout == msg
109+
assert stderr == ""
110+
111+
112+
def test_captured_utf8_3byte_offset2(capsys):
113+
msg = "\uFFFF"
114+
msg = "12" + msg * (1024 // len(msg) + 1)
115+
116+
m.captured_output_default(msg)
117+
stdout, stderr = capsys.readouterr()
118+
assert stdout == msg
119+
assert stderr == ""
120+
121+
122+
def test_captured_utf8_4byte_offset0(capsys):
123+
msg = "\U0010FFFF"
124+
msg = "" + msg * (1024 // len(msg) + 1)
125+
126+
m.captured_output_default(msg)
127+
stdout, stderr = capsys.readouterr()
128+
assert stdout == msg
129+
assert stderr == ""
130+
131+
132+
def test_captured_utf8_4byte_offset1(capsys):
133+
msg = "\U0010FFFF"
134+
msg = "1" + msg * (1024 // len(msg) + 1)
135+
136+
m.captured_output_default(msg)
137+
stdout, stderr = capsys.readouterr()
138+
assert stdout == msg
139+
assert stderr == ""
140+
141+
142+
def test_captured_utf8_4byte_offset2(capsys):
143+
msg = "\U0010FFFF"
144+
msg = "12" + msg * (1024 // len(msg) + 1)
145+
146+
m.captured_output_default(msg)
147+
stdout, stderr = capsys.readouterr()
148+
assert stdout == msg
149+
assert stderr == ""
150+
151+
152+
def test_captured_utf8_4byte_offset3(capsys):
153+
msg = "\U0010FFFF"
154+
msg = "123" + msg * (1024 // len(msg) + 1)
155+
156+
m.captured_output_default(msg)
157+
stdout, stderr = capsys.readouterr()
158+
assert stdout == msg
159+
assert stderr == ""
160+
161+
72162
def test_guard_capture(capsys):
73163
msg = "I've been redirected to Python, I hope!"
74164
m.guard_output(msg)

0 commit comments

Comments
 (0)