setup multy db suport

This commit is contained in:
2026-03-07 11:19:09 -05:00
parent 66acc010ca
commit 69f5b87e5f
17 changed files with 315 additions and 227 deletions

View File

@@ -7,7 +7,22 @@ requires-python = "~=3.13.0"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
# these dependencies are a best effort and aren't guaranteed to work # these dependencies are a best effort and aren't guaranteed to work
dependencies = ["apprise", "apscheduler", "httpx", "polars", "pydantic", "pyyaml", "requests", "typer"] dependencies = [
"alembic",
"apprise",
"apscheduler",
"httpx",
"polars",
"psycopg[binary]",
"pydantic",
"pyyaml",
"requests",
"sqlalchemy",
"typer",
]
[project.scripts]
database = "python.database_cli:app"
[dependency-groups] [dependency-groups]
dev = [ dev = [

View File

@@ -1,109 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = python/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
file_template = %%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
revision_environment = true
[post_write_hooks]
hooks = dynamic_schema,ruff
dynamic_schema.type = dynamic_schema
ruff.type = ruff
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -9,20 +9,24 @@ from typing import TYPE_CHECKING, Any, Literal
from alembic import context from alembic import context
from alembic.script import write_hooks from alembic.script import write_hooks
from sqlalchemy.schema import CreateSchema
from python.common import bash_wrapper from python.common import bash_wrapper
from python.orm import RichieBase from python.orm.common import get_postgres_engine
from python.orm.base import get_postgres_engine
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import MutableMapping from collections.abc import MutableMapping
# this is the Alembic Config object, which provides from sqlalchemy.orm import DeclarativeBase
# access to the values within the .ini file in use.
config = context.config config = context.config
base_class: type[DeclarativeBase] = config.attributes.get("base")
if base_class is None:
error = "No base class provided. Use the database CLI to run alembic commands."
raise RuntimeError(error)
target_metadata = RichieBase.metadata target_metadata = base_class.metadata
logging.basicConfig( logging.basicConfig(
level="DEBUG", level="DEBUG",
datefmt="%Y-%m-%dT%H:%M:%S%z", datefmt="%Y-%m-%dT%H:%M:%S%z",
@@ -35,8 +39,9 @@ logging.basicConfig(
def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None: def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
"""Dynamic schema.""" """Dynamic schema."""
original_file = Path(filename).read_text() original_file = Path(filename).read_text()
dynamic_schema_file_part1 = original_file.replace(f"schema='{RichieBase.schema_name}'", "schema=schema") schema_name = base_class.schema_name
dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{RichieBase.schema_name}.", "f'{schema}.") dynamic_schema_file_part1 = original_file.replace(f"schema='{schema_name}'", "schema=schema")
dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{schema_name}.", "f'{schema}.")
Path(filename).write_text(dynamic_schema_file) Path(filename).write_text(dynamic_schema_file)
@@ -52,12 +57,12 @@ def include_name(
type_: Literal["schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint"], type_: Literal["schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint"],
_parent_names: MutableMapping[Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None], _parent_names: MutableMapping[Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None],
) -> bool: ) -> bool:
"""This filter table to be included in the migration. """Filter tables to be included in the migration.
Args: Args:
name (str): The name of the table. name (str): The name of the table.
type_ (str): The type of the table. type_ (str): The type of the table.
parent_names (list[str]): The names of the parent tables. _parent_names (MutableMapping): The names of the parent tables.
Returns: Returns:
bool: True if the table should be included, False otherwise. bool: True if the table should be included, False otherwise.
@@ -67,7 +72,6 @@ def include_name(
return name == target_metadata.schema return name == target_metadata.schema
return True return True
def run_migrations_online() -> None: def run_migrations_online() -> None:
"""Run migrations in 'online' mode. """Run migrations in 'online' mode.
@@ -75,14 +79,24 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = get_postgres_engine() env_prefix = config.attributes.get("env_prefix", "POSTGRES")
connectable = get_postgres_engine(name=env_prefix)
with connectable.connect() as connection: with connectable.connect() as connection:
schema = base_class.schema_name
if not connectable.dialect.has_schema(connection, schema):
answer = input(f"Schema {schema!r} does not exist. Create it? [y/N] ")
if answer.lower() != "y":
error = f"Schema {schema!r} does not exist. Exiting."
raise SystemExit(error)
connection.execute(CreateSchema(schema))
connection.commit()
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
include_schemas=True, include_schemas=True,
version_table_schema=RichieBase.schema_name, version_table_schema=schema,
include_name=include_name, include_name=include_name,
) )

View File

@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
from python.orm import RichieBase from python.orm import ${config.attributes["base"].__name__}
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
@@ -24,7 +24,7 @@ down_revision: str | None = ${repr(down_revision)}
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
depends_on: str | Sequence[str] | None = ${repr(depends_on)} depends_on: str | Sequence[str] | None = ${repr(depends_on)}
schema=RichieBase.schema_name schema=${config.attributes["base"].__name__}.schema_name
def upgrade() -> None: def upgrade() -> None:
"""Upgrade.""" """Upgrade."""

View File

@@ -16,7 +16,7 @@ from fastapi import FastAPI
from python.api.routers import contact_router, create_frontend_router from python.api.routers import contact_router, create_frontend_router
from python.common import configure_logger from python.common import configure_logger
from python.orm.base import get_postgres_engine from python.orm.common import get_postgres_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from python.api.dependencies import DbSession from python.api.dependencies import DbSession
from python.orm.contact import Contact, ContactRelationship, Need, RelationshipType from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
class NeedBase(BaseModel): class NeedBase(BaseModel):

114
python/database_cli.py Normal file
View File

@@ -0,0 +1,114 @@
"""CLI wrapper around alembic for multi-database support.
Usage:
database <db_name> <command> [args...]
Examples:
database van_inventory upgrade head
database van_inventory downgrade head-1
database van_inventory revision --autogenerate -m "add meals table"
database van_inventory check
database richie check
database richie upgrade head
"""
from __future__ import annotations
from dataclasses import dataclass
from importlib import import_module
from typing import TYPE_CHECKING, Annotated
import typer
from alembic.config import CommandLine, Config
if TYPE_CHECKING:
from sqlalchemy.orm import DeclarativeBase
@dataclass(frozen=True)
class DatabaseConfig:
"""Configuration for a database."""
env_prefix: str
version_location: str
base_module: str
base_class_name: str
models_module: str
script_location: str = "python/alembic"
file_template: str = "%%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s"
def get_base(self) -> type[DeclarativeBase]:
"""Import and return the Base class."""
module = import_module(self.base_module)
return getattr(module, self.base_class_name)
def import_models(self) -> None:
"""Import ORM models so alembic autogenerate can detect them."""
import_module(self.models_module)
def alembic_config(self) -> Config:
"""Build an alembic Config for this database."""
# Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
from alembic.config import Config as AlembicConfig # noqa: PLC0415
cfg = AlembicConfig()
cfg.set_main_option("script_location", self.script_location)
cfg.set_main_option("file_template", self.file_template)
cfg.set_main_option("prepend_sys_path", ".")
cfg.set_main_option("version_path_separator", "os")
cfg.set_main_option("version_locations", self.version_location)
cfg.set_main_option("revision_environment", "true")
cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,ruff")
cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema")
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
cfg.attributes["base"] = self.get_base()
cfg.attributes["env_prefix"] = self.env_prefix
self.import_models()
return cfg
DATABASES: dict[str, DatabaseConfig] = {
"richie": DatabaseConfig(
env_prefix="RICHIE",
version_location="python/alembic/richie/versions",
base_module="python.orm.richie.base",
base_class_name="RichieBase",
models_module="python.orm.richie.contact",
),
"van_inventory": DatabaseConfig(
env_prefix="VAN_INVENTORY",
version_location="python/alembic/van_inventory/versions",
base_module="python.orm.van_inventory.base",
base_class_name="VanInventoryBase",
models_module="python.orm.van_inventory.models",
),
}
app = typer.Typer(help="Multi-database alembic wrapper.")
@app.command(
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def main(
ctx: typer.Context,
db_name: Annotated[str, typer.Argument(help=f"Database name. Options: {', '.join(DATABASES)}")],
command: Annotated[str, typer.Argument(help="Alembic command (upgrade, downgrade, revision, check, etc.)")],
) -> None:
"""Run an alembic command against the specified database."""
db_config = DATABASES.get(db_name)
if not db_config:
typer.echo(f"Unknown database: {db_name!r}. Available: {', '.join(DATABASES)}", err=True)
raise typer.Exit(code=1)
alembic_cfg = db_config.alembic_config()
cmd_line = CommandLine()
options = cmd_line.parser.parse_args([command, *ctx.args])
cmd_line.run_cmd(alembic_cfg, options)
if __name__ == "__main__":
app()

View File

@@ -1,22 +1,2 @@
"""ORM package exports.""" """ORM package exports."""
from __future__ import annotations
from python.orm.base import RichieBase, TableBase
from python.orm.contact import (
Contact,
ContactNeed,
ContactRelationship,
Need,
RelationshipType,
)
__all__ = [
"Contact",
"ContactNeed",
"ContactRelationship",
"Need",
"RelationshipType",
"RichieBase",
"TableBase",
]

View File

@@ -1,80 +0,0 @@
"""Base ORM definitions."""
from __future__ import annotations
from datetime import datetime
from os import getenv
from typing import cast
from sqlalchemy import DateTime, MetaData, create_engine, func
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class RichieBase(DeclarativeBase):
"""Base class for all ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention={
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
},
)
class TableBase(AbstractConcreteBase, RichieBase):
"""Abstract concrete base for tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
def get_connection_info() -> tuple[str, str, str, str, str | None]:
"""Get connection info from environment variables."""
database = getenv("POSTGRES_DB")
host = getenv("POSTGRES_HOST")
port = getenv("POSTGRES_PORT")
username = getenv("POSTGRES_USER")
password = getenv("POSTGRES_PASSWORD")
if None in (database, host, port, username):
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
raise ValueError(error)
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info()
url = URL.create(
drivername="postgresql+psycopg",
username=username,
password=password,
host=host,
port=int(port),
database=database,
)
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
)

51
python/orm/common.py Normal file
View File

@@ -0,0 +1,51 @@
"""Shared ORM definitions."""
from __future__ import annotations
from os import getenv
from typing import cast
from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Engine
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
"""Get connection info from environment variables."""
database = getenv(f"{name}_DB")
host = getenv(f"{name}_HOST")
port = getenv(f"{name}_PORT")
username = getenv(f"{name}_USER")
password = getenv(f"{name}_PASSWORD")
if None in (database, host, port, username):
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
raise ValueError(error)
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info(name)
url = URL.create(
drivername="postgresql+psycopg",
username=username,
password=password,
host=host,
port=int(port),
database=database,
)
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
)

View File

@@ -0,0 +1,22 @@
"""Richie database ORM exports."""
from __future__ import annotations
from python.orm.richie.base import RichieBase, TableBase
from python.orm.richie.contact import (
Contact,
ContactNeed,
ContactRelationship,
Need,
RelationshipType,
)
__all__ = [
"Contact",
"ContactNeed",
"ContactRelationship",
"Need",
"RelationshipType",
"RichieBase",
"TableBase",
]

39
python/orm/richie/base.py Normal file
View File

@@ -0,0 +1,39 @@
"""Richie database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class RichieBase(DeclarativeBase):
"""Base class for richie database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class TableBase(AbstractConcreteBase, RichieBase):
"""Abstract concrete base for richie tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -7,7 +7,7 @@ from enum import Enum
from sqlalchemy import ForeignKey, String from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.base import RichieBase, TableBase from python.orm.richie.base import RichieBase, TableBase
class RelationshipType(str, Enum): class RelationshipType(str, Enum):

View File

@@ -0,0 +1,2 @@
"""Van inventory database ORM exports."""

View File

@@ -0,0 +1,39 @@
"""Van inventory database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class VanInventoryBase(DeclarativeBase):
"""Base class for van_inventory database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class VanTableBase(AbstractConcreteBase, VanInventoryBase):
"""Abstract concrete base for van_inventory tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -76,6 +76,7 @@
ensureDatabases = [ ensureDatabases = [
"hass" "hass"
"richie" "richie"
"van_inventory"
]; ];
# Thank you NotAShelf # Thank you NotAShelf
# https://github.com/NotAShelf/nyx/blob/d407b4d6e5ab7f60350af61a3d73a62a5e9ac660/modules/core/roles/server/system/services/databases/postgresql.nix#L74 # https://github.com/NotAShelf/nyx/blob/d407b4d6e5ab7f60350af61a3d73a62a5e9ac660/modules/core/roles/server/system/services/databases/postgresql.nix#L74