Source code for Ligare.database.engine.postgresql

"""
Integrations with SQLAlchemy's `PostgreSQL <https://docs.sqlalchemy.org/en/14/core/engines.html#postgresql>`_ engine.
"""

from importlib.util import find_spec
from typing import Any, Callable, Union

from Ligare.database.config import DatabaseConnectArgsConfig
from Ligare.database.types import IScopedSessionFactory, MetaBase
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm.scoping import ScopedSession
from sqlalchemy.orm.session import sessionmaker
from typing_extensions import override


[docs] class PostgreSQLScopedSession( ScopedSession, IScopedSessionFactory["PostgreSQLScopedSession"] ):
[docs] @override @staticmethod def create( connection_string: str, echo: bool = False, execution_options: dict[str, Any] | None = None, connect_args: DatabaseConnectArgsConfig | None = None, bases: list[MetaBase | type[MetaBase]] | None = None, ) -> "PostgreSQLScopedSession": if find_spec("psycopg2") is None: raise ModuleNotFoundError( "No module named 'psycopg2'. Install PostgreSQL support through `Ligare.database[postgres]` or `Ligare.database[postgres-binary]`." ) engine = create_engine( connection_string, echo=echo, execution_options=execution_options or {}, connect_args=connect_args.model_dump() if connect_args is not None else {}, ) if bases: PostgreSQLScopedSession._alter_base_schemas(engine, bases) return PostgreSQLScopedSession( sessionmaker(autocommit=False, autoflush=False, bind=engine) )
@staticmethod def _alter_base_schemas(engine: Engine, bases: list[MetaBase | type[MetaBase]]): # This renames all tables to undo any renaming that previously happened # from, e.g., our SQLite engine. for metadata_base in bases: metadata_base.metadata.reflect(bind=engine) for table_subclass in type(metadata_base).__subclasses__(metadata_base): schema: str | None = None if hasattr(metadata_base, "__table_args__") and isinstance( metadata_base.__table_args__, dict ): schema = metadata_base.__table_args__.get("schema") if schema: table_name: list[str] = table_subclass.__tablename__.split(".") # Trim all prepended schema names while table_name[0] == schema: table_name = table_name[1:] table_subclass.__tablename__ = table_name[0] for table in metadata_base.metadata.sorted_tables: table_name = table.name.split(".") while table_name[0] == table.schema: table_name = table_name[1:] table.name = ".".join(table_name) table.fullname = f"{table.schema}.{table.name}"
[docs] def __init__( self, session_factory: Union[Callable[..., Any], "sessionmaker[Any]"], scopefunc: Any = None, ) -> None: super().__init__(session_factory, scopefunc)