Skip to content

bpo-45126: Harden sqlite3 connection initialisation #28227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,44 @@ def test_connection_init_good_isolation_levels(self):
with memory_database(isolation_level=level) as cx:
cx.execute("select 'ok'")

def test_connection_reinit(self):
db = ":memory:"
cx = sqlite.connect(db)
cx.text_factory = bytes
cx.row_factory = sqlite.Row
cu = cx.cursor()
cu.execute("create table foo (bar)")
cu.executemany("insert into foo (bar) values (?)",
((str(v),) for v in range(4)))
cu.execute("select bar from foo")

rows = [r for r in cu.fetchmany(2)]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"])

cx.__init__(db)
cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d")))

# This uses the old database, old row factory, but new text factory
rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])

def test_connection_bad_reinit(self):
cx = sqlite.connect(":memory:")
with cx:
cx.execute("create table t(t)")
with temp_dir() as db:
self.assertRaisesRegex(sqlite.OperationalError,
"unable to open database file",
cx.__init__, db)
self.assertRaisesRegex(sqlite.ProgrammingError,
"Base Connection.__init__ not called",
cx.executemany, "insert into t values(?)",
((v,) for v in range(3)))


class UninitialisedConnectionTests(unittest.TestCase):
def setUp(self):
Expand Down
12 changes: 6 additions & 6 deletions Modules/_sqlite/clinic/connection.c.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
const char *database, double timeout,
int detect_types, const char *isolation_level,
int check_same_thread, PyObject *factory,
int cached_statements, int uri);
int cache_size, int uri);

static int
pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
Expand All @@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
const char *isolation_level = "";
int check_same_thread = 1;
PyObject *factory = (PyObject*)clinic_state()->ConnectionType;
int cached_statements = 128;
int cache_size = 128;
int uri = 0;

fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf);
Expand Down Expand Up @@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
}
}
if (fastargs[6]) {
cached_statements = _PyLong_AsInt(fastargs[6]);
if (cached_statements == -1 && PyErr_Occurred()) {
cache_size = _PyLong_AsInt(fastargs[6]);
if (cache_size == -1 && PyErr_Occurred()) {
goto exit;
}
if (!--noptargs) {
Expand All @@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
goto exit;
}
skip_optional_pos:
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri);
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri);

exit:
/* Cleanup for database */
Expand Down Expand Up @@ -851,4 +851,4 @@ getlimit(pysqlite_Connection *self, PyObject *arg)
#ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
#define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
#endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/
/*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/
122 changes: 66 additions & 56 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
static void free_callback_context(callback_context *ctx);
static void set_callback_context(callback_context **ctx_pp,
callback_context *ctx);
static void connection_close(pysqlite_Connection *self);

static PyObject *
new_statement_cache(pysqlite_Connection *self, int maxsize)
new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
int maxsize)
{
PyObject *args[] = { NULL, PyLong_FromLong(maxsize), };
if (args[1] == NULL) {
return NULL;
}
PyObject *lru_cache = self->state->lru_cache;
PyObject *lru_cache = state->lru_cache;
size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL);
Py_DECREF(args[1]);
Expand Down Expand Up @@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init
isolation_level: str(accept={str, NoneType}) = ""
check_same_thread: bool(accept={int}) = True
factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType
cached_statements: int = 128
cached_statements as cache_size: int = 128
uri: bool = False
[clinic start generated code]*/

