From e5aaf18d443468a6f7c7005bea2159ad02db3b39 Mon Sep 17 00:00:00 2001 From: atom-andrew Date: Fri, 3 Oct 2025 19:22:50 -0700 Subject: [PATCH 1/4] Always dispose of engines --- sqlalchemy_utils/functions/database.py | 68 ++++++++++++++------------ 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py index bda84c2f..e389b779 100644 --- a/sqlalchemy_utils/functions/database.py +++ b/sqlalchemy_utils/functions/database.py @@ -562,35 +562,37 @@ def create_database(url, encoding='utf8', template=None): else: engine = sa.create_engine(url) - if dialect_name == 'postgresql': - if not template: - template = 'template1' - - with engine.begin() as conn: - text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format( - quote(conn, database), encoding, quote(conn, template) - ) - conn.execute(sa.text(text)) + try: + if dialect_name == 'postgresql': + if not template: + template = 'template1' - elif dialect_name == 'mysql': - with engine.begin() as conn: - text = "CREATE DATABASE {} CHARACTER SET = '{}'".format( - quote(conn, database), encoding - ) - conn.execute(sa.text(text)) + with engine.begin() as conn: + text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format( + quote(conn, database), encoding, quote(conn, template) + ) + conn.execute(sa.text(text)) - elif dialect_name == 'sqlite' and database != ':memory:': - if database: + elif dialect_name == 'mysql': with engine.begin() as conn: - conn.execute(sa.text('CREATE TABLE DB(id int)')) - conn.execute(sa.text('DROP TABLE DB')) + text = "CREATE DATABASE {} CHARACTER SET = '{}'".format( + quote(conn, database), encoding + ) + conn.execute(sa.text(text)) - else: - with engine.begin() as conn: - text = f'CREATE DATABASE {quote(conn, database)}' - conn.execute(sa.text(text)) + elif dialect_name == 'sqlite' and database != ':memory:': + if database: + with engine.begin() as conn: + conn.execute(sa.text('CREATE TABLE DB(id int)')) + conn.execute(sa.text('DROP TABLE DB')) + + else: + with engine.begin() as conn: + text = f'CREATE DATABASE {quote(conn, database)}' + conn.execute(sa.text(text)) - engine.dispose() + finally: + engine.dispose() def drop_database(url): @@ -631,12 +633,14 @@ def drop_database(url): else: engine = sa.create_engine(url) - if dialect_name == 'sqlite' and database != ':memory:': - if database: - os.remove(database) - else: - with engine.begin() as conn: - text = f'DROP DATABASE {quote(conn, database)}' - conn.execute(sa.text(text)) + try: + if dialect_name == 'sqlite' and database != ':memory:': + if database: + os.remove(database) + else: + with engine.begin() as conn: + text = f'DROP DATABASE {quote(conn, database)}' + conn.execute(sa.text(text)) - engine.dispose() + finally: + engine.dispose() From 2411f00cb0ba50f923c8e4ad652289b5ddb8803e Mon Sep 17 00:00:00 2001 From: atom-andrew Date: Sun, 5 Oct 2025 18:24:30 -0700 Subject: [PATCH 2/4] Factor out engine creation --- sqlalchemy_utils/functions/database.py | 134 ++++++++++++------------- tests/functions/test_database.py | 9 ++ 2 files changed, 76 insertions(+), 67 deletions(-) diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py index e389b779..fada1992 100644 --- a/sqlalchemy_utils/functions/database.py +++ b/sqlalchemy_utils/functions/database.py @@ -1,6 +1,7 @@ import itertools import os from collections.abc import Mapping, Sequence +from contextlib import contextmanager from copy import copy import sqlalchemy as sa @@ -442,6 +443,36 @@ def _sqlite_file_exists(database): return header[:16] == b'SQLite format 3\x00' +@contextmanager +def _create_engine(url, use_primary_db=False, **engine_kwargs): + """A context manager that provies a SQLAlchemy engine. + + :param url: A SQLAlchemy engine URL. + :param use_primary_db: If True, connects to the primary database of the + given database server. This is necessary for operations such as creating + or dropping databases. If False, connects to the database specified in the url. + :param engine_kwargs: Additional keyword arguments passed to sqlalchemy's create_engine function. + """ + url = make_url(url) + dialect_name = url.get_dialect().name + + if use_primary_db: + if dialect_name == 'postgresql': + url = _set_url_database(url, database='postgres') + elif dialect_name == 'mssql': + url = _set_url_database(url, database='master') + elif dialect_name == 'cockroachdb': + url = _set_url_database(url, database='defaultdb') + elif not dialect_name == 'sqlite': + url = _set_url_database(url, database=None) + + engine = sa.create_engine(url, **engine_kwargs) + try: + yield engine + finally: + engine.dispose() + + def database_exists(url): """Check if a database exists. @@ -466,55 +497,48 @@ def database_exists(url): url = make_url(url) database = url.database dialect_name = url.get_dialect().name - engine = None - try: - if dialect_name == 'postgresql': - text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database - for db in (database, 'postgres', 'template1', 'template0', None): - url = _set_url_database(url, database=db) - engine = sa.create_engine(url, isolation_level='AUTOCOMMIT') + if dialect_name == 'postgresql': + text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database + for db in (database, 'postgres', 'template1', 'template0', None): + url = _set_url_database(url, database=db) + with _create_engine(url) as engine: try: return bool(_get_scalar_result(engine, sa.text(text))) except (ProgrammingError, OperationalError): pass - return False + return False - elif dialect_name == 'mysql': - url = _set_url_database(url, database=None) - engine = sa.create_engine(url) + elif dialect_name == 'mysql': + url = _set_url_database(url, database=None) + with _create_engine(url) as engine: text = ( 'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA ' "WHERE SCHEMA_NAME = '%s'" % database ) return bool(_get_scalar_result(engine, sa.text(text))) - elif dialect_name == 'sqlite': - url = _set_url_database(url, database=None) - engine = sa.create_engine(url) - if database: - return database == ':memory:' or _sqlite_file_exists(database) - else: - # The default SQLAlchemy database is in memory, and :memory: is - # not required, thus we should support that use case. - return True - elif dialect_name == 'mssql': + elif dialect_name == 'sqlite': + if database: + return database == ':memory:' or _sqlite_file_exists(database) + else: + # The default SQLAlchemy database is in memory, and :memory: is + # not required, thus we should support that use case. + return True + elif dialect_name == 'mssql': + url = _set_url_database(url, database='master') + with _create_engine(url, isolation_level='AUTOCOMMIT') as engine: text = "SELECT 1 FROM sys.databases WHERE name = '%s'" % database - url = _set_url_database(url, database='master') - engine = sa.create_engine(url, isolation_level='AUTOCOMMIT') try: return bool(_get_scalar_result(engine, sa.text(text))) except (ProgrammingError, OperationalError): return False - else: + else: + with _create_engine(url) as engine: text = 'SELECT 1' try: - engine = sa.create_engine(url) return bool(_get_scalar_result(engine, sa.text(text))) except (ProgrammingError, OperationalError): return False - finally: - if engine: - engine.dispose() def create_database(url, encoding='utf8', template=None): @@ -544,25 +568,16 @@ def create_database(url, encoding='utf8', template=None): dialect_name = url.get_dialect().name dialect_driver = url.get_dialect().driver - if dialect_name == 'postgresql': - url = _set_url_database(url, database='postgres') - elif dialect_name == 'mssql': - url = _set_url_database(url, database='master') - elif dialect_name == 'cockroachdb': - url = _set_url_database(url, database='defaultdb') - elif not dialect_name == 'sqlite': - url = _set_url_database(url, database=None) - if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) or ( dialect_name == 'postgresql' and dialect_driver in {'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'} ): - engine = sa.create_engine(url, isolation_level='AUTOCOMMIT') + engine_kwargs = {'isolation_level': 'AUTOCOMMIT'} else: - engine = sa.create_engine(url) + engine_kwargs = {} - try: + with _create_engine(url, use_primary_db=True, **engine_kwargs) as engine: if dialect_name == 'postgresql': if not template: template = 'template1' @@ -591,9 +606,6 @@ def create_database(url, encoding='utf8', template=None): text = f'CREATE DATABASE {quote(conn, database)}' conn.execute(sa.text(text)) - finally: - engine.dispose() - def drop_database(url): """Issue the appropriate DROP DATABASE statement. @@ -615,32 +627,20 @@ def drop_database(url): dialect_name = url.get_dialect().name dialect_driver = url.get_dialect().driver - if dialect_name == 'postgresql': - url = _set_url_database(url, database='postgres') - elif dialect_name == 'mssql': - url = _set_url_database(url, database='master') - elif dialect_name == 'cockroachdb': - url = _set_url_database(url, database='defaultdb') - elif not dialect_name == 'sqlite': - url = _set_url_database(url, database=None) - - if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) or ( - dialect_name == 'postgresql' - and dialect_driver - in {'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'} - ): - engine = sa.create_engine(url, isolation_level='AUTOCOMMIT') + if dialect_name == 'sqlite' and database != ':memory:': + if database: + os.remove(database) else: - engine = sa.create_engine(url) - - try: - if dialect_name == 'sqlite' and database != ':memory:': - if database: - os.remove(database) + if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) or ( + dialect_name == 'postgresql' + and dialect_driver + in {'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'} + ): + engine_kwargs = {'isolation_level': 'AUTOCOMMIT'} else: + engine_kwargs = {} + + with _create_engine(url, use_primary_db=True, **engine_kwargs) as engine: with engine.begin() as conn: text = f'DROP DATABASE {quote(conn, database)}' conn.execute(sa.text(text)) - - finally: - engine.dispose() diff --git a/tests/functions/test_database.py b/tests/functions/test_database.py index cdd631ad..7b927031 100644 --- a/tests/functions/test_database.py +++ b/tests/functions/test_database.py @@ -3,6 +3,7 @@ from sqlalchemy_utils import create_database, database_exists, drop_database from sqlalchemy_utils.compat import get_sqlalchemy_version +from sqlalchemy_utils.functions.database import _create_engine pymysql = None try: @@ -163,3 +164,11 @@ class TestDatabaseMssql(DatabaseTest): def db_name(self): pytest.importorskip('pyodbc') return 'db_test_sqlalchemy_util' + + +def test_create_engine(sqlite_memory_dsn): + with _create_engine(sqlite_memory_dsn) as engine: + pool = engine.pool + + # a disposed engine should not have the same pool + assert engine.pool is not pool From 196a687f6bcb8c4e7ebfad66af25805074a73c32 Mon Sep 17 00:00:00 2001 From: atom-andrew Date: Sun, 5 Oct 2025 19:23:58 -0700 Subject: [PATCH 3/4] Better tests --- tests/functions/test_database.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/functions/test_database.py b/tests/functions/test_database.py index 7b927031..2cce3d5e 100644 --- a/tests/functions/test_database.py +++ b/tests/functions/test_database.py @@ -167,8 +167,20 @@ def db_name(self): def test_create_engine(sqlite_memory_dsn): + """Test that engine creation context manager creates an engine and disposes of it""" with _create_engine(sqlite_memory_dsn) as engine: pool = engine.pool + with engine.connect() as conn: + assert conn.execute(sa.text('SELECT 1')).scalar() == 1 - # a disposed engine should not have the same pool - assert engine.pool is not pool + assert engine.pool is not pool, "Engine was not disposed because pool is the same" + + +def test_create_engine_always_disposes(sqlite_memory_dsn): + """Test that engine creation context manager still dispoes of an engine when an exception is raised.""" + with pytest.raises(RuntimeError, match='it failed'): + with _create_engine(sqlite_memory_dsn) as engine: + pool = engine.pool + raise RuntimeError('it failed') + + assert engine.pool is not pool, "Engine was not disposed because pool is the same" From 271efc7fc8686c26d6db3ce62e29e64dd049a7cf Mon Sep 17 00:00:00 2001 From: atom-andrew Date: Sun, 5 Oct 2025 19:24:22 -0700 Subject: [PATCH 4/4] Fix typo --- tests/functions/test_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functions/test_database.py b/tests/functions/test_database.py index 2cce3d5e..ab256cc5 100644 --- a/tests/functions/test_database.py +++ b/tests/functions/test_database.py @@ -177,7 +177,7 @@ def test_create_engine(sqlite_memory_dsn): def test_create_engine_always_disposes(sqlite_memory_dsn): - """Test that engine creation context manager still dispoes of an engine when an exception is raised.""" + """Test that engine creation context manager still disposes of an engine when an exception is raised.""" with pytest.raises(RuntimeError, match='it failed'): with _create_engine(sqlite_memory_dsn) as engine: pool = engine.pool