diff --git a/Lib/sqlite3/test/test_regression.py b/Lib/sqlite3/test/test_regression.py index 3d71809d9c11cf..f7f04b73f8ddb7 100644 --- a/Lib/sqlite3/test/test_regression.py +++ b/Lib/sqlite3/test/test_regression.py @@ -27,7 +27,7 @@ import weakref import functools from test import support - +from test.support.os_helper import temp_dir from .test_dbapi import managed_connect class RegressionTests(unittest.TestCase): @@ -486,6 +486,19 @@ def test_executescript_step_through_select(self): con.executescript("select step(t) from t") self.assertEqual(steps, values) + def test_connection_bad_reinit(self): + cx = sqlite.connect(":memory:") + with cx: + cx.execute("create table t(t)") + with temp_dir() as db: + self.assertRaisesRegex(sqlite.OperationalError, + "unable to open database file", + cx.__init__, db) + self.assertRaisesRegex(sqlite.ProgrammingError, + "Base Connection.__init__ not called", + cx.executemany, "insert into t values(?)", + ((v,) for v in range(3))) + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2021-10-22-15-05-10.bpo-45126.roRvx-.rst b/Misc/NEWS.d/next/Library/2021-10-22-15-05-10.bpo-45126.roRvx-.rst new file mode 100644 index 00000000000000..bba8876cafa72e --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-10-22-15-05-10.bpo-45126.roRvx-.rst @@ -0,0 +1,6 @@ +Prevent segfault if :class:`sqlite3.Connection` reinitialisation fails. + +.. warning:: + + :class:`sqlite3.Connection` is not adviced, as it may produce undesired + side-effects. Creating a new object is preferred to reinitialisation. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index e94c4cbb4e8c3a..e0b5e6f1dab0e3 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -159,6 +159,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, Py_INCREF(&PyUnicode_Type); Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type); + self->db = NULL; Py_BEGIN_ALLOW_THREADS rc = sqlite3_open_v2(database, &self->db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | @@ -167,13 +168,13 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, if (rc != SQLITE_OK) { _pysqlite_seterror(state, self->db); - return -1; + goto error; } if (!isolation_level) { isolation_level = PyUnicode_FromString(""); if (!isolation_level) { - return -1; + goto error; } } else { Py_INCREF(isolation_level); @@ -181,16 +182,16 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, Py_CLEAR(self->isolation_level); if (pysqlite_connection_set_isolation_level(self, isolation_level, NULL) != 0) { Py_DECREF(isolation_level); - return -1; + goto error; } Py_DECREF(isolation_level); self->statement_cache = new_statement_cache(self, cached_statements); if (self->statement_cache == NULL) { - return -1; + goto error; } if (PyErr_Occurred()) { - return -1; + goto error; } self->created_cursors = 0; @@ -198,7 +199,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, /* Create list of weak references to cursors */ self->cursors = PyList_New(0); if (self->cursors == NULL) { - return -1; + goto error; } self->detect_types = detect_types; @@ -222,12 +223,16 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, self->NotSupportedError = state->NotSupportedError; if (PySys_Audit("sqlite3.connect/handle", "O", self) < 0) { - return -1; + goto error; } self->initialized = 1; - return 0; + +error: + self->initialized = 0; + self->db = 0; + return -1; } static void