Expand All @@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
const char *database, double timeout,
int detect_types, const char *isolation_level,
int check_same_thread, PyObject *factory,
int cached_statements, int uri)
/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/
int cache_size, int uri)
/*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/
{
int rc;

if (PySys_Audit("sqlite3.connect", "s", database) < 0) {
return -1;
}

pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
self->state = state;

Py_CLEAR(self->statement_cache);
Py_CLEAR(self->cursors);

Py_INCREF(Py_None);
Py_XSETREF(self->row_factory, Py_None);

Py_INCREF(&PyUnicode_Type);
Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type);
if (self->initialized) {
PyTypeObject *tp = Py_TYPE(self);
tp->tp_clear((PyObject *)self);
connection_close(self);
self->initialized = 0;
}

// Create and configure SQLite database object.
sqlite3 *db;
int rc;
Py_BEGIN_ALLOW_THREADS
rc = sqlite3_open_v2(database, &self->db,
rc = sqlite3_open_v2(database, &db,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE |
(uri ? SQLITE_OPEN_URI : 0), NULL);
if (rc == SQLITE_OK) {
(void)sqlite3_busy_timeout(db, (int)(timeout*1000));
}
Py_END_ALLOW_THREADS

if (self->db == NULL && rc == SQLITE_NOMEM) {
if (db == NULL && rc == SQLITE_NOMEM) {
PyErr_NoMemory();
return -1;
}

pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
if (rc != SQLITE_OK) {
_pysqlite_seterror(state, self->db);
_pysqlite_seterror(state, db);
return -1;
}

if (isolation_level) {
const char *stmt = get_begin_statement(isolation_level);
if (stmt == NULL) {
// Convert isolation level to begin statement.
const char *begin_statement = NULL;
if (isolation_level != NULL) {
begin_statement = get_begin_statement(isolation_level);
if (begin_statement == NULL) {
return -1;
}
self->begin_statement = stmt;
}
else {
self->begin_statement = NULL;
}

self->statement_cache = new_statement_cache(self, cached_statements);
if (self->statement_cache == NULL) {
return -1;
}
if (PyErr_Occurred()) {
// Create LRU statement cache; returns a new reference.
PyObject *statement_cache = new_statement_cache(self, state, cache_size);
if (statement_cache == NULL) {
return -1;
}

self->created_cursors = 0;

/* Create list of weak references to cursors */
self->cursors = PyList_New(0);
if (self->cursors == NULL) {
// Create list of weak references to cursors.
PyObject *cursors = PyList_New(0);
if (cursors == NULL) {
Py_DECREF(statement_cache);
return -1;
}

// Init connection state members.
self->db = db;
self->state = state;
self->detect_types = detect_types;
(void)sqlite3_busy_timeout(self->db, (int)(timeout*1000));
self->thread_ident = PyThread_get_thread_ident();
self->begin_statement = begin_statement;
self->check_same_thread = check_same_thread;
self->thread_ident = PyThread_get_thread_ident();
self->statement_cache = statement_cache;
self->cursors = cursors;
self->created_cursors = 0;
self->row_factory = Py_NewRef(Py_None);
self->text_factory = Py_NewRef(&PyUnicode_Type);
self->trace_ctx = NULL;
self->progress_ctx = NULL;
self->authorizer_ctx = NULL;

set_callback_context(&self->trace_ctx, NULL);
set_callback_context(&self->progress_ctx, NULL);
set_callback_context(&self->authorizer_ctx, NULL);

// Borrowed refs
self->Warning = state->Warning;
self->Error = state->Error;
self->InterfaceError = state->InterfaceError;
Expand All @@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
}

self->initialized = 1;

return 0;
}

Expand Down Expand Up @@ -321,16 +326,6 @@ connection_clear(pysqlite_Connection *self)
return 0;
}

static void
connection_close(pysqlite_Connection *self)
{
if (self->db) {
int rc = sqlite3_close_v2(self->db);
assert(rc == SQLITE_OK), (void)rc;
self->db = NULL;
}
}

static void
free_callback_contexts(pysqlite_Connection *self)
{
Expand All @@ -339,6 +334,22 @@ free_callback_contexts(pysqlite_Connection *self)
set_callback_context(&self->authorizer_ctx, NULL);
}

static void
connection_close(pysqlite_Connection *self)
{
if (self->db) {
free_callback_contexts(self);

sqlite3 *db = self->db;
self->db = NULL;

Py_BEGIN_ALLOW_THREADS
int rc = sqlite3_close_v2(db);
assert(rc == SQLITE_OK), (void)rc;
Py_END_ALLOW_THREADS
}
}

static void
connection_dealloc(pysqlite_Connection *self)
{
Expand All @@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self)

/* Clean up if user has not called .close() explicitly. */
connection_close(self);
free_callback_contexts(self);

tp->tp_free(self);
Py_DECREF(tp);
Expand Down