"""
Integrations with SQLAlchemy's `PostgreSQL <https://docs.sqlalchemy.org/en/14/core/engines.html#postgresql>`_ engine.
"""
from importlib.util import find_spec
from typing import Any, Callable, Union
from Ligare.database.config import DatabaseConnectArgsConfig
from Ligare.database.types import IScopedSessionFactory, MetaBase
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm.scoping import ScopedSession
from sqlalchemy.orm.session import sessionmaker
from typing_extensions import override
[docs]
class PostgreSQLScopedSession(
ScopedSession, IScopedSessionFactory["PostgreSQLScopedSession"]
):
[docs]
@override
@staticmethod
def create(
connection_string: str,
echo: bool = False,
execution_options: dict[str, Any] | None = None,
connect_args: DatabaseConnectArgsConfig | None = None,
bases: list[MetaBase | type[MetaBase]] | None = None,
) -> "PostgreSQLScopedSession":
if find_spec("psycopg2") is None:
raise ModuleNotFoundError(
"No module named 'psycopg2'. Install PostgreSQL support through `Ligare.database[postgres]` or `Ligare.database[postgres-binary]`."
)
engine = create_engine(
connection_string,
echo=echo,
execution_options=execution_options or {},
connect_args=connect_args.model_dump() if connect_args is not None else {},
)
if bases:
PostgreSQLScopedSession._alter_base_schemas(engine, bases)
return PostgreSQLScopedSession(
sessionmaker(autocommit=False, autoflush=False, bind=engine)
)
@staticmethod
def _alter_base_schemas(engine: Engine, bases: list[MetaBase | type[MetaBase]]):
# This renames all tables to undo any renaming that previously happened
# from, e.g., our SQLite engine.
for metadata_base in bases:
metadata_base.metadata.reflect(bind=engine)
for table_subclass in type(metadata_base).__subclasses__(metadata_base):
schema: str | None = None
if hasattr(metadata_base, "__table_args__") and isinstance(
metadata_base.__table_args__, dict
):
schema = metadata_base.__table_args__.get("schema")
if schema:
table_name: list[str] = table_subclass.__tablename__.split(".")
# Trim all prepended schema names
while table_name[0] == schema:
table_name = table_name[1:]
table_subclass.__tablename__ = table_name[0]
for table in metadata_base.metadata.sorted_tables:
table_name = table.name.split(".")
while table_name[0] == table.schema:
table_name = table_name[1:]
table.name = ".".join(table_name)
table.fullname = f"{table.schema}.{table.name}"
[docs]
def __init__(
self,
session_factory: Union[Callable[..., Any], "sessionmaker[Any]"],
scopefunc: Any = None,
) -> None:
super().__init__(session_factory, scopefunc)