Compare commits
1 Commits
main
..
15c717f9fe
| Author | SHA1 | Date | |
|---|---|---|---|
| 15c717f9fe |
@@ -1,7 +1,6 @@
|
|||||||
name: pytest
|
name: pytest
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ jobs:
|
|||||||
lockfile:
|
lockfile:
|
||||||
runs-on: self-hosted
|
runs-on: self-hosted
|
||||||
permissions:
|
permissions:
|
||||||
actions: write
|
|
||||||
contents: write
|
contents: write
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
## Dev environment tips
|
||||||
|
|
||||||
|
- use treefmt to format all files
|
||||||
|
- make python code ruff compliant
|
||||||
|
- use pytest to test python code
|
||||||
|
- always use the minimum amount of complexity
|
||||||
|
- if judgment calls are easy to reverse make them. if not ask me first
|
||||||
|
- Match existing code style.
|
||||||
|
- Use builtin helpers getenv() over os.environ.get.
|
||||||
|
- Prefer single-purpose functions over “do everything” helpers.
|
||||||
|
- Avoid compatibility branches like PG_USER and POSTGRESQL_URL unless requested.
|
||||||
|
- Keep helpers only if reused or they simplify the code otherwise inline.
|
||||||
File diff suppressed because one or more lines are too long
Generated
+15
-15
@@ -8,11 +8,11 @@
|
|||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"dir": "pkgs/firefox-addons",
|
"dir": "pkgs/firefox-addons",
|
||||||
"lastModified": 1781150628,
|
"lastModified": 1780733803,
|
||||||
"narHash": "sha256-b4mp8l3qWuSCyYYo9HSngDtcB3PpecYiOXjULrjwwlw=",
|
"narHash": "sha256-QBJPq12P1DAXFGezoEJaSO/xPUrPlnaI3ddSaMG2JpM=",
|
||||||
"owner": "rycee",
|
"owner": "rycee",
|
||||||
"repo": "nur-expressions",
|
"repo": "nur-expressions",
|
||||||
"rev": "753319310f4673a2dabbfab87482187b40bf9bac",
|
"rev": "c80b0aa94392c5f3612ac797108f6d952752036d",
|
||||||
"type": "gitlab"
|
"type": "gitlab"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -29,11 +29,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1781189114,
|
"lastModified": 1780679734,
|
||||||
"narHash": "sha256-5inaamLgUMWy+MOBE9ChF9QAF1o/74LFuHkI0W/9rqc=",
|
"narHash": "sha256-KmRNvpNOb7QEORa06bVgjW9kITcx0VhsI7w0vhmZyD8=",
|
||||||
"owner": "nix-community",
|
"owner": "nix-community",
|
||||||
"repo": "home-manager",
|
"repo": "home-manager",
|
||||||
"rev": "486595d2cf49cfcd649b58a284fa11ac0e34da22",
|
"rev": "b2b7db486e06e098711dc291bb25db82850e1d16",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -47,11 +47,11 @@
|
|||||||
"nixpkgs": "nixpkgs"
|
"nixpkgs": "nixpkgs"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1781168557,
|
"lastModified": 1780310866,
|
||||||
"narHash": "sha256-LOnLQ2tpYF9gqIDDr3+j3DbpJJr/QCH6zPRT2GzEUOE=",
|
"narHash": "sha256-fPBRVf6A5xlACYcOI59shGrjURuvwu0lRsDoSCEXt/I=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixos-hardware",
|
"repo": "nixos-hardware",
|
||||||
"rev": "6358ff76821101c178e3ab4919a62799bfe3652e",
|
"rev": "4ed851c979641e28597a05086332d75cdc9e395f",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -76,11 +76,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs-master": {
|
"nixpkgs-master": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1781229721,
|
"lastModified": 1780798858,
|
||||||
"narHash": "sha256-ORvqDbb/LYxiJljGIejapjkc/kJbVote2N1WSb9W45I=",
|
"narHash": "sha256-4KLc5ZMjfMQosXA2JasUgZTk3i+c/i1zMH4custtmI0=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "173d0ad7a974f8543a9ab01d2271b2e290341b33",
|
"rev": "92840095e65b9970125843175f4be974b71a92ad",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -108,11 +108,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs_2": {
|
"nixpkgs_2": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1781074563,
|
"lastModified": 1780243769,
|
||||||
"narHash": "sha256-md8WlXOlfnIeHeOScMTTHFyf2d6iaTwPl2apR5EQ3P4=",
|
"narHash": "sha256-x5UQuRsH3MqI0U9afaXSNqzTPSeZlRLvFAav2Ux1pNw=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "9ae611a455b90cf061d8f332b977e387bda8e1ca",
|
"rev": "331800de5053fcebacf6813adb5db9c9dca22a0c",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
@@ -38,6 +38,7 @@
|
|||||||
ruff
|
ruff
|
||||||
scalene
|
scalene
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
|
sqlalchemy
|
||||||
tenacity
|
tenacity
|
||||||
textual
|
textual
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|||||||
+8
-7
@@ -3,7 +3,7 @@ name = "system_tools"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = ""
|
description = ""
|
||||||
authors = [{ name = "Richie Cahill", email = "richie@tmmworkshop.com" }]
|
authors = [{ name = "Richie Cahill", email = "richie@tmmworkshop.com" }]
|
||||||
requires-python = "~=3.14.0"
|
requires-python = "~=3.13.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
# these dependencies are a best effort and aren't guaranteed to work
|
# these dependencies are a best effort and aren't guaranteed to work
|
||||||
@@ -12,22 +12,20 @@ dependencies = [
|
|||||||
"alembic",
|
"alembic",
|
||||||
"apprise",
|
"apprise",
|
||||||
"apscheduler",
|
"apscheduler",
|
||||||
"fastapi",
|
|
||||||
"fastapi-cli",
|
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"python-multipart",
|
||||||
"polars",
|
"polars",
|
||||||
"psycopg[binary]",
|
"psycopg[binary]",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"python-multipart",
|
"pyyaml",
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"tenacity",
|
|
||||||
"tinytuya",
|
|
||||||
"typer",
|
"typer",
|
||||||
"websockets",
|
"websockets",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
database = "python.database_cli:app"
|
database = "python.database_cli:app"
|
||||||
|
van-inventory = "python.van_inventory.main:serve"
|
||||||
whisper-transcribe = "python.tools.whisper.transcribe:main"
|
whisper-transcribe = "python.tools.whisper.transcribe:main"
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
@@ -43,7 +41,7 @@ dev = [
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|
||||||
target-version = "py314"
|
target-version = "py313"
|
||||||
|
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
||||||
@@ -86,6 +84,9 @@ lint.ignore = [
|
|||||||
"python/alembic/**" = [
|
"python/alembic/**" = [
|
||||||
"INP001", # (perm) this creates LSP issues for alembic
|
"INP001", # (perm) this creates LSP issues for alembic
|
||||||
]
|
]
|
||||||
|
"python/signal_bot/**" = [
|
||||||
|
"D107", # (perm) class docstrings cover __init__
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
convention = "google"
|
convention = "google"
|
||||||
|
|||||||
+1417
File diff suppressed because it is too large
Load Diff
+50
@@ -0,0 +1,50 @@
|
|||||||
|
"""adding FailedIngestion.
|
||||||
|
|
||||||
|
Revision ID: 2f43120e3ffc
|
||||||
|
Revises: f99be864fe69
|
||||||
|
Create Date: 2026-03-24 23:46:17.277897
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from python.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2f43120e3ffc"
|
||||||
|
down_revision: str | None = "f99be864fe69"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"failed_ingestion",
|
||||||
|
sa.Column("raw_line", sa.Text(), nullable=False),
|
||||||
|
sa.Column("error", sa.Text(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_failed_ingestion")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("failed_ingestion", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
+2770
File diff suppressed because it is too large
Load Diff
+1391
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,72 @@
|
|||||||
|
"""Attach all partition tables to the posts parent table.
|
||||||
|
|
||||||
|
Alembic autogenerate creates partition tables as standalone tables but does not
|
||||||
|
emit the ALTER TABLE ... ATTACH PARTITION statements needed for PostgreSQL to
|
||||||
|
route inserts to the correct partition.
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: 605b1794838f
|
||||||
|
Create Date: 2026-03-25 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from python.orm import DataScienceDevBase
|
||||||
|
from python.orm.data_science_dev.posts.partitions import (
|
||||||
|
PARTITION_END_YEAR,
|
||||||
|
PARTITION_START_YEAR,
|
||||||
|
iso_weeks_in_year,
|
||||||
|
week_bounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a1b2c3d4e5f6"
|
||||||
|
down_revision: str | None = "605b1794838f"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
ALREADY_ATTACHED_QUERY = text("""
|
||||||
|
SELECT inhrelid::regclass::text
|
||||||
|
FROM pg_inherits
|
||||||
|
WHERE inhparent = :parent::regclass
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Attach all weekly partition tables to the posts parent table."""
|
||||||
|
connection = op.get_bind()
|
||||||
|
already_attached = {row[0] for row in connection.execute(ALREADY_ATTACHED_QUERY, {"parent": f"{schema}.posts"})}
|
||||||
|
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
qualified_name = f"{schema}.{table_name}"
|
||||||
|
if qualified_name in already_attached:
|
||||||
|
continue
|
||||||
|
start, end = week_bounds(year, week)
|
||||||
|
start_str = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
end_str = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {schema}.posts "
|
||||||
|
f"ATTACH PARTITION {qualified_name} "
|
||||||
|
f"FOR VALUES FROM ('{start_str}') TO ('{end_str}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Detach all weekly partition tables from the posts parent table."""
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
op.execute(f"ALTER TABLE {schema}.posts DETACH PARTITION {schema}.{table_name}")
|
||||||
+153
@@ -0,0 +1,153 @@
|
|||||||
|
"""adding congress data.
|
||||||
|
|
||||||
|
Revision ID: 83bfc8af92d8
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2026-03-27 10:43:02.324510
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from python.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "83bfc8af92d8"
|
||||||
|
down_revision: str | None = "a1b2c3d4e5f6"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"bill",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("bill_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("title", sa.String(), nullable=True),
|
||||||
|
sa.Column("title_short", sa.String(), nullable=True),
|
||||||
|
sa.Column("official_title", sa.String(), nullable=True),
|
||||||
|
sa.Column("status", sa.String(), nullable=True),
|
||||||
|
sa.Column("status_at", sa.Date(), nullable=True),
|
||||||
|
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("subjects_top_term", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
|
||||||
|
sa.UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index("ix_bill_congress", "bill", ["congress"], unique=False, schema=schema)
|
||||||
|
op.create_table(
|
||||||
|
"legislator",
|
||||||
|
sa.Column("bioguide_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("thomas_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("lis_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("govtrack_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("opensecrets_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("fec_ids", sa.String(), nullable=True),
|
||||||
|
sa.Column("first_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("last_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("official_full_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("nickname", sa.String(), nullable=True),
|
||||||
|
sa.Column("birthday", sa.Date(), nullable=True),
|
||||||
|
sa.Column("gender", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_party", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_state", sa.String(), nullable=True),
|
||||||
|
sa.Column("current_district", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("current_chamber", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
|
||||||
|
op.create_table(
|
||||||
|
"bill_text",
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("version_code", sa.String(), nullable=False),
|
||||||
|
sa.Column("version_name", sa.String(), nullable=True),
|
||||||
|
sa.Column("text_content", sa.String(), nullable=True),
|
||||||
|
sa.Column("date", sa.Date(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_bill_text_bill_id_bill"), ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text")),
|
||||||
|
sa.UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"vote",
|
||||||
|
sa.Column("congress", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chamber", sa.String(), nullable=False),
|
||||||
|
sa.Column("session", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("vote_type", sa.String(), nullable=True),
|
||||||
|
sa.Column("question", sa.String(), nullable=True),
|
||||||
|
sa.Column("result", sa.String(), nullable=True),
|
||||||
|
sa.Column("result_text", sa.String(), nullable=True),
|
||||||
|
sa.Column("vote_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column("yea_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("nay_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("not_voting_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("present_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("bill_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
|
||||||
|
sa.UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_index("ix_vote_congress_chamber", "vote", ["congress", "chamber"], unique=False, schema=schema)
|
||||||
|
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
|
||||||
|
op.create_table(
|
||||||
|
"vote_record",
|
||||||
|
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("position", sa.String(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_vote_record_legislator_id_legislator"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("vote_record", schema=schema)
|
||||||
|
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
|
||||||
|
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
|
||||||
|
op.drop_table("vote", schema=schema)
|
||||||
|
op.drop_table("bill_text", schema=schema)
|
||||||
|
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
|
||||||
|
op.drop_table("legislator", schema=schema)
|
||||||
|
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
|
||||||
|
op.drop_table("bill", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
+58
@@ -0,0 +1,58 @@
|
|||||||
|
"""adding LegislatorSocialMedia.
|
||||||
|
|
||||||
|
Revision ID: 5cd7eee3549d
|
||||||
|
Revises: 83bfc8af92d8
|
||||||
|
Create Date: 2026-03-29 11:53:44.224799
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
from python.orm import DataScienceDevBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "5cd7eee3549d"
|
||||||
|
down_revision: str | None = "83bfc8af92d8"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = DataScienceDevBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"legislator_social_media",
|
||||||
|
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("platform", sa.String(), nullable=False),
|
||||||
|
sa.Column("account_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("url", sa.String(), nullable=True),
|
||||||
|
sa.Column("source", sa.String(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["legislator_id"],
|
||||||
|
[f"{schema}.legislator.id"],
|
||||||
|
name=op.f("fk_legislator_social_media_legislator_id_legislator"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_social_media")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("legislator_social_media", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
-93
@@ -1,93 +0,0 @@
|
|||||||
"""adding audiobook libreary metadata.
|
|
||||||
|
|
||||||
Revision ID: d7864d1ffc17
|
|
||||||
Revises: c8a794340928
|
|
||||||
Create Date: 2026-06-03 20:24:09.200837
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
from python.orm import RichieBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "d7864d1ffc17"
|
|
||||||
down_revision: str | None = "c8a794340928"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
schema = RichieBase.schema_name
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table(
|
|
||||||
"audiobook_author",
|
|
||||||
sa.Column("name", sa.String(), nullable=False),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook_author")),
|
|
||||||
sa.UniqueConstraint("name", name=op.f("uq_audiobook_author_name")),
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
op.create_table(
|
|
||||||
"audiobook_series",
|
|
||||||
sa.Column("name", sa.String(), nullable=False),
|
|
||||||
sa.Column("author_id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["author_id"],
|
|
||||||
[f"{schema}.audiobook_author.id"],
|
|
||||||
name=op.f("fk_audiobook_series_author_id_audiobook_author"),
|
|
||||||
ondelete="CASCADE",
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook_series")),
|
|
||||||
sa.UniqueConstraint("author_id", "name", name=op.f("uq_audiobook_series_author_id")),
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
op.create_table(
|
|
||||||
"audiobook",
|
|
||||||
sa.Column("title", sa.String(), nullable=False),
|
|
||||||
sa.Column("author_id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("series_id", sa.Integer(), nullable=True),
|
|
||||||
sa.Column("series_index", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["author_id"],
|
|
||||||
[f"{schema}.audiobook_author.id"],
|
|
||||||
name=op.f("fk_audiobook_author_id_audiobook_author"),
|
|
||||||
ondelete="CASCADE",
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["series_id"],
|
|
||||||
[f"{schema}.audiobook_series.id"],
|
|
||||||
name=op.f("fk_audiobook_series_id_audiobook_series"),
|
|
||||||
ondelete="SET NULL",
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook")),
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table("audiobook", schema=schema)
|
|
||||||
op.drop_table("audiobook_series", schema=schema)
|
|
||||||
op.drop_table("audiobook_author", schema=schema)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
-63
@@ -1,63 +0,0 @@
|
|||||||
"""updated series_index to float and added UniqueConstraint to audiobook and audiobook_author.
|
|
||||||
|
|
||||||
Revision ID: b3c60cc5beb5
|
|
||||||
Revises: d7864d1ffc17
|
|
||||||
Create Date: 2026-06-10 20:02:43.073725
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
from python.orm import RichieBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "b3c60cc5beb5"
|
|
||||||
down_revision: str | None = "d7864d1ffc17"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
schema = RichieBase.schema_name
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.alter_column(
|
|
||||||
"audiobook",
|
|
||||||
"series_index",
|
|
||||||
existing_type=sa.INTEGER(),
|
|
||||||
type_=sa.Float(),
|
|
||||||
existing_nullable=False,
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
op.create_unique_constraint(
|
|
||||||
op.f("uq_audiobook_author_id"),
|
|
||||||
"audiobook",
|
|
||||||
["author_id", "series_id", "title"],
|
|
||||||
schema=schema,
|
|
||||||
postgresql_nulls_not_distinct=True,
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_constraint(op.f("uq_audiobook_author_id"), "audiobook", schema=schema, type_="unique")
|
|
||||||
op.alter_column(
|
|
||||||
"audiobook",
|
|
||||||
"series_index",
|
|
||||||
existing_type=sa.Float(),
|
|
||||||
type_=sa.INTEGER(),
|
|
||||||
existing_nullable=False,
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
+100
@@ -0,0 +1,100 @@
|
|||||||
|
"""seprating signal_bot database.
|
||||||
|
|
||||||
|
Revision ID: 6eaf696e07a5
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-17 21:35:37.612672
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from python.orm import SignalBotBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "6eaf696e07a5"
|
||||||
|
down_revision: str | None = None
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = SignalBotBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"dead_letter_message",
|
||||||
|
sa.Column("source", sa.String(), nullable=False),
|
||||||
|
sa.Column("message", sa.Text(), nullable=False),
|
||||||
|
sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status", postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_dead_letter_message")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"role",
|
||||||
|
sa.Column("name", sa.String(length=50), nullable=False),
|
||||||
|
sa.Column("id", sa.SmallInteger(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
|
||||||
|
sa.UniqueConstraint("name", name=op.f("uq_role_name")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"signal_device",
|
||||||
|
sa.Column("phone_number", sa.String(length=50), nullable=False),
|
||||||
|
sa.Column("safety_number", sa.String(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"trust_level",
|
||||||
|
postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_signal_device")),
|
||||||
|
sa.UniqueConstraint("phone_number", name=op.f("uq_signal_device_phone_number")),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"device_role",
|
||||||
|
sa.Column("device_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("role_id", sa.SmallInteger(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["device_id"], [f"{schema}.signal_device.id"], name=op.f("fk_device_role_device_id_signal_device")
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["role_id"], [f"{schema}.role.id"], name=op.f("fk_device_role_role_id_role")),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_device_role")),
|
||||||
|
sa.UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"),
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("device_role", schema=schema)
|
||||||
|
op.drop_table("signal_device", schema=schema)
|
||||||
|
op.drop_table("role", schema=schema)
|
||||||
|
op.drop_table("dead_letter_message", schema=schema)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""test.
|
||||||
|
|
||||||
|
Revision ID: 66bdd532bcab
|
||||||
|
Revises: 6eaf696e07a5
|
||||||
|
Create Date: 2026-03-18 19:21:14.561568
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from python.orm import SignalBotBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "66bdd532bcab"
|
||||||
|
down_revision: str | None = "6eaf696e07a5"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
schema = SignalBotBase.schema_name
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column(
|
||||||
|
"dead_letter_message",
|
||||||
|
"status",
|
||||||
|
existing_type=postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
|
||||||
|
type_=sa.Enum("UNPROCESSED", "PROCESSED", name="message_status", native_enum=False),
|
||||||
|
existing_nullable=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"signal_device",
|
||||||
|
"trust_level",
|
||||||
|
existing_type=postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
||||||
|
type_=sa.Enum("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", native_enum=False),
|
||||||
|
existing_nullable=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column(
|
||||||
|
"signal_device",
|
||||||
|
"trust_level",
|
||||||
|
existing_type=sa.Enum("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", native_enum=False),
|
||||||
|
type_=postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
||||||
|
existing_nullable=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"dead_letter_message",
|
||||||
|
"status",
|
||||||
|
existing_type=sa.Enum("UNPROCESSED", "PROCESSED", name="message_status", native_enum=False),
|
||||||
|
type_=postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
|
||||||
|
existing_nullable=False,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -1,15 +1,11 @@
|
|||||||
"""FastAPI dependencies."""
|
"""FastAPI dependencies."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from collections.abc import Iterator
|
||||||
|
from typing import Annotated
|
||||||
from typing import TYPE_CHECKING, Annotated
|
|
||||||
|
|
||||||
from fastapi import Depends, Request
|
from fastapi import Depends, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Iterator
|
|
||||||
|
|
||||||
|
|
||||||
def get_db(request: Request) -> Iterator[Session]:
|
def get_db(request: Request) -> Iterator[Session]:
|
||||||
"""Get database session from app state."""
|
"""Get database session from app state."""
|
||||||
|
|||||||
+2
-6
@@ -1,10 +1,9 @@
|
|||||||
"""FastAPI interface for Contact database."""
|
"""FastAPI interface for Contact database."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -15,9 +14,6 @@ from python.api.routers import contact_router, views_router
|
|||||||
from python.common import configure_logger
|
from python.common import configure_logger
|
||||||
from python.orm.common import get_postgres_engine
|
from python.orm.common import get_postgres_engine
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
"""Middleware for the FastAPI application."""
|
"""Middleware for the FastAPI application."""
|
||||||
|
|
||||||
from compression import zstd
|
from compression import zstd
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
from starlette.responses import Response
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
MINIMUM_RESPONSE_SIZE = 500
|
MINIMUM_RESPONSE_SIZE = 500
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from python.api.dependencies import DbSession # noqa: TC001 this is a FastAPI needed at runtime
|
from python.api.dependencies import DbSession
|
||||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||||
|
|
||||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, selectinload
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
from python.api.dependencies import DbSession # noqa: TC001 this is a FastAPI needed at runtime
|
from python.api.dependencies import DbSession
|
||||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||||
|
|
||||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
"""Data science CLI tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
@@ -0,0 +1,613 @@
|
|||||||
|
"""Ingestion pipeline for loading congress data from unitedstates/congress JSON files.
|
||||||
|
|
||||||
|
Loads legislators, bills, votes, vote records, and bill text into the data_science_dev database.
|
||||||
|
Expects the parent directory to contain congress-tracker/ and congress-legislators/ as siblings.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ingest-congress /path/to/parent/
|
||||||
|
ingest-congress /path/to/parent/ --congress 118
|
||||||
|
ingest-congress /path/to/parent/ --congress 118 --only bills
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path # noqa: TC003 needed at runtime for typer CLI argument
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import typer
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from python.common import configure_logger
|
||||||
|
from python.orm.common import get_postgres_engine
|
||||||
|
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, LegislatorSocialMedia, Vote, VoteRecord
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BATCH_SIZE = 10_000
|
||||||
|
|
||||||
|
app = typer.Typer(help="Ingest unitedstates/congress data into data_science_dev.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
parent_dir: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Argument(help="Parent directory containing congress-tracker/ and congress-legislators/"),
|
||||||
|
],
|
||||||
|
congress: Annotated[int | None, typer.Option(help="Only ingest a specific congress number")] = None,
|
||||||
|
only: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option(help="Only run a specific step: legislators, social-media, bills, votes, bill-text"),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Ingest congress data from unitedstates/congress JSON files."""
|
||||||
|
configure_logger(level="INFO")
|
||||||
|
|
||||||
|
data_dir = parent_dir / "congress-tracker/congress/data/"
|
||||||
|
legislators_dir = parent_dir / "congress-legislators"
|
||||||
|
|
||||||
|
if not data_dir.is_dir():
|
||||||
|
typer.echo(f"Expected congress-tracker/ directory: {data_dir}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
if not legislators_dir.is_dir():
|
||||||
|
typer.echo(f"Expected congress-legislators/ directory: {legislators_dir}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
|
||||||
|
congress_dirs = _resolve_congress_dirs(data_dir, congress)
|
||||||
|
if not congress_dirs:
|
||||||
|
typer.echo("No congress directories found.", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
logger.info("Found %d congress directories to process", len(congress_dirs))
|
||||||
|
|
||||||
|
steps: dict[str, tuple] = {
|
||||||
|
"legislators": (ingest_legislators, (engine, legislators_dir)),
|
||||||
|
"legislators-social-media": (ingest_social_media, (engine, legislators_dir)),
|
||||||
|
"bills": (ingest_bills, (engine, congress_dirs)),
|
||||||
|
"votes": (ingest_votes, (engine, congress_dirs)),
|
||||||
|
"bill-text": (ingest_bill_text, (engine, congress_dirs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if only:
|
||||||
|
if only not in steps:
|
||||||
|
typer.echo(f"Unknown step: {only}. Choose from: {', '.join(steps)}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
steps = {only: steps[only]}
|
||||||
|
|
||||||
|
for step_name, (step_func, step_args) in steps.items():
|
||||||
|
logger.info("=== Starting step: %s ===", step_name)
|
||||||
|
step_func(*step_args)
|
||||||
|
logger.info("=== Finished step: %s ===", step_name)
|
||||||
|
|
||||||
|
logger.info("ingest-congress done")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_congress_dirs(data_dir: Path, congress: int | None) -> list[Path]:
|
||||||
|
"""Find congress number directories under data_dir."""
|
||||||
|
if congress is not None:
|
||||||
|
target = data_dir / str(congress)
|
||||||
|
return [target] if target.is_dir() else []
|
||||||
|
return sorted(path for path in data_dir.iterdir() if path.is_dir() and path.name.isdigit())
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_batch(session: Session, batch: list[object], label: str) -> int:
|
||||||
|
"""Add a batch of ORM objects to the session and commit. Returns count added."""
|
||||||
|
if not batch:
|
||||||
|
return 0
|
||||||
|
session.add_all(batch)
|
||||||
|
session.commit()
|
||||||
|
count = len(batch)
|
||||||
|
logger.info("Committed %d %s", count, label)
|
||||||
|
batch.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legislators — loaded from congress-legislators YAML files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_legislators(engine: Engine, legislators_dir: Path) -> None:
|
||||||
|
"""Load legislators from congress-legislators YAML files."""
|
||||||
|
legislators_data = _load_legislators_yaml(legislators_dir)
|
||||||
|
logger.info("Loaded %d legislators from YAML files", len(legislators_data))
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
existing_legislators = {
|
||||||
|
legislator.bioguide_id: legislator for legislator in session.scalars(select(Legislator)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing legislators in DB", len(existing_legislators))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
total_updated = 0
|
||||||
|
for entry in legislators_data:
|
||||||
|
bioguide_id = entry.get("id", {}).get("bioguide")
|
||||||
|
if not bioguide_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
fields = _parse_legislator(entry)
|
||||||
|
if existing := existing_legislators.get(bioguide_id):
|
||||||
|
changed = False
|
||||||
|
for field, value in fields.items():
|
||||||
|
if value is not None and getattr(existing, field) != value:
|
||||||
|
setattr(existing, field, value)
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
total_updated += 1
|
||||||
|
else:
|
||||||
|
session.add(Legislator(bioguide_id=bioguide_id, **fields))
|
||||||
|
total_inserted += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info("Inserted %d new legislators, updated %d existing", total_inserted, total_updated)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_legislators_yaml(legislators_dir: Path) -> list[dict]:
|
||||||
|
"""Load and combine legislators-current.yaml and legislators-historical.yaml."""
|
||||||
|
legislators: list[dict] = []
|
||||||
|
for filename in ("legislators-current.yaml", "legislators-historical.yaml"):
|
||||||
|
path = legislators_dir / filename
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning("Legislators file not found: %s", path)
|
||||||
|
continue
|
||||||
|
with path.open() as file:
|
||||||
|
data = yaml.safe_load(file)
|
||||||
|
if isinstance(data, list):
|
||||||
|
legislators.extend(data)
|
||||||
|
return legislators
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_legislator(entry: dict) -> dict:
|
||||||
|
"""Extract Legislator fields from a congress-legislators YAML entry."""
|
||||||
|
ids = entry.get("id", {})
|
||||||
|
name = entry.get("name", {})
|
||||||
|
bio = entry.get("bio", {})
|
||||||
|
terms = entry.get("terms", [])
|
||||||
|
latest_term = terms[-1] if terms else {}
|
||||||
|
|
||||||
|
fec_ids = ids.get("fec")
|
||||||
|
fec_ids_joined = ",".join(fec_ids) if isinstance(fec_ids, list) else fec_ids
|
||||||
|
|
||||||
|
chamber = latest_term.get("type")
|
||||||
|
chamber_normalized = {"rep": "House", "sen": "Senate"}.get(chamber, chamber)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"thomas_id": ids.get("thomas"),
|
||||||
|
"lis_id": ids.get("lis"),
|
||||||
|
"govtrack_id": ids.get("govtrack"),
|
||||||
|
"opensecrets_id": ids.get("opensecrets"),
|
||||||
|
"fec_ids": fec_ids_joined,
|
||||||
|
"first_name": name.get("first"),
|
||||||
|
"last_name": name.get("last"),
|
||||||
|
"official_full_name": name.get("official_full"),
|
||||||
|
"nickname": name.get("nickname"),
|
||||||
|
"birthday": bio.get("birthday"),
|
||||||
|
"gender": bio.get("gender"),
|
||||||
|
"current_party": latest_term.get("party"),
|
||||||
|
"current_state": latest_term.get("state"),
|
||||||
|
"current_district": latest_term.get("district"),
|
||||||
|
"current_chamber": chamber_normalized,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Social Media — loaded from legislators-social-media.yaml
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SOCIAL_MEDIA_PLATFORMS = {
|
||||||
|
"twitter": "https://twitter.com/{account}",
|
||||||
|
"facebook": "https://facebook.com/{account}",
|
||||||
|
"youtube": "https://youtube.com/{account}",
|
||||||
|
"instagram": "https://instagram.com/{account}",
|
||||||
|
"mastodon": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_social_media(engine: Engine, legislators_dir: Path) -> None:
|
||||||
|
"""Load social media accounts from legislators-social-media.yaml."""
|
||||||
|
social_media_path = legislators_dir / "legislators-social-media.yaml"
|
||||||
|
if not social_media_path.exists():
|
||||||
|
logger.warning("Social media file not found: %s", social_media_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
with social_media_path.open() as file:
|
||||||
|
social_media_data = yaml.safe_load(file)
|
||||||
|
|
||||||
|
if not isinstance(social_media_data, list):
|
||||||
|
logger.warning("Unexpected format in %s", social_media_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Loaded %d entries from legislators-social-media.yaml", len(social_media_data))
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
legislator_map = _build_legislator_map(session)
|
||||||
|
existing_accounts = {
|
||||||
|
(account.legislator_id, account.platform)
|
||||||
|
for account in session.scalars(select(LegislatorSocialMedia)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing social media accounts in DB", len(existing_accounts))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
total_updated = 0
|
||||||
|
for entry in social_media_data:
|
||||||
|
bioguide_id = entry.get("id", {}).get("bioguide")
|
||||||
|
if not bioguide_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
legislator_id = legislator_map.get(bioguide_id)
|
||||||
|
if legislator_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
social = entry.get("social", {})
|
||||||
|
for platform, url_template in SOCIAL_MEDIA_PLATFORMS.items():
|
||||||
|
account_name = social.get(platform)
|
||||||
|
if not account_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
url = url_template.format(account=account_name) if url_template else None
|
||||||
|
|
||||||
|
if (legislator_id, platform) in existing_accounts:
|
||||||
|
total_updated += 1
|
||||||
|
else:
|
||||||
|
session.add(
|
||||||
|
LegislatorSocialMedia(
|
||||||
|
legislator_id=legislator_id,
|
||||||
|
platform=platform,
|
||||||
|
account_name=str(account_name),
|
||||||
|
url=url,
|
||||||
|
source="https://github.com/unitedstates/congress-legislators",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_accounts.add((legislator_id, platform))
|
||||||
|
total_inserted += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info("Inserted %d new social media accounts, updated %d existing", total_inserted, total_updated)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_voters(position_group: object) -> Iterator[dict]:
|
||||||
|
"""Yield voter dicts from a vote position group (handles list, single dict, or string)."""
|
||||||
|
if isinstance(position_group, dict):
|
||||||
|
yield position_group
|
||||||
|
elif isinstance(position_group, list):
|
||||||
|
for voter in position_group:
|
||||||
|
if isinstance(voter, dict):
|
||||||
|
yield voter
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bills
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_bills(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||||
|
"""Load bill data.json files."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
existing_bills = {(bill.congress, bill.bill_type, bill.number) for bill in session.scalars(select(Bill)).all()}
|
||||||
|
logger.info("Found %d existing bills in DB", len(existing_bills))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
batch: list[Bill] = []
|
||||||
|
for congress_dir in congress_dirs:
|
||||||
|
bills_dir = congress_dir / "bills"
|
||||||
|
if not bills_dir.is_dir():
|
||||||
|
continue
|
||||||
|
logger.info("Scanning bills from %s", congress_dir.name)
|
||||||
|
for bill_file in bills_dir.rglob("data.json"):
|
||||||
|
data = _read_json(bill_file)
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
bill = _parse_bill(data, existing_bills)
|
||||||
|
if bill is not None:
|
||||||
|
batch.append(bill)
|
||||||
|
if len(batch) >= BATCH_SIZE:
|
||||||
|
total_inserted += _flush_batch(session, batch, "bills")
|
||||||
|
|
||||||
|
total_inserted += _flush_batch(session, batch, "bills")
|
||||||
|
logger.info("Inserted %d new bills total", total_inserted)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bill(data: dict, existing_bills: set[tuple[int, str, int]]) -> Bill | None:
|
||||||
|
"""Parse a bill data.json dict into a Bill ORM object, skipping existing."""
|
||||||
|
raw_congress = data.get("congress")
|
||||||
|
bill_type = data.get("bill_type")
|
||||||
|
raw_number = data.get("number")
|
||||||
|
if raw_congress is None or bill_type is None or raw_number is None:
|
||||||
|
return None
|
||||||
|
congress = int(raw_congress)
|
||||||
|
number = int(raw_number)
|
||||||
|
if (congress, bill_type, number) in existing_bills:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sponsor_bioguide = None
|
||||||
|
sponsor = data.get("sponsor")
|
||||||
|
if sponsor:
|
||||||
|
sponsor_bioguide = sponsor.get("bioguide_id")
|
||||||
|
|
||||||
|
return Bill(
|
||||||
|
congress=congress,
|
||||||
|
bill_type=bill_type,
|
||||||
|
number=number,
|
||||||
|
title=data.get("short_title") or data.get("official_title"),
|
||||||
|
title_short=data.get("short_title"),
|
||||||
|
official_title=data.get("official_title"),
|
||||||
|
status=data.get("status"),
|
||||||
|
status_at=data.get("status_at"),
|
||||||
|
sponsor_bioguide_id=sponsor_bioguide,
|
||||||
|
subjects_top_term=data.get("subjects_top_term"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Votes (and vote records)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_votes(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||||
|
"""Load vote data.json files with their vote records."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
legislator_map = _build_legislator_map(session)
|
||||||
|
logger.info("Loaded %d legislators into lookup map", len(legislator_map))
|
||||||
|
bill_map = _build_bill_map(session)
|
||||||
|
logger.info("Loaded %d bills into lookup map", len(bill_map))
|
||||||
|
existing_votes = {
|
||||||
|
(vote.congress, vote.chamber, vote.session, vote.number) for vote in session.scalars(select(Vote)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing votes in DB", len(existing_votes))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
batch: list[Vote] = []
|
||||||
|
for congress_dir in congress_dirs:
|
||||||
|
votes_dir = congress_dir / "votes"
|
||||||
|
if not votes_dir.is_dir():
|
||||||
|
continue
|
||||||
|
logger.info("Scanning votes from %s", congress_dir.name)
|
||||||
|
for vote_file in votes_dir.rglob("data.json"):
|
||||||
|
data = _read_json(vote_file)
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
vote = _parse_vote(data, legislator_map, bill_map, existing_votes)
|
||||||
|
if vote is not None:
|
||||||
|
batch.append(vote)
|
||||||
|
if len(batch) >= BATCH_SIZE:
|
||||||
|
total_inserted += _flush_batch(session, batch, "votes")
|
||||||
|
|
||||||
|
total_inserted += _flush_batch(session, batch, "votes")
|
||||||
|
logger.info("Inserted %d new votes total", total_inserted)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_legislator_map(session: Session) -> dict[str, int]:
|
||||||
|
"""Build a mapping of bioguide_id -> legislator.id."""
|
||||||
|
return {legislator.bioguide_id: legislator.id for legislator in session.scalars(select(Legislator)).all()}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bill_map(session: Session) -> dict[tuple[int, str, int], int]:
|
||||||
|
"""Build a mapping of (congress, bill_type, number) -> bill.id."""
|
||||||
|
return {(bill.congress, bill.bill_type, bill.number): bill.id for bill in session.scalars(select(Bill)).all()}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_vote(
|
||||||
|
data: dict,
|
||||||
|
legislator_map: dict[str, int],
|
||||||
|
bill_map: dict[tuple[int, str, int], int],
|
||||||
|
existing_votes: set[tuple[int, str, int, int]],
|
||||||
|
) -> Vote | None:
|
||||||
|
"""Parse a vote data.json dict into a Vote ORM object with records."""
|
||||||
|
raw_congress = data.get("congress")
|
||||||
|
chamber = data.get("chamber")
|
||||||
|
raw_number = data.get("number")
|
||||||
|
vote_date = data.get("date")
|
||||||
|
if raw_congress is None or chamber is None or raw_number is None or vote_date is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_session = data.get("session")
|
||||||
|
if raw_session is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
congress = int(raw_congress)
|
||||||
|
number = int(raw_number)
|
||||||
|
session_number = int(raw_session)
|
||||||
|
|
||||||
|
# Normalize chamber from "h"/"s" to "House"/"Senate"
|
||||||
|
chamber_normalized = {"h": "House", "s": "Senate"}.get(chamber, chamber)
|
||||||
|
|
||||||
|
if (congress, chamber_normalized, session_number, number) in existing_votes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Resolve linked bill
|
||||||
|
bill_id = None
|
||||||
|
bill_ref = data.get("bill")
|
||||||
|
if bill_ref:
|
||||||
|
bill_key = (
|
||||||
|
int(bill_ref.get("congress", congress)),
|
||||||
|
bill_ref.get("type"),
|
||||||
|
int(bill_ref.get("number", 0)),
|
||||||
|
)
|
||||||
|
bill_id = bill_map.get(bill_key)
|
||||||
|
|
||||||
|
raw_votes = data.get("votes", {})
|
||||||
|
vote_counts = _count_votes(raw_votes)
|
||||||
|
vote_records = _build_vote_records(raw_votes, legislator_map)
|
||||||
|
|
||||||
|
return Vote(
|
||||||
|
congress=congress,
|
||||||
|
chamber=chamber_normalized,
|
||||||
|
session=session_number,
|
||||||
|
number=number,
|
||||||
|
vote_type=data.get("type"),
|
||||||
|
question=data.get("question"),
|
||||||
|
result=data.get("result"),
|
||||||
|
result_text=data.get("result_text"),
|
||||||
|
vote_date=vote_date[:10] if isinstance(vote_date, str) else vote_date,
|
||||||
|
bill_id=bill_id,
|
||||||
|
vote_records=vote_records,
|
||||||
|
**vote_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_votes(raw_votes: dict) -> dict[str, int]:
|
||||||
|
"""Count voters per position category, correctly handling dict and list formats."""
|
||||||
|
yea_count = 0
|
||||||
|
nay_count = 0
|
||||||
|
not_voting_count = 0
|
||||||
|
present_count = 0
|
||||||
|
|
||||||
|
for position, position_group in raw_votes.items():
|
||||||
|
voter_count = sum(1 for _ in _iter_voters(position_group))
|
||||||
|
if position in ("Yea", "Aye"):
|
||||||
|
yea_count += voter_count
|
||||||
|
elif position in ("Nay", "No"):
|
||||||
|
nay_count += voter_count
|
||||||
|
elif position == "Not Voting":
|
||||||
|
not_voting_count += voter_count
|
||||||
|
elif position == "Present":
|
||||||
|
present_count += voter_count
|
||||||
|
|
||||||
|
return {
|
||||||
|
"yea_count": yea_count,
|
||||||
|
"nay_count": nay_count,
|
||||||
|
"not_voting_count": not_voting_count,
|
||||||
|
"present_count": present_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_vote_records(raw_votes: dict, legislator_map: dict[str, int]) -> list[VoteRecord]:
|
||||||
|
"""Build VoteRecord objects from raw vote data."""
|
||||||
|
records: list[VoteRecord] = []
|
||||||
|
for position, position_group in raw_votes.items():
|
||||||
|
for voter in _iter_voters(position_group):
|
||||||
|
bioguide_id = voter.get("id")
|
||||||
|
if not bioguide_id:
|
||||||
|
continue
|
||||||
|
legislator_id = legislator_map.get(bioguide_id)
|
||||||
|
if legislator_id is None:
|
||||||
|
continue
|
||||||
|
records.append(
|
||||||
|
VoteRecord(
|
||||||
|
legislator_id=legislator_id,
|
||||||
|
position=position,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bill Text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_bill_text(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||||
|
"""Load bill text from text-versions directories."""
|
||||||
|
with Session(engine) as session:
|
||||||
|
bill_map = _build_bill_map(session)
|
||||||
|
logger.info("Loaded %d bills into lookup map", len(bill_map))
|
||||||
|
existing_bill_texts = {
|
||||||
|
(bill_text.bill_id, bill_text.version_code) for bill_text in session.scalars(select(BillText)).all()
|
||||||
|
}
|
||||||
|
logger.info("Found %d existing bill text versions in DB", len(existing_bill_texts))
|
||||||
|
|
||||||
|
total_inserted = 0
|
||||||
|
batch: list[BillText] = []
|
||||||
|
for congress_dir in congress_dirs:
|
||||||
|
logger.info("Scanning bill texts from %s", congress_dir.name)
|
||||||
|
for bill_text in _iter_bill_texts(congress_dir, bill_map, existing_bill_texts):
|
||||||
|
batch.append(bill_text)
|
||||||
|
if len(batch) >= BATCH_SIZE:
|
||||||
|
total_inserted += _flush_batch(session, batch, "bill texts")
|
||||||
|
|
||||||
|
total_inserted += _flush_batch(session, batch, "bill texts")
|
||||||
|
logger.info("Inserted %d new bill text versions total", total_inserted)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_bill_texts(
|
||||||
|
congress_dir: Path,
|
||||||
|
bill_map: dict[tuple[int, str, int], int],
|
||||||
|
existing_bill_texts: set[tuple[int, str]],
|
||||||
|
) -> Iterator[BillText]:
|
||||||
|
"""Yield BillText objects for a single congress directory, skipping existing."""
|
||||||
|
bills_dir = congress_dir / "bills"
|
||||||
|
if not bills_dir.is_dir():
|
||||||
|
return
|
||||||
|
|
||||||
|
for bill_dir in bills_dir.rglob("text-versions"):
|
||||||
|
if not bill_dir.is_dir():
|
||||||
|
continue
|
||||||
|
bill_key = _bill_key_from_dir(bill_dir.parent, congress_dir)
|
||||||
|
if bill_key is None:
|
||||||
|
continue
|
||||||
|
bill_id = bill_map.get(bill_key)
|
||||||
|
if bill_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for version_dir in sorted(bill_dir.iterdir()):
|
||||||
|
if not version_dir.is_dir():
|
||||||
|
continue
|
||||||
|
if (bill_id, version_dir.name) in existing_bill_texts:
|
||||||
|
continue
|
||||||
|
text_content = _read_bill_text(version_dir)
|
||||||
|
version_data = _read_json(version_dir / "data.json")
|
||||||
|
yield BillText(
|
||||||
|
bill_id=bill_id,
|
||||||
|
version_code=version_dir.name,
|
||||||
|
version_name=version_data.get("version_name") if version_data else None,
|
||||||
|
date=version_data.get("issued_on") if version_data else None,
|
||||||
|
text_content=text_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bill_key_from_dir(bill_dir: Path, congress_dir: Path) -> tuple[int, str, int] | None:
|
||||||
|
"""Extract (congress, bill_type, number) from directory structure."""
|
||||||
|
congress = int(congress_dir.name)
|
||||||
|
bill_type = bill_dir.parent.name
|
||||||
|
name = bill_dir.name
|
||||||
|
# Directory name is like "hr3590" — strip the type prefix to get the number
|
||||||
|
number_str = name[len(bill_type) :]
|
||||||
|
if not number_str.isdigit():
|
||||||
|
return None
|
||||||
|
return (congress, bill_type, int(number_str))
|
||||||
|
|
||||||
|
|
||||||
|
def _read_bill_text(version_dir: Path) -> str | None:
|
||||||
|
"""Read bill text from a version directory, preferring .txt over .xml."""
|
||||||
|
for extension in ("txt", "htm", "html", "xml"):
|
||||||
|
candidates = list(version_dir.glob(f"document.{extension}"))
|
||||||
|
if not candidates:
|
||||||
|
candidates = list(version_dir.glob(f"*.{extension}"))
|
||||||
|
if candidates:
|
||||||
|
try:
|
||||||
|
return candidates[0].read_text(encoding="utf-8")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to read %s", candidates[0])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _read_json(path: Path) -> dict | None:
|
||||||
|
"""Read and parse a JSON file, returning None on failure."""
|
||||||
|
try:
|
||||||
|
return orjson.loads(path.read_bytes())
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to parse %s", path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
@@ -0,0 +1,247 @@
|
|||||||
|
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ingest-posts /path/to/files/
|
||||||
|
ingest-posts /path/to/single_file.jsonl
|
||||||
|
ingest-posts /data/dir/ --workers 4 --batch-size 5000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path # noqa: TC003 this is needed for typer
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import psycopg
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from python.common import configure_logger
|
||||||
|
from python.orm.common import get_connection_info
|
||||||
|
from python.parallelize import parallelize_process
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
path: Annotated[Path, typer.Argument(help="Directory containing JSONL files, or a single JSONL file")],
|
||||||
|
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
|
||||||
|
workers: Annotated[int, typer.Option(help="Parallel workers for multi-file ingestion")] = 4,
|
||||||
|
pattern: Annotated[str, typer.Option(help="Glob pattern for JSONL files")] = "*.jsonl",
|
||||||
|
) -> None:
|
||||||
|
"""Ingest JSONL post files into the weekly-partitioned posts table."""
|
||||||
|
configure_logger(level="INFO")
|
||||||
|
|
||||||
|
logger.info("starting ingest-posts")
|
||||||
|
logger.info("path=%s batch_size=%d workers=%d pattern=%s", path, batch_size, workers, pattern)
|
||||||
|
if path.is_file():
|
||||||
|
ingest_file(path, batch_size=batch_size)
|
||||||
|
elif path.is_dir():
|
||||||
|
ingest_directory(path, batch_size=batch_size, max_workers=workers, pattern=pattern)
|
||||||
|
else:
|
||||||
|
typer.echo(f"Path does not exist: {path}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
logger.info("ingest-posts done")
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_directory(
|
||||||
|
directory: Path,
|
||||||
|
*,
|
||||||
|
batch_size: int,
|
||||||
|
max_workers: int,
|
||||||
|
pattern: str = "*.jsonl",
|
||||||
|
) -> None:
|
||||||
|
"""Ingest all JSONL files in a directory using parallel workers."""
|
||||||
|
files = sorted(directory.glob(pattern))
|
||||||
|
if not files:
|
||||||
|
logger.warning("No JSONL files found in %s", directory)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Found %d JSONL files to ingest", len(files))
|
||||||
|
|
||||||
|
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
|
||||||
|
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA = "main"
|
||||||
|
|
||||||
|
COLUMNS = (
|
||||||
|
"post_id",
|
||||||
|
"user_id",
|
||||||
|
"instance",
|
||||||
|
"date",
|
||||||
|
"text",
|
||||||
|
"langs",
|
||||||
|
"like_count",
|
||||||
|
"reply_count",
|
||||||
|
"repost_count",
|
||||||
|
"reply_to",
|
||||||
|
"replied_author",
|
||||||
|
"thread_root",
|
||||||
|
"thread_root_author",
|
||||||
|
"repost_from",
|
||||||
|
"reposted_author",
|
||||||
|
"quotes",
|
||||||
|
"quoted_author",
|
||||||
|
"labels",
|
||||||
|
"sent_label",
|
||||||
|
"sent_score",
|
||||||
|
)
|
||||||
|
|
||||||
|
INSERT_FROM_STAGING = f"""
|
||||||
|
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
|
||||||
|
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
|
||||||
|
ON CONFLICT (post_id, date) DO NOTHING
|
||||||
|
""" # noqa: S608
|
||||||
|
|
||||||
|
FAILED_INSERT = f"""
|
||||||
|
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
|
||||||
|
VALUES (%(raw_line)s, %(error)s)
|
||||||
|
""" # noqa: S608
|
||||||
|
|
||||||
|
|
||||||
|
def get_psycopg_connection() -> psycopg.Connection:
|
||||||
|
"""Create a raw psycopg3 connection from environment variables."""
|
||||||
|
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
|
||||||
|
return psycopg.connect(
|
||||||
|
dbname=database,
|
||||||
|
host=host,
|
||||||
|
port=int(port),
|
||||||
|
user=username,
|
||||||
|
password=password,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_file(path: Path, *, batch_size: int) -> None:
|
||||||
|
"""Ingest a single JSONL file into the posts table."""
|
||||||
|
log_trigger = max(100_000 // batch_size, 1)
|
||||||
|
failed_lines: list[dict] = []
|
||||||
|
try:
|
||||||
|
with get_psycopg_connection() as connection:
|
||||||
|
for index, batch in enumerate(read_jsonl_batches(path, batch_size, failed_lines), 1):
|
||||||
|
ingest_batch(connection, batch)
|
||||||
|
if index % log_trigger == 0:
|
||||||
|
logger.info("Ingested %d batches (%d rows) from %s", index, index * batch_size, path)
|
||||||
|
|
||||||
|
if failed_lines:
|
||||||
|
logger.warning("Recording %d malformed lines from %s", len(failed_lines), path.name)
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.executemany(FAILED_INSERT, failed_lines)
|
||||||
|
connection.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to ingest file: %s", path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
|
||||||
|
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
|
||||||
|
if not batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(f"""
|
||||||
|
CREATE TEMP TABLE IF NOT EXISTS staging
|
||||||
|
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
|
||||||
|
ON COMMIT DELETE ROWS
|
||||||
|
""")
|
||||||
|
cursor.execute("TRUNCATE pg_temp.staging")
|
||||||
|
|
||||||
|
with cursor.copy(f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN") as copy:
|
||||||
|
for row in batch:
|
||||||
|
copy.write_row(tuple(row.get(column) for column in COLUMNS))
|
||||||
|
|
||||||
|
cursor.execute(INSERT_FROM_STAGING)
|
||||||
|
connection.commit()
|
||||||
|
except Exception as error:
|
||||||
|
connection.rollback()
|
||||||
|
|
||||||
|
if len(batch) == 1:
|
||||||
|
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
FAILED_INSERT,
|
||||||
|
{
|
||||||
|
"raw_line": orjson.dumps(batch[0], default=str).decode(),
|
||||||
|
"error": str(error),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
midpoint = len(batch) // 2
|
||||||
|
ingest_batch(connection, batch[:midpoint])
|
||||||
|
ingest_batch(connection, batch[midpoint:])
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl_batches(file_path: Path, batch_size: int, failed_lines: list[dict]) -> Iterator[list[dict]]:
|
||||||
|
"""Stream a JSONL file and yield batches of transformed rows."""
|
||||||
|
batch: list[dict] = []
|
||||||
|
with file_path.open("r", encoding="utf-8") as handle:
|
||||||
|
for raw_line in handle:
|
||||||
|
line = raw_line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
batch.extend(parse_line(line, file_path, failed_lines))
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
|
||||||
|
"""Parse a JSONL line, handling concatenated JSON objects."""
|
||||||
|
try:
|
||||||
|
yield transform_row(orjson.loads(line))
|
||||||
|
except orjson.JSONDecodeError:
|
||||||
|
if "}{" not in line:
|
||||||
|
logger.warning("Skipping malformed line in %s: %s", file_path.name, line[:120])
|
||||||
|
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
|
||||||
|
return
|
||||||
|
fragments = line.replace("}{", "}\n{").split("\n")
|
||||||
|
for fragment in fragments:
|
||||||
|
try:
|
||||||
|
yield transform_row(orjson.loads(fragment))
|
||||||
|
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
|
||||||
|
logger.warning("Skipping malformed fragment in %s: %s", file_path.name, fragment[:120])
|
||||||
|
failed_lines.append({"raw_line": fragment, "error": str(error)})
|
||||||
|
except Exception as error:
|
||||||
|
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
|
||||||
|
failed_lines.append({"raw_line": line, "error": str(error)})
|
||||||
|
|
||||||
|
|
||||||
|
def transform_row(raw: dict) -> dict:
|
||||||
|
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
|
||||||
|
raw["date"] = parse_date(raw["date"])
|
||||||
|
if raw.get("langs") is not None:
|
||||||
|
raw["langs"] = orjson.dumps(raw["langs"])
|
||||||
|
if raw.get("text") is not None:
|
||||||
|
raw["text"] = raw["text"].replace("\x00", "")
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def parse_date(raw_date: int) -> datetime:
|
||||||
|
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
|
||||||
|
return datetime(
|
||||||
|
raw_date // 100000000,
|
||||||
|
(raw_date // 1000000) % 100,
|
||||||
|
(raw_date // 10000) % 100,
|
||||||
|
(raw_date // 100) % 100,
|
||||||
|
raw_date % 100,
|
||||||
|
tzinfo=UTC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
+29
-3
@@ -4,10 +4,12 @@ Usage:
|
|||||||
database <db_name> <command> [args...]
|
database <db_name> <command> [args...]
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
database van_inventory upgrade head
|
||||||
|
database van_inventory downgrade head-1
|
||||||
|
database van_inventory revision --autogenerate -m "add meals table"
|
||||||
|
database van_inventory check
|
||||||
database richie check
|
database richie check
|
||||||
database richie upgrade head
|
database richie upgrade head
|
||||||
database richie downgrade head-1
|
|
||||||
database richie revision --autogenerate -m "add meals table"
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -46,7 +48,10 @@ class DatabaseConfig:
|
|||||||
|
|
||||||
def alembic_config(self) -> Config:
|
def alembic_config(self) -> Config:
|
||||||
"""Build an alembic Config for this database."""
|
"""Build an alembic Config for this database."""
|
||||||
cfg = Config()
|
# Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
|
||||||
|
from alembic.config import Config as AlembicConfig # noqa: PLC0415
|
||||||
|
|
||||||
|
cfg = AlembicConfig()
|
||||||
cfg.set_main_option("script_location", self.script_location)
|
cfg.set_main_option("script_location", self.script_location)
|
||||||
cfg.set_main_option("file_template", self.file_template)
|
cfg.set_main_option("file_template", self.file_template)
|
||||||
cfg.set_main_option("prepend_sys_path", ".")
|
cfg.set_main_option("prepend_sys_path", ".")
|
||||||
@@ -71,6 +76,27 @@ DATABASES: dict[str, DatabaseConfig] = {
|
|||||||
base_class_name="RichieBase",
|
base_class_name="RichieBase",
|
||||||
models_module="python.orm.richie",
|
models_module="python.orm.richie",
|
||||||
),
|
),
|
||||||
|
"van_inventory": DatabaseConfig(
|
||||||
|
env_prefix="VAN_INVENTORY",
|
||||||
|
version_location="python/alembic/van_inventory/versions",
|
||||||
|
base_module="python.orm.van_inventory.base",
|
||||||
|
base_class_name="VanInventoryBase",
|
||||||
|
models_module="python.orm.van_inventory.models",
|
||||||
|
),
|
||||||
|
"signal_bot": DatabaseConfig(
|
||||||
|
env_prefix="SIGNALBOT",
|
||||||
|
version_location="python/alembic/signal_bot/versions",
|
||||||
|
base_module="python.orm.signal_bot.base",
|
||||||
|
base_class_name="SignalBotBase",
|
||||||
|
models_module="python.orm.signal_bot.models",
|
||||||
|
),
|
||||||
|
"data_science_dev": DatabaseConfig(
|
||||||
|
env_prefix="DATA_SCIENCE_DEV",
|
||||||
|
version_location="python/alembic/data_science_dev/versions",
|
||||||
|
base_module="python.orm.data_science_dev.base",
|
||||||
|
base_class_name="DataScienceDevBase",
|
||||||
|
models_module="python.orm.data_science_dev.models",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Self
|
from typing import Self
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
DEFAULT_PAGE_SIZE = 100
|
DEFAULT_PAGE_SIZE = 100
|
||||||
EXPECTED_NO_CONTENT = 204
|
|
||||||
EXPECTED_CREATED = 201
|
EXPECTED_CREATED = 201
|
||||||
EXPECTED_OK = 200
|
EXPECTED_OK = 200
|
||||||
|
|
||||||
@@ -224,16 +222,6 @@ class GiteaClient:
|
|||||||
json=payload,
|
json=payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_workflow(self, *, owner: str, repo: str, workflow_id: str, ref: str) -> None:
|
|
||||||
"""Trigger a workflow_dispatch run."""
|
|
||||||
workflow_path = quote(workflow_id, safe="")
|
|
||||||
self._request(
|
|
||||||
"POST",
|
|
||||||
f"/api/v1/repos/{owner}/{repo}/actions/workflows/{workflow_path}/dispatches",
|
|
||||||
expected_statuses={EXPECTED_OK, EXPECTED_NO_CONTENT},
|
|
||||||
json={"ref": ref},
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_run_jobs(self, *, owner: str, repo: str, run_id: str | int) -> list[WorkflowJob]:
|
def list_run_jobs(self, *, owner: str, repo: str, run_id: str | int) -> list[WorkflowJob]:
|
||||||
"""List workflow jobs for a specific run."""
|
"""List workflow jobs for a specific run."""
|
||||||
jobs: list[WorkflowJob] = []
|
jobs: list[WorkflowJob] = []
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ DEFAULT_BASE_BRANCH = "main"
|
|||||||
DEFAULT_BRANCH = "automation/update-flake-lock"
|
DEFAULT_BRANCH = "automation/update-flake-lock"
|
||||||
DEFAULT_GITEA_URL = "https://gitea.tmmworkshop.com"
|
DEFAULT_GITEA_URL = "https://gitea.tmmworkshop.com"
|
||||||
PR_LABELS = ["dependencies", "automated", "flake_lock_update"]
|
PR_LABELS = ["dependencies", "automated", "flake_lock_update"]
|
||||||
PR_CHECK_WORKFLOWS = ["build_systems.yml", "treefmt.yml", "pytest.yml"]
|
|
||||||
PR_TITLE = "Update flake.lock"
|
PR_TITLE = "Update flake.lock"
|
||||||
PR_BODY = "Automated flake.lock update."
|
PR_BODY = "Automated flake.lock update."
|
||||||
|
|
||||||
@@ -58,12 +57,6 @@ def find_flake_lock_pull_request(client: GiteaClient, *, owner: str, repo: str)
|
|||||||
return pull_requests[0]
|
return pull_requests[0]
|
||||||
|
|
||||||
|
|
||||||
def dispatch_pull_request_checks(client: GiteaClient, *, owner: str, repo: str, branch: str) -> None:
|
|
||||||
"""Dispatch the workflows that normally run for pull requests."""
|
|
||||||
for workflow in PR_CHECK_WORKFLOWS:
|
|
||||||
client.dispatch_workflow(owner=owner, repo=repo, workflow_id=workflow, ref=branch)
|
|
||||||
|
|
||||||
|
|
||||||
def has_worktree_changes() -> bool:
|
def has_worktree_changes() -> bool:
|
||||||
"""Return whether `flake.lock` has worktree changes."""
|
"""Return whether `flake.lock` has worktree changes."""
|
||||||
result = run_cmd(["git", "diff", "--quiet", "--", "flake.lock"], check=False)
|
result = run_cmd(["git", "diff", "--quiet", "--", "flake.lock"], check=False)
|
||||||
@@ -120,9 +113,6 @@ def update(
|
|||||||
branch=branch,
|
branch=branch,
|
||||||
base=base,
|
base=base,
|
||||||
)
|
)
|
||||||
# We can remove this if Gitea fixes the following issue:
|
|
||||||
# https://github.com/go-gitea/gitea/issues/33963
|
|
||||||
dispatch_pull_request_checks(client, owner=owner, repo=repo_name, branch=branch)
|
|
||||||
typer.echo(pull_request.html_url or f"Pull request #{pull_request.number}")
|
typer.echo(pull_request.html_url or f"Pull request #{pull_request.number}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""FastAPI heater control service."""
|
"""FastAPI heater control service."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -14,9 +13,6 @@ from python.common import configure_logger
|
|||||||
from python.heater.controller import HeaterController
|
from python.heater.controller import HeaterController
|
||||||
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
|
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -262,7 +262,6 @@ def installer(
|
|||||||
):
|
):
|
||||||
run(command, check=True, stdin=test.stdout)
|
run(command, check=True, stdin=test.stdout)
|
||||||
|
|
||||||
# Fixed mount point for the new system; the installer runs as root on a fresh disk
|
|
||||||
mnt_dir = "/tmp/nix_install" # noqa: S108
|
mnt_dir = "/tmp/nix_install" # noqa: S108
|
||||||
|
|
||||||
Path(mnt_dir).mkdir(parents=True, exist_ok=True)
|
Path(mnt_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
"""ORM package exports."""
|
"""ORM package exports."""
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevBase
|
||||||
from python.orm.richie.base import RichieBase
|
from python.orm.richie.base import RichieBase
|
||||||
|
from python.orm.signal_bot.base import SignalBotBase
|
||||||
|
from python.orm.van_inventory.base import VanInventoryBase
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DataScienceDevBase",
|
||||||
"RichieBase",
|
"RichieBase",
|
||||||
|
"SignalBotBase",
|
||||||
|
"VanInventoryBase",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
"""Data science dev database ORM exports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase, DataScienceDevTableBaseBig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataScienceDevBase",
|
||||||
|
"DataScienceDevTableBase",
|
||||||
|
"DataScienceDevTableBaseBig",
|
||||||
|
]
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
"""Data science dev database ORM base."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, DateTime, MetaData, func
|
||||||
|
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from python.orm.common import NAMING_CONVENTION
|
||||||
|
|
||||||
|
|
||||||
|
class DataScienceDevBase(DeclarativeBase):
|
||||||
|
"""Base class for data_science_dev database ORM models."""
|
||||||
|
|
||||||
|
schema_name = "main"
|
||||||
|
|
||||||
|
metadata = MetaData(
|
||||||
|
schema=schema_name,
|
||||||
|
naming_convention=NAMING_CONVENTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TableMixin:
|
||||||
|
"""Shared timestamp columns for all table bases."""
|
||||||
|
|
||||||
|
created: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
updated: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
|
||||||
|
"""Table with Integer primary key."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
|
||||||
|
"""Table with BigInteger primary key."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""init."""
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.congress.bill import Bill, BillText
|
||||||
|
from python.orm.data_science_dev.congress.legislator import Legislator, LegislatorSocialMedia
|
||||||
|
from python.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Bill",
|
||||||
|
"BillText",
|
||||||
|
"Legislator",
|
||||||
|
"LegislatorSocialMedia",
|
||||||
|
"Vote",
|
||||||
|
"VoteRecord",
|
||||||
|
]
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""Bill model - legislation introduced in Congress."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.orm.data_science_dev.congress.vote import Vote
|
||||||
|
|
||||||
|
|
||||||
|
class Bill(DataScienceDevTableBase):
|
||||||
|
"""Legislation with congress number, type, titles, status, and sponsor."""
|
||||||
|
|
||||||
|
__tablename__ = "bill"
|
||||||
|
|
||||||
|
congress: Mapped[int]
|
||||||
|
bill_type: Mapped[str]
|
||||||
|
number: Mapped[int]
|
||||||
|
|
||||||
|
title: Mapped[str | None]
|
||||||
|
title_short: Mapped[str | None]
|
||||||
|
official_title: Mapped[str | None]
|
||||||
|
|
||||||
|
status: Mapped[str | None]
|
||||||
|
status_at: Mapped[date | None]
|
||||||
|
|
||||||
|
sponsor_bioguide_id: Mapped[str | None]
|
||||||
|
|
||||||
|
subjects_top_term: Mapped[str | None]
|
||||||
|
|
||||||
|
votes: Mapped[list[Vote]] = relationship(
|
||||||
|
"Vote",
|
||||||
|
back_populates="bill",
|
||||||
|
)
|
||||||
|
bill_texts: Mapped[list[BillText]] = relationship(
|
||||||
|
"BillText",
|
||||||
|
back_populates="bill",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
|
||||||
|
Index("ix_bill_congress", "congress"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillText(DataScienceDevTableBase):
|
||||||
|
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "bill_text"
|
||||||
|
|
||||||
|
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||||
|
version_code: Mapped[str]
|
||||||
|
version_name: Mapped[str | None]
|
||||||
|
text_content: Mapped[str | None]
|
||||||
|
date: Mapped[date | None]
|
||||||
|
|
||||||
|
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
||||||
|
|
||||||
|
__table_args__ = (UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),)
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""Legislator model - members of Congress."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.orm.data_science_dev.congress.vote import VoteRecord
|
||||||
|
|
||||||
|
|
||||||
|
class Legislator(DataScienceDevTableBase):
|
||||||
|
"""Members of Congress with identification and current term info."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator"
|
||||||
|
|
||||||
|
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
|
||||||
|
|
||||||
|
thomas_id: Mapped[str | None]
|
||||||
|
lis_id: Mapped[str | None]
|
||||||
|
govtrack_id: Mapped[int | None]
|
||||||
|
opensecrets_id: Mapped[str | None]
|
||||||
|
fec_ids: Mapped[str | None]
|
||||||
|
|
||||||
|
first_name: Mapped[str]
|
||||||
|
last_name: Mapped[str]
|
||||||
|
official_full_name: Mapped[str | None]
|
||||||
|
nickname: Mapped[str | None]
|
||||||
|
|
||||||
|
birthday: Mapped[date | None]
|
||||||
|
gender: Mapped[str | None]
|
||||||
|
|
||||||
|
current_party: Mapped[str | None]
|
||||||
|
current_state: Mapped[str | None]
|
||||||
|
current_district: Mapped[int | None]
|
||||||
|
current_chamber: Mapped[str | None]
|
||||||
|
|
||||||
|
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
|
||||||
|
"LegislatorSocialMedia",
|
||||||
|
back_populates="legislator",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||||
|
"VoteRecord",
|
||||||
|
back_populates="legislator",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LegislatorSocialMedia(DataScienceDevTableBase):
|
||||||
|
"""Social media account linked to a legislator."""
|
||||||
|
|
||||||
|
__tablename__ = "legislator_social_media"
|
||||||
|
|
||||||
|
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
|
||||||
|
platform: Mapped[str]
|
||||||
|
account_name: Mapped[str]
|
||||||
|
url: Mapped[str | None]
|
||||||
|
source: Mapped[str]
|
||||||
|
|
||||||
|
legislator: Mapped[Legislator] = relationship(back_populates="social_media_accounts")
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
"""Vote model - roll call votes in Congress."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.orm.data_science_dev.congress.bill import Bill
|
||||||
|
from python.orm.data_science_dev.congress.legislator import Legislator
|
||||||
|
from python.orm.data_science_dev.congress.vote import Vote
|
||||||
|
|
||||||
|
|
||||||
|
class VoteRecord(DataScienceDevBase):
|
||||||
|
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "vote_record"
|
||||||
|
|
||||||
|
vote_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
legislator_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
position: Mapped[str]
|
||||||
|
|
||||||
|
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
|
||||||
|
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")
|
||||||
|
|
||||||
|
|
||||||
|
class Vote(DataScienceDevTableBase):
|
||||||
|
"""Roll call votes with counts and optional bill linkage."""
|
||||||
|
|
||||||
|
__tablename__ = "vote"
|
||||||
|
|
||||||
|
congress: Mapped[int]
|
||||||
|
chamber: Mapped[str]
|
||||||
|
session: Mapped[int]
|
||||||
|
number: Mapped[int]
|
||||||
|
|
||||||
|
vote_type: Mapped[str | None]
|
||||||
|
question: Mapped[str | None]
|
||||||
|
result: Mapped[str | None]
|
||||||
|
result_text: Mapped[str | None]
|
||||||
|
|
||||||
|
vote_date: Mapped[date]
|
||||||
|
|
||||||
|
yea_count: Mapped[int | None]
|
||||||
|
nay_count: Mapped[int | None]
|
||||||
|
not_voting_count: Mapped[int | None]
|
||||||
|
present_count: Mapped[int | None]
|
||||||
|
|
||||||
|
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
|
||||||
|
|
||||||
|
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
|
||||||
|
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||||
|
"VoteRecord",
|
||||||
|
back_populates="vote",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"congress",
|
||||||
|
"chamber",
|
||||||
|
"session",
|
||||||
|
"number",
|
||||||
|
name="uq_vote_congress_chamber_session_number",
|
||||||
|
),
|
||||||
|
Index("ix_vote_date", "vote_date"),
|
||||||
|
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
||||||
|
)
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""Data science dev database ORM models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
|
||||||
|
from python.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
||||||
|
from python.orm.data_science_dev.posts.tables import Posts
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Bill",
|
||||||
|
"BillText",
|
||||||
|
"Legislator",
|
||||||
|
"Posts",
|
||||||
|
"Vote",
|
||||||
|
"VoteRecord",
|
||||||
|
]
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
"""Posts module — weekly-partitioned posts table and partition ORM models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
|
||||||
|
from python.orm.data_science_dev.posts.tables import Posts
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FailedIngestion",
|
||||||
|
"Posts",
|
||||||
|
]
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""Shared column definitions for the posts partitioned table family."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, SmallInteger, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class PostsColumns:
|
||||||
|
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
|
||||||
|
|
||||||
|
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger)
|
||||||
|
instance: Mapped[str]
|
||||||
|
date: Mapped[datetime] = mapped_column(primary_key=True)
|
||||||
|
text: Mapped[str] = mapped_column(Text)
|
||||||
|
langs: Mapped[str | None]
|
||||||
|
like_count: Mapped[int]
|
||||||
|
reply_count: Mapped[int]
|
||||||
|
repost_count: Mapped[int]
|
||||||
|
reply_to: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
replied_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
thread_root: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
repost_from: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
quotes: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
labels: Mapped[str | None]
|
||||||
|
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
|
||||||
|
sent_score: Mapped[float | None]
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""Table for storing JSONL lines that failed during post ingestion."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevTableBase
|
||||||
|
|
||||||
|
|
||||||
|
class FailedIngestion(DataScienceDevTableBase):
|
||||||
|
"""Stores raw JSONL lines and their error messages when ingestion fails."""
|
||||||
|
|
||||||
|
__tablename__ = "failed_ingestion"
|
||||||
|
|
||||||
|
raw_line: Mapped[str] = mapped_column(Text)
|
||||||
|
error: Mapped[str] = mapped_column(Text)
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Dynamically generated ORM classes for each weekly partition of the posts table.
|
||||||
|
|
||||||
|
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
|
||||||
|
These are real ORM models tracked by Alembic autogenerate.
|
||||||
|
|
||||||
|
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
|
||||||
|
52 or 53 weeks, and week boundaries are always Monday to Monday.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevBase
|
||||||
|
from python.orm.data_science_dev.posts.columns import PostsColumns
|
||||||
|
|
||||||
|
PARTITION_START_YEAR = 2023
|
||||||
|
PARTITION_END_YEAR = 2026
|
||||||
|
|
||||||
|
_current_module = sys.modules[__name__]
|
||||||
|
|
||||||
|
|
||||||
|
def iso_weeks_in_year(year: int) -> int:
|
||||||
|
"""Return the number of ISO weeks in a given year (52 or 53)."""
|
||||||
|
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
|
||||||
|
return dec_28.isocalendar().week
|
||||||
|
|
||||||
|
|
||||||
|
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
|
||||||
|
"""Return (start, end) datetimes for an ISO week.
|
||||||
|
|
||||||
|
Start = Monday 00:00:00 UTC of the given ISO week.
|
||||||
|
End = Monday 00:00:00 UTC of the following ISO week.
|
||||||
|
"""
|
||||||
|
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
|
||||||
|
if week < iso_weeks_in_year(year):
|
||||||
|
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
|
||||||
|
else:
|
||||||
|
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def _build_partition_classes() -> dict[str, type]:
|
||||||
|
"""Generate one ORM class per ISO week partition."""
|
||||||
|
classes: dict[str, type] = {}
|
||||||
|
|
||||||
|
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||||
|
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||||
|
class_name = f"PostsWeek{year}W{week:02d}"
|
||||||
|
table_name = f"posts_{year}_{week:02d}"
|
||||||
|
|
||||||
|
partition_class = type(
|
||||||
|
class_name,
|
||||||
|
(PostsColumns, DataScienceDevBase),
|
||||||
|
{
|
||||||
|
"__tablename__": table_name,
|
||||||
|
"__table_args__": ({"implicit_returning": False},),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
classes[class_name] = partition_class
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
|
||||||
|
# Generate all partition classes and register them on this module
|
||||||
|
_partition_classes = _build_partition_classes()
|
||||||
|
for _name, _cls in _partition_classes.items():
|
||||||
|
setattr(_current_module, _name, _cls)
|
||||||
|
__all__ = list(_partition_classes.keys())
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from python.orm.data_science_dev.base import DataScienceDevBase
|
||||||
|
from python.orm.data_science_dev.posts.columns import PostsColumns
|
||||||
|
|
||||||
|
|
||||||
|
class Posts(PostsColumns, DataScienceDevBase):
|
||||||
|
"""Parent partitioned table for posts, partitioned by week on `date`."""
|
||||||
|
|
||||||
|
__tablename__ = "posts"
|
||||||
|
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)
|
||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from python.orm.richie.audiobook import Audiobook, AudiobookAuthor, AudiobookSeries
|
|
||||||
from python.orm.richie.base import RichieBase, TableBase, TableBaseBig, TableBaseSmall
|
from python.orm.richie.base import RichieBase, TableBase, TableBaseBig, TableBaseSmall
|
||||||
from python.orm.richie.contact import (
|
from python.orm.richie.contact import (
|
||||||
Contact,
|
Contact,
|
||||||
@@ -13,9 +12,6 @@ from python.orm.richie.contact import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Audiobook",
|
|
||||||
"AudiobookAuthor",
|
|
||||||
"AudiobookSeries",
|
|
||||||
"Contact",
|
"Contact",
|
||||||
"ContactNeed",
|
"ContactNeed",
|
||||||
"ContactRelationship",
|
"ContactRelationship",
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
"""Audiobook catalog models."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, UniqueConstraint
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from python.orm.richie.base import TableBase
|
|
||||||
|
|
||||||
|
|
||||||
class AudiobookAuthor(TableBase):
|
|
||||||
"""Canonical audiobook author."""
|
|
||||||
|
|
||||||
__tablename__ = "audiobook_author"
|
|
||||||
__table_args__ = (UniqueConstraint("name"),)
|
|
||||||
|
|
||||||
name: Mapped[str]
|
|
||||||
|
|
||||||
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="author")
|
|
||||||
series: Mapped[list[AudiobookSeries]] = relationship("AudiobookSeries", back_populates="author")
|
|
||||||
|
|
||||||
|
|
||||||
class AudiobookSeries(TableBase):
|
|
||||||
"""Canonical audiobook series."""
|
|
||||||
|
|
||||||
__tablename__ = "audiobook_series"
|
|
||||||
__table_args__ = (UniqueConstraint("author_id", "name"),)
|
|
||||||
|
|
||||||
name: Mapped[str]
|
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
|
|
||||||
|
|
||||||
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="series")
|
|
||||||
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="series")
|
|
||||||
|
|
||||||
|
|
||||||
class Audiobook(TableBase):
|
|
||||||
"""Canonical audiobook title."""
|
|
||||||
|
|
||||||
__tablename__ = "audiobook"
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint(
|
|
||||||
"author_id",
|
|
||||||
"series_id",
|
|
||||||
"title",
|
|
||||||
postgresql_nulls_not_distinct=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
title: Mapped[str]
|
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
|
|
||||||
series_id: Mapped[int | None] = mapped_column(ForeignKey("main.audiobook_series.id", ondelete="SET NULL"))
|
|
||||||
series_index: Mapped[float] = mapped_column(default=0.0)
|
|
||||||
|
|
||||||
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="books")
|
|
||||||
series: Mapped[AudiobookSeries | None] = relationship("AudiobookSeries", back_populates="books")
|
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""Signal bot database ORM exports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from python.orm.signal_bot.base import SignalBotBase, SignalBotTableBase, SignalBotTableBaseSmall
|
||||||
|
from python.orm.signal_bot.models import DeadLetterMessage, DeviceRole, RoleRecord, SignalDevice
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DeadLetterMessage",
|
||||||
|
"DeviceRole",
|
||||||
|
"RoleRecord",
|
||||||
|
"SignalBotBase",
|
||||||
|
"SignalBotTableBase",
|
||||||
|
"SignalBotTableBaseSmall",
|
||||||
|
"SignalDevice",
|
||||||
|
]
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
"""Signal bot database ORM base."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, MetaData, SmallInteger, func
|
||||||
|
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from python.orm.common import NAMING_CONVENTION
|
||||||
|
|
||||||
|
|
||||||
|
class SignalBotBase(DeclarativeBase):
|
||||||
|
"""Base class for signal_bot database ORM models."""
|
||||||
|
|
||||||
|
schema_name = "main"
|
||||||
|
|
||||||
|
metadata = MetaData(
|
||||||
|
schema=schema_name,
|
||||||
|
naming_convention=NAMING_CONVENTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TableMixin:
|
||||||
|
"""Shared timestamp columns for all table bases."""
|
||||||
|
|
||||||
|
created: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
updated: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalBotTableBaseSmall(_TableMixin, AbstractConcreteBase, SignalBotBase):
|
||||||
|
"""Table with SmallInteger primary key."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(SmallInteger, primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalBotTableBase(_TableMixin, AbstractConcreteBase, SignalBotBase):
|
||||||
|
"""Table with Integer primary key."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
"""Signal bot device, role, and dead letter ORM models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
|
||||||
|
from python.signal_bot.models import MessageStatus, TrustLevel
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRecord(SignalBotTableBaseSmall):
|
||||||
|
"""Lookup table for RBAC roles, keyed by smallint."""
|
||||||
|
|
||||||
|
__tablename__ = "role"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceRole(SignalBotTableBase):
|
||||||
|
"""Association between a device and a role."""
|
||||||
|
|
||||||
|
__tablename__ = "device_role"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"),
|
||||||
|
{"schema": "main"},
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id: Mapped[int] = mapped_column(ForeignKey("main.signal_device.id"))
|
||||||
|
role_id: Mapped[int] = mapped_column(SmallInteger, ForeignKey("main.role.id"))
|
||||||
|
|
||||||
|
|
||||||
|
class SignalDevice(SignalBotTableBase):
|
||||||
|
"""A Signal device tracked by phone number and safety number."""
|
||||||
|
|
||||||
|
__tablename__ = "signal_device"
|
||||||
|
|
||||||
|
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
safety_number: Mapped[str | None]
|
||||||
|
trust_level: Mapped[TrustLevel] = mapped_column(
|
||||||
|
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
|
||||||
|
default=TrustLevel.UNVERIFIED,
|
||||||
|
)
|
||||||
|
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
roles: Mapped[list[RoleRecord]] = relationship(secondary=DeviceRole.__table__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeadLetterMessage(SignalBotTableBase):
|
||||||
|
"""A Signal message that failed processing and was sent to the dead letter queue."""
|
||||||
|
|
||||||
|
__tablename__ = "dead_letter_message"
|
||||||
|
|
||||||
|
source: Mapped[str]
|
||||||
|
message: Mapped[str] = mapped_column(Text)
|
||||||
|
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||||
|
status: Mapped[MessageStatus] = mapped_column(
|
||||||
|
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
|
||||||
|
default=MessageStatus.UNPROCESSED,
|
||||||
|
)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Van inventory database ORM exports."""
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
"""Van inventory database ORM base."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, MetaData, func
|
||||||
|
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from python.orm.common import NAMING_CONVENTION
|
||||||
|
|
||||||
|
|
||||||
|
class VanInventoryBase(DeclarativeBase):
|
||||||
|
"""Base class for van_inventory database ORM models."""
|
||||||
|
|
||||||
|
schema_name = "main"
|
||||||
|
|
||||||
|
metadata = MetaData(
|
||||||
|
schema=schema_name,
|
||||||
|
naming_convention=NAMING_CONVENTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VanTableBase(AbstractConcreteBase, VanInventoryBase):
|
||||||
|
"""Abstract concrete base for van_inventory tables with IDs and timestamps."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
created: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
updated: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Van inventory ORM models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from python.orm.van_inventory.base import VanTableBase
|
||||||
|
|
||||||
|
|
||||||
|
class Item(VanTableBase):
|
||||||
|
"""A food item in the van."""
|
||||||
|
|
||||||
|
__tablename__ = "items"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(unique=True)
|
||||||
|
quantity: Mapped[float] = mapped_column(default=0)
|
||||||
|
unit: Mapped[str]
|
||||||
|
category: Mapped[str | None]
|
||||||
|
|
||||||
|
meal_ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="item")
|
||||||
|
|
||||||
|
|
||||||
|
class Meal(VanTableBase):
|
||||||
|
"""A meal that can be made from items in the van."""
|
||||||
|
|
||||||
|
__tablename__ = "meals"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(unique=True)
|
||||||
|
instructions: Mapped[str | None]
|
||||||
|
|
||||||
|
ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="meal")
|
||||||
|
|
||||||
|
|
||||||
|
class MealIngredient(VanTableBase):
|
||||||
|
"""Links a meal to the items it requires, with quantities."""
|
||||||
|
|
||||||
|
__tablename__ = "meal_ingredients"
|
||||||
|
__table_args__ = (UniqueConstraint("meal_id", "item_id"),)
|
||||||
|
|
||||||
|
meal_id: Mapped[int] = mapped_column(ForeignKey("meals.id"))
|
||||||
|
item_id: Mapped[int] = mapped_column(ForeignKey("items.id"))
|
||||||
|
quantity_needed: Mapped[float]
|
||||||
|
|
||||||
|
meal: Mapped[Meal] = relationship(back_populates="ingredients")
|
||||||
|
item: Mapped[Item] = relationship(back_populates="meal_ingredients")
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Signal command and control bot."""
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Signal bot commands."""
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
"""Van inventory command — parse receipts and item lists via LLM, push to API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from python.signal_bot.models import InventoryItem, InventoryUpdate
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.signal_bot.llm_client import LLMClient
|
||||||
|
from python.signal_bot.models import SignalMessage
|
||||||
|
from python.signal_bot.signal_client import SignalClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """\
|
||||||
|
You are an inventory assistant. Extract items from the input and return ONLY
|
||||||
|
a JSON array. Each element must have these fields:
|
||||||
|
- "name": item name (string)
|
||||||
|
- "quantity": numeric count or amount (default 1)
|
||||||
|
- "unit": unit of measure (e.g. "each", "lb", "oz", "gallon", "bag", "box")
|
||||||
|
- "category": category like "food", "tools", "supplies", etc.
|
||||||
|
- "notes": any extra detail (empty string if none)
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
[{"name": "water bottles", "quantity": 6, "unit": "gallon", "category": "supplies", "notes": "1 gallon each"}]
|
||||||
|
|
||||||
|
Return ONLY the JSON array, no other text.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
IMAGE_PROMPT = "Extract all items from this receipt or inventory photo."
|
||||||
|
TEXT_PROMPT = "Extract all items from this inventory list."
|
||||||
|
|
||||||
|
|
||||||
|
def parse_llm_response(raw: str) -> list[InventoryItem]:
|
||||||
|
"""Parse the LLM JSON response into InventoryItem list."""
|
||||||
|
text = raw.strip()
|
||||||
|
# Strip markdown code fences if present
|
||||||
|
if text.startswith("```"):
|
||||||
|
lines = text.split("\n")
|
||||||
|
lines = [line for line in lines if not line.startswith("```")]
|
||||||
|
text = "\n".join(lines)
|
||||||
|
|
||||||
|
items_data: list[dict[str, Any]] = json.loads(text)
|
||||||
|
return [InventoryItem.model_validate(item) for item in items_data]
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_item(api_url: str, item: InventoryItem) -> None:
|
||||||
|
"""Create or update an item via the van_inventory API.
|
||||||
|
|
||||||
|
Fetches existing items, and if one with the same name exists,
|
||||||
|
patches its quantity (summing). Otherwise creates a new item.
|
||||||
|
"""
|
||||||
|
base = api_url.rstrip("/")
|
||||||
|
response = httpx.get(f"{base}/api/items", timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
existing: list[dict[str, Any]] = response.json()
|
||||||
|
|
||||||
|
match = next((e for e in existing if e["name"].lower() == item.name.lower()), None)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
new_qty = match["quantity"] + item.quantity
|
||||||
|
patch = {"quantity": new_qty}
|
||||||
|
if item.category:
|
||||||
|
patch["category"] = item.category
|
||||||
|
response = httpx.patch(f"{base}/api/items/{match['id']}", json=patch, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
return
|
||||||
|
payload = {
|
||||||
|
"name": item.name,
|
||||||
|
"quantity": item.quantity,
|
||||||
|
"unit": item.unit,
|
||||||
|
"category": item.category or None,
|
||||||
|
}
|
||||||
|
response = httpx.post(f"{base}/api/items", json=payload, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_inventory_update(
|
||||||
|
message: SignalMessage,
|
||||||
|
signal: SignalClient,
|
||||||
|
llm: LLMClient,
|
||||||
|
api_url: str,
|
||||||
|
) -> InventoryUpdate:
|
||||||
|
"""Process an inventory update from a Signal message.
|
||||||
|
|
||||||
|
Accepts either an image (receipt photo) or text list.
|
||||||
|
Uses the LLM to extract structured items, then pushes to the van_inventory API.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Processing inventory update from {message.source}")
|
||||||
|
if message.attachments:
|
||||||
|
image_data = signal.get_attachment(message.attachments[0])
|
||||||
|
raw_response = llm.chat(
|
||||||
|
IMAGE_PROMPT,
|
||||||
|
image_data=image_data,
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
source_type = "receipt_photo"
|
||||||
|
elif message.message.strip():
|
||||||
|
raw_response = llm.chat(
|
||||||
|
f"{TEXT_PROMPT}\n\n{message.message}",
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
source_type = "text_list"
|
||||||
|
else:
|
||||||
|
signal.reply(message, "Send a photo of a receipt or a text list of items to update inventory.")
|
||||||
|
return InventoryUpdate()
|
||||||
|
|
||||||
|
logger.info(f"{raw_response=}")
|
||||||
|
|
||||||
|
new_items = parse_llm_response(raw_response)
|
||||||
|
|
||||||
|
logger.info(f"{new_items=}")
|
||||||
|
|
||||||
|
for item in new_items:
|
||||||
|
_upsert_item(api_url, item)
|
||||||
|
|
||||||
|
summary = _format_summary(new_items)
|
||||||
|
signal.reply(message, f"Inventory updated with {len(new_items)} item(s):\n{summary}")
|
||||||
|
|
||||||
|
return InventoryUpdate(items=new_items, raw_response=raw_response, source_type=source_type)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to process inventory update")
|
||||||
|
signal.reply(message, "Failed to process inventory update. Check logs for details.")
|
||||||
|
return InventoryUpdate()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_summary(items: list[InventoryItem]) -> str:
|
||||||
|
"""Format items into a readable summary."""
|
||||||
|
lines = [f" - {item.name} x{item.quantity} {item.unit} [{item.category}]" for item in items]
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""Location command for the Signal bot."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.signal_bot.models import SignalMessage
|
||||||
|
from python.signal_bot.signal_client import SignalClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_entity_state(ha_url: str, ha_token: str, entity_id: str) -> dict[str, Any]:
|
||||||
|
"""Fetch an entity's state from Home Assistant."""
|
||||||
|
entity_url = f"{ha_url}/api/states/{entity_id}"
|
||||||
|
logger.debug(f"Fetching {entity_url=}")
|
||||||
|
response = httpx.get(
|
||||||
|
entity_url,
|
||||||
|
headers={"Authorization": f"Bearer {ha_token}"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_location(latitude: str, longitude: str) -> str:
|
||||||
|
"""Render a friendly location response."""
|
||||||
|
return f"Van location: {latitude}, {longitude}\nhttps://maps.google.com/?q={latitude},{longitude}"
|
||||||
|
|
||||||
|
|
||||||
|
def handle_location_request(
|
||||||
|
message: SignalMessage,
|
||||||
|
signal: SignalClient,
|
||||||
|
ha_url: str | None,
|
||||||
|
ha_token: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Reply with van location from Home Assistant."""
|
||||||
|
if ha_url is None or ha_token is None:
|
||||||
|
signal.reply(message, "Location command is not configured (missing HA_URL or HA_TOKEN).")
|
||||||
|
return
|
||||||
|
|
||||||
|
lat_payload = None
|
||||||
|
lon_payload = None
|
||||||
|
try:
|
||||||
|
lat_payload = _get_entity_state(ha_url, ha_token, "sensor.van_last_known_latitude")
|
||||||
|
lon_payload = _get_entity_state(ha_url, ha_token, "sensor.van_last_known_longitude")
|
||||||
|
except httpx.HTTPError:
|
||||||
|
logger.exception("Couldn't fetch van location from Home Assistant right now.")
|
||||||
|
logger.debug(f"{ha_url=} {lat_payload=} {lon_payload=}")
|
||||||
|
signal.reply(message, "Couldn't fetch van location from Home Assistant right now.")
|
||||||
|
return
|
||||||
|
|
||||||
|
latitude = lat_payload.get("state", "")
|
||||||
|
longitude = lon_payload.get("state", "")
|
||||||
|
|
||||||
|
if not latitude or not longitude or latitude == "unavailable" or longitude == "unavailable":
|
||||||
|
signal.reply(message, "Van location is unavailable in Home Assistant right now.")
|
||||||
|
return
|
||||||
|
|
||||||
|
signal.reply(message, _format_location(latitude, longitude))
|
||||||
@@ -0,0 +1,284 @@
|
|||||||
|
"""Device registry — tracks verified/unverified devices by safety number."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import TYPE_CHECKING, NamedTuple
|
||||||
|
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from python.common import utcnow
|
||||||
|
from python.orm.signal_bot.models import RoleRecord, SignalDevice
|
||||||
|
from python.signal_bot.models import Role, TrustLevel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
from python.signal_bot.signal_client import SignalClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_BLOCKED_TTL = timedelta(minutes=60)
|
||||||
|
_DEFAULT_TTL = timedelta(minutes=5)
|
||||||
|
|
||||||
|
|
||||||
|
class _CacheEntry(NamedTuple):
|
||||||
|
expires: datetime
|
||||||
|
trust_level: TrustLevel
|
||||||
|
has_safety_number: bool
|
||||||
|
safety_number: str | None
|
||||||
|
roles: list[Role]
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceRegistry:
|
||||||
|
"""Manage device trust based on Signal safety numbers.
|
||||||
|
|
||||||
|
Devices start as UNVERIFIED. An admin verifies them over SSH by calling
|
||||||
|
``verify(phone_number)`` which marks the device VERIFIED and also tells
|
||||||
|
signal-cli to trust the identity.
|
||||||
|
|
||||||
|
Only VERIFIED devices may execute commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
||||||
|
self.signal_client = signal_client
|
||||||
|
self.engine = engine
|
||||||
|
self._contact_cache: dict[str, _CacheEntry] = {}
|
||||||
|
|
||||||
|
def is_verified(self, phone_number: str) -> bool:
|
||||||
|
"""Check if a phone number is verified."""
|
||||||
|
if entry := self._cached(phone_number):
|
||||||
|
return entry.trust_level == TrustLevel.VERIFIED
|
||||||
|
device = self._load_device(phone_number)
|
||||||
|
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
||||||
|
|
||||||
|
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
||||||
|
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
||||||
|
now = utcnow()
|
||||||
|
|
||||||
|
entry = self._cached(phone_number)
|
||||||
|
if entry and entry.safety_number == safety_number:
|
||||||
|
return
|
||||||
|
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
device = session.scalars(
|
||||||
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
if device:
|
||||||
|
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
||||||
|
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
|
||||||
|
device.safety_number = safety_number
|
||||||
|
device.trust_level = TrustLevel.UNVERIFIED
|
||||||
|
device.last_seen = now
|
||||||
|
else:
|
||||||
|
device = SignalDevice(
|
||||||
|
phone_number=phone_number,
|
||||||
|
safety_number=safety_number,
|
||||||
|
trust_level=TrustLevel.UNVERIFIED,
|
||||||
|
last_seen=now,
|
||||||
|
)
|
||||||
|
session.add(device)
|
||||||
|
logger.info(f"New device registered: {phone_number}")
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
self._update_cache(phone_number, device)
|
||||||
|
|
||||||
|
def has_safety_number(self, phone_number: str) -> bool:
|
||||||
|
"""Check if a device has a safety number on file."""
|
||||||
|
if entry := self._cached(phone_number):
|
||||||
|
return entry.has_safety_number
|
||||||
|
device = self._load_device(phone_number)
|
||||||
|
return device is not None and device.safety_number is not None
|
||||||
|
|
||||||
|
def verify(self, phone_number: str) -> bool:
|
||||||
|
"""Mark a device as verified. Called by admin over SSH.
|
||||||
|
|
||||||
|
Returns True if the device was found and verified.
|
||||||
|
"""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
device = session.scalars(
|
||||||
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
device.trust_level = TrustLevel.VERIFIED
|
||||||
|
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
||||||
|
session.commit()
|
||||||
|
self._update_cache(phone_number, device)
|
||||||
|
logger.info(f"Device verified: {phone_number}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def block(self, phone_number: str) -> bool:
|
||||||
|
"""Block a device."""
|
||||||
|
return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
|
||||||
|
|
||||||
|
def unverify(self, phone_number: str) -> bool:
|
||||||
|
"""Reset a device to unverified."""
|
||||||
|
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
|
||||||
|
|
||||||
|
# -- role management ------------------------------------------------------
|
||||||
|
|
||||||
|
def get_roles(self, phone_number: str) -> list[Role]:
|
||||||
|
"""Return the roles for a device, defaulting to empty."""
|
||||||
|
if entry := self._cached(phone_number):
|
||||||
|
return entry.roles
|
||||||
|
device = self._load_device(phone_number)
|
||||||
|
return _extract_roles(device) if device else []
|
||||||
|
|
||||||
|
def has_role(self, phone_number: str, role: Role) -> bool:
|
||||||
|
"""Check if a device has a specific role or is admin."""
|
||||||
|
roles = self.get_roles(phone_number)
|
||||||
|
return Role.ADMIN in roles or role in roles
|
||||||
|
|
||||||
|
def grant_role(self, phone_number: str, role: Role) -> bool:
|
||||||
|
"""Add a role to a device. Called by admin over SSH."""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
device = session.scalars(
|
||||||
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if any(record.name == role for record in device.roles):
|
||||||
|
return True
|
||||||
|
|
||||||
|
role_record = session.scalars(select(RoleRecord).where(RoleRecord.name == role)).one_or_none()
|
||||||
|
|
||||||
|
if not role_record:
|
||||||
|
logger.warning(f"Unknown role: {role}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
device.roles.append(role_record)
|
||||||
|
session.commit()
|
||||||
|
self._update_cache(phone_number, device)
|
||||||
|
logger.info(f"Device {phone_number} granted role {role}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def revoke_role(self, phone_number: str, role: Role) -> bool:
|
||||||
|
"""Remove a role from a device. Called by admin over SSH."""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
device = session.scalars(
|
||||||
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
device.roles = [record for record in device.roles if record.name != role]
|
||||||
|
session.commit()
|
||||||
|
self._update_cache(phone_number, device)
|
||||||
|
logger.info(f"Device {phone_number} revoked role {role}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
|
||||||
|
"""Replace all roles for a device. Called by admin over SSH."""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
device = session.scalars(
|
||||||
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
role_names = [str(role) for role in roles]
|
||||||
|
records = session.scalars(select(RoleRecord).where(RoleRecord.name.in_(role_names))).all()
|
||||||
|
device.roles = records
|
||||||
|
session.commit()
|
||||||
|
self._update_cache(phone_number, device)
|
||||||
|
logger.info(f"Device {phone_number} roles set to {role_names}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# -- queries --------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_devices(self) -> list[SignalDevice]:
|
||||||
|
"""Return all known devices."""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
return list(session.scalars(select(SignalDevice)).all())
|
||||||
|
|
||||||
|
def sync_identities(self) -> None:
|
||||||
|
"""Pull identity list from signal-cli and record any new ones."""
|
||||||
|
identities = self.signal_client.get_identities()
|
||||||
|
for identity in identities:
|
||||||
|
number = identity.get("number", "")
|
||||||
|
safety = identity.get("safety_number", identity.get("fingerprint", ""))
|
||||||
|
if number:
|
||||||
|
self.record_contact(number, safety)
|
||||||
|
|
||||||
|
# -- internals ------------------------------------------------------------
|
||||||
|
|
||||||
|
def _cached(self, phone_number: str) -> _CacheEntry | None:
|
||||||
|
"""Return the cache entry if it exists and hasn't expired."""
|
||||||
|
entry = self._contact_cache.get(phone_number)
|
||||||
|
if entry and utcnow() < entry.expires:
|
||||||
|
return entry
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_device(self, phone_number: str) -> SignalDevice | None:
|
||||||
|
"""Fetch a device by phone number (with joined roles)."""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
return session.scalars(select(SignalDevice).where(SignalDevice.phone_number == phone_number)).one_or_none()
|
||||||
|
|
||||||
|
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
|
||||||
|
"""Refresh the cache entry for a device."""
|
||||||
|
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
||||||
|
self._contact_cache[phone_number] = _CacheEntry(
|
||||||
|
expires=utcnow() + ttl,
|
||||||
|
trust_level=device.trust_level,
|
||||||
|
has_safety_number=device.safety_number is not None,
|
||||||
|
safety_number=device.safety_number,
|
||||||
|
roles=_extract_roles(device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
||||||
|
"""Update the trust level for a device."""
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
device = session.scalars(
|
||||||
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
return False
|
||||||
|
|
||||||
|
device.trust_level = level
|
||||||
|
session.commit()
|
||||||
|
self._update_cache(phone_number, device)
|
||||||
|
if log_msg:
|
||||||
|
logger.info(f"{log_msg}: {phone_number}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_roles(device: SignalDevice) -> list[Role]:
|
||||||
|
"""Convert a device's RoleRecord objects to a list of Role enums."""
|
||||||
|
return [Role(record.name) for record in device.roles]
|
||||||
|
|
||||||
|
|
||||||
|
def sync_roles(engine: Engine) -> None:
|
||||||
|
"""Sync the Role enum to the role table, adding new and removing stale entries."""
|
||||||
|
expected = {role.value for role in Role}
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
existing = set(session.scalars(select(RoleRecord.name)).all())
|
||||||
|
|
||||||
|
to_add = expected - existing
|
||||||
|
to_remove = existing - expected
|
||||||
|
|
||||||
|
for name in to_add:
|
||||||
|
session.add(RoleRecord(name=name))
|
||||||
|
logger.info(f"Role added: {name}")
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove)))
|
||||||
|
for name in to_remove:
|
||||||
|
logger.info(f"Role removed: {name}")
|
||||||
|
|
||||||
|
session.commit()
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
"""Flexible LLM client for ollama backends."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import Any, Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""Talk to an ollama instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Ollama model name.
|
||||||
|
host: Ollama host.
|
||||||
|
port: Ollama port.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
host: str,
|
||||||
|
port: int = 11434,
|
||||||
|
temperature: float = 0.1,
|
||||||
|
timeout: int = 300,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
|
||||||
|
|
||||||
|
def chat(self, prompt: str, image_data: bytes | None = None, system: str | None = None) -> str:
|
||||||
|
"""Send a text prompt and return the response."""
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
|
||||||
|
user_msg = {"role": "user", "content": prompt}
|
||||||
|
if image_data:
|
||||||
|
user_msg["images"] = [base64.b64encode(image_data).decode()]
|
||||||
|
|
||||||
|
messages.append(user_msg)
|
||||||
|
return self._generate(messages)
|
||||||
|
|
||||||
|
def _generate(self, messages: list[dict[str, Any]]) -> str:
|
||||||
|
"""Call the ollama chat API."""
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": self.temperature},
|
||||||
|
}
|
||||||
|
logger.info(f"LLM request to {self.model}")
|
||||||
|
response = self._client.post("/api/chat", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data["message"]["content"]
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List available models on the ollama instance."""
|
||||||
|
response = self._client.get("/api/tags")
|
||||||
|
response.raise_for_status()
|
||||||
|
return [m["name"] for m in response.json().get("models", [])]
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""Enter the context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: object) -> None:
|
||||||
|
"""Close the HTTP client on exit."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
self._client.close()
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
"""Signal command and control bot — main entry point."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from os import getenv
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from alembic.command import upgrade
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
|
from python.common import configure_logger, utcnow
|
||||||
|
from python.database_cli import DATABASES
|
||||||
|
from python.orm.common import get_postgres_engine
|
||||||
|
from python.orm.signal_bot.models import DeadLetterMessage
|
||||||
|
from python.signal_bot.commands.inventory import handle_inventory_update
|
||||||
|
from python.signal_bot.commands.location import handle_location_request
|
||||||
|
from python.signal_bot.device_registry import DeviceRegistry, sync_roles
|
||||||
|
from python.signal_bot.llm_client import LLMClient
|
||||||
|
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
|
||||||
|
from python.signal_bot.signal_client import SignalClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class Command:
|
||||||
|
"""A registered bot command."""
|
||||||
|
|
||||||
|
action: Callable[[SignalMessage, str], None]
|
||||||
|
help_text: str
|
||||||
|
role: Role | None # None = no role required (always allowed)
|
||||||
|
|
||||||
|
|
||||||
|
class Bot:
|
||||||
|
"""Holds shared resources and dispatches incoming messages to command handlers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
signal: SignalClient,
|
||||||
|
llm: LLMClient,
|
||||||
|
registry: DeviceRegistry,
|
||||||
|
config: BotConfig,
|
||||||
|
) -> None:
|
||||||
|
self.signal = signal
|
||||||
|
self.llm = llm
|
||||||
|
self.registry = registry
|
||||||
|
self.config = config
|
||||||
|
self.commands: dict[str, Command] = {
|
||||||
|
"help": Command(action=self._help, help_text="show this help message", role=None),
|
||||||
|
"status": Command(action=self._status, help_text="show bot status", role=Role.STATUS),
|
||||||
|
"inventory": Command(
|
||||||
|
action=self._inventory,
|
||||||
|
help_text="update van inventory from a text list or receipt photo",
|
||||||
|
role=Role.INVENTORY,
|
||||||
|
),
|
||||||
|
"location": Command(
|
||||||
|
action=self._location,
|
||||||
|
help_text="get current van location",
|
||||||
|
role=Role.LOCATION,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# -- actions --------------------------------------------------------------
|
||||||
|
|
||||||
|
def _help(self, message: SignalMessage, _cmd: str) -> None:
|
||||||
|
"""Return help text filtered to the sender's roles."""
|
||||||
|
self.signal.reply(message, self._build_help(self.registry.get_roles(message.source)))
|
||||||
|
|
||||||
|
def _status(self, message: SignalMessage, _cmd: str) -> None:
|
||||||
|
"""Return the status of the bot."""
|
||||||
|
models = self.llm.list_models()
|
||||||
|
model_list = ", ".join(models[:10])
|
||||||
|
device_count = len(self.registry.list_devices())
|
||||||
|
self.signal.reply(
|
||||||
|
message,
|
||||||
|
f"Bot online.\nLLM: {self.llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _inventory(self, message: SignalMessage, _cmd: str) -> None:
|
||||||
|
"""Process an inventory update."""
|
||||||
|
handle_inventory_update(message, self.signal, self.llm, self.config.inventory_api_url)
|
||||||
|
|
||||||
|
def _location(self, message: SignalMessage, _cmd: str) -> None:
|
||||||
|
"""Reply with current van location."""
|
||||||
|
handle_location_request(message, self.signal, self.config.ha_url, self.config.ha_token)
|
||||||
|
|
||||||
|
# -- dispatch -------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_help(self, roles: list[Role]) -> str:
|
||||||
|
"""Build help text showing only the commands the user can access."""
|
||||||
|
is_admin = Role.ADMIN in roles
|
||||||
|
lines = ["Available commands:"]
|
||||||
|
for name, cmd in self.commands.items():
|
||||||
|
if cmd.role is None or is_admin or cmd.role in roles:
|
||||||
|
lines.append(f" {name:20s} — {cmd.help_text}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def dispatch(self, message: SignalMessage) -> None:
|
||||||
|
"""Route an incoming message to the right command handler."""
|
||||||
|
source = message.source
|
||||||
|
|
||||||
|
if not self.registry.is_verified(source):
|
||||||
|
logger.info(f"Device {source} not verified, ignoring message")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
|
||||||
|
logger.warning(f"Admin device {source} missing safety number, ignoring message")
|
||||||
|
return
|
||||||
|
|
||||||
|
text = message.message.strip()
|
||||||
|
parts = text.split()
|
||||||
|
|
||||||
|
if not parts and not message.attachments:
|
||||||
|
return
|
||||||
|
|
||||||
|
cmd = parts[0].lower() if parts else ""
|
||||||
|
|
||||||
|
logger.info(f"f{source=} running {cmd=} with {message=}")
|
||||||
|
|
||||||
|
command = self.commands.get(cmd)
|
||||||
|
if command is None:
|
||||||
|
if message.attachments:
|
||||||
|
command = self.commands["inventory"]
|
||||||
|
cmd = "inventory"
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if command.role is not None and not self.registry.has_role(source, command.role):
|
||||||
|
logger.warning(f"Device {source} denied access to {cmd!r}")
|
||||||
|
self.signal.reply(message, f"Permission denied: you do not have the '{command.role}' role.")
|
||||||
|
return
|
||||||
|
|
||||||
|
command.action(message, cmd)
|
||||||
|
|
||||||
|
def process_message(self, message: SignalMessage) -> None:
|
||||||
|
"""Process a single message, sending it to the dead letter queue after repeated failures."""
|
||||||
|
max_attempts = self.config.max_message_attempts
|
||||||
|
for attempt in range(1, max_attempts + 1):
|
||||||
|
try:
|
||||||
|
safety_number = self.signal.get_safety_number(message.source)
|
||||||
|
self.registry.record_contact(message.source, safety_number)
|
||||||
|
self.dispatch(message)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue")
|
||||||
|
with Session(self.config.engine) as session:
|
||||||
|
session.add(
|
||||||
|
DeadLetterMessage(
|
||||||
|
source=message.source,
|
||||||
|
message=message.message,
|
||||||
|
received_at=utcnow(),
|
||||||
|
status=MessageStatus.UNPROCESSED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Listen for messages via WebSocket, reconnecting on failure."""
|
||||||
|
logger.info("Bot started — listening via WebSocket")
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(self.config.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=self.config.reconnect_delay, max=self.config.max_reconnect_delay),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def _listen() -> None:
|
||||||
|
for message in self.signal.listen():
|
||||||
|
logger.info(f"Message from {message.source}: {message.message[:80]}")
|
||||||
|
self.process_message(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_listen()
|
||||||
|
except Exception:
|
||||||
|
logger.critical("Max retries exceeded, shutting down")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
log_level: Annotated[str, typer.Option()] = "DEBUG",
|
||||||
|
llm_timeout: Annotated[int, typer.Option()] = 600,
|
||||||
|
) -> None:
|
||||||
|
"""Run the Signal command and control bot."""
|
||||||
|
configure_logger(log_level)
|
||||||
|
signal_api_url = getenv("SIGNAL_API_URL")
|
||||||
|
phone_number = getenv("SIGNAL_PHONE_NUMBER")
|
||||||
|
inventory_api_url = getenv("INVENTORY_API_URL")
|
||||||
|
|
||||||
|
if signal_api_url is None:
|
||||||
|
error = "SIGNAL_API_URL environment variable not set"
|
||||||
|
raise ValueError(error)
|
||||||
|
if phone_number is None:
|
||||||
|
error = "SIGNAL_PHONE_NUMBER environment variable not set"
|
||||||
|
raise ValueError(error)
|
||||||
|
if inventory_api_url is None:
|
||||||
|
error = "INVENTORY_API_URL environment variable not set"
|
||||||
|
raise ValueError(error)
|
||||||
|
|
||||||
|
signal_bot_config = DATABASES["signal_bot"].alembic_config()
|
||||||
|
upgrade(signal_bot_config, "head")
|
||||||
|
engine = get_postgres_engine(name="SIGNALBOT")
|
||||||
|
sync_roles(engine)
|
||||||
|
config = BotConfig(
|
||||||
|
signal_api_url=signal_api_url,
|
||||||
|
phone_number=phone_number,
|
||||||
|
inventory_api_url=inventory_api_url,
|
||||||
|
ha_url=getenv("HA_URL"),
|
||||||
|
ha_token=getenv("HA_TOKEN"),
|
||||||
|
engine=engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_host = getenv("LLM_HOST")
|
||||||
|
llm_model = getenv("LLM_MODEL", "qwen3-vl:32b")
|
||||||
|
llm_port = int(getenv("LLM_PORT", "11434"))
|
||||||
|
if llm_host is None:
|
||||||
|
error = "LLM_HOST environment variable not set"
|
||||||
|
raise ValueError(error)
|
||||||
|
|
||||||
|
with (
|
||||||
|
SignalClient(config.signal_api_url, config.phone_number) as signal,
|
||||||
|
LLMClient(model=llm_model, host=llm_host, port=llm_port, timeout=llm_timeout) as llm,
|
||||||
|
):
|
||||||
|
registry = DeviceRegistry(signal, engine)
|
||||||
|
bot = Bot(signal, llm, registry, config)
|
||||||
|
bot.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
"""Models for the Signal command and control bot."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime # noqa: TC003 - pydantic needs this at runtime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy.engine import Engine # noqa: TC002 - pydantic needs this at runtime
|
||||||
|
|
||||||
|
|
||||||
|
class TrustLevel(StrEnum):
|
||||||
|
"""Device trust level."""
|
||||||
|
|
||||||
|
VERIFIED = "verified"
|
||||||
|
UNVERIFIED = "unverified"
|
||||||
|
BLOCKED = "blocked"
|
||||||
|
|
||||||
|
|
||||||
|
class Role(StrEnum):
|
||||||
|
"""RBAC roles — one per command, plus admin which grants all."""
|
||||||
|
|
||||||
|
ADMIN = "admin"
|
||||||
|
STATUS = "status"
|
||||||
|
INVENTORY = "inventory"
|
||||||
|
LOCATION = "location"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageStatus(StrEnum):
|
||||||
|
"""Dead letter queue message status."""
|
||||||
|
|
||||||
|
UNPROCESSED = "unprocessed"
|
||||||
|
PROCESSED = "processed"
|
||||||
|
|
||||||
|
|
||||||
|
class Device(BaseModel):
|
||||||
|
"""A registered device tracked by safety number."""
|
||||||
|
|
||||||
|
phone_number: str
|
||||||
|
safety_number: str
|
||||||
|
trust_level: TrustLevel = TrustLevel.UNVERIFIED
|
||||||
|
first_seen: datetime
|
||||||
|
last_seen: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class SignalMessage(BaseModel):
|
||||||
|
"""An incoming Signal message."""
|
||||||
|
|
||||||
|
source: str
|
||||||
|
timestamp: int
|
||||||
|
message: str = ""
|
||||||
|
attachments: list[str] = []
|
||||||
|
group_id: str | None = None
|
||||||
|
is_receipt: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SignalEnvelope(BaseModel):
|
||||||
|
"""Raw envelope from signal-cli-rest-api."""
|
||||||
|
|
||||||
|
envelope: dict[str, Any]
|
||||||
|
account: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InventoryItem(BaseModel):
|
||||||
|
"""An item in the van inventory."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
quantity: float = 1
|
||||||
|
unit: str = "each"
|
||||||
|
category: str = ""
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class InventoryUpdate(BaseModel):
|
||||||
|
"""Result of processing an inventory update."""
|
||||||
|
|
||||||
|
items: list[InventoryItem] = []
|
||||||
|
raw_response: str = ""
|
||||||
|
source_type: str = "" # "receipt_photo" or "text_list"
|
||||||
|
|
||||||
|
|
||||||
|
class BotConfig(BaseModel):
|
||||||
|
"""Top-level bot configuration."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
signal_api_url: str
|
||||||
|
phone_number: str
|
||||||
|
inventory_api_url: str
|
||||||
|
ha_url: str | None = None
|
||||||
|
ha_token: str | None = None
|
||||||
|
engine: Engine
|
||||||
|
reconnect_delay: int = 5
|
||||||
|
max_reconnect_delay: int = 300
|
||||||
|
max_retries: int = 10
|
||||||
|
max_message_attempts: int = 3
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
"""Client for the signal-cli-rest-api."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import websockets.sync.client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from python.signal_bot.models import SignalMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_envelope(envelope: dict[str, Any]) -> SignalMessage | None:
|
||||||
|
"""Parse a signal-cli envelope into a SignalMessage, or None if not a data message."""
|
||||||
|
data_message = envelope.get("dataMessage")
|
||||||
|
if not data_message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
attachment_ids = [att["id"] for att in data_message.get("attachments", []) if "id" in att]
|
||||||
|
|
||||||
|
group_info = data_message.get("groupInfo")
|
||||||
|
group_id = group_info.get("groupId") if group_info else None
|
||||||
|
|
||||||
|
return SignalMessage(
|
||||||
|
source=envelope.get("source", ""),
|
||||||
|
timestamp=envelope.get("timestamp", 0),
|
||||||
|
message=data_message.get("message", "") or "",
|
||||||
|
attachments=attachment_ids,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalClient:
|
||||||
|
"""Communicate with signal-cli-rest-api.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: URL of the signal-cli-rest-api (e.g. http://localhost:8989).
|
||||||
|
phone_number: The registered phone number to send/receive as.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, phone_number: str) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.phone_number = phone_number
|
||||||
|
self._client = httpx.Client(base_url=self.base_url, timeout=30)
|
||||||
|
|
||||||
|
def _ws_url(self) -> str:
|
||||||
|
"""Build the WebSocket URL from the base HTTP URL."""
|
||||||
|
url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||||
|
return f"{url}/v1/receive/{self.phone_number}"
|
||||||
|
|
||||||
|
def listen(self) -> Generator[SignalMessage]:
|
||||||
|
"""Connect via WebSocket and yield messages as they arrive."""
|
||||||
|
ws_url = self._ws_url()
|
||||||
|
logger.info(f"Connecting to WebSocket: {ws_url}")
|
||||||
|
|
||||||
|
with websockets.sync.client.connect(ws_url) as ws:
|
||||||
|
for raw in ws:
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
envelope = data.get("envelope", {})
|
||||||
|
message = _parse_envelope(envelope)
|
||||||
|
if message:
|
||||||
|
yield message
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Non-JSON WebSocket frame: {raw[:200]}")
|
||||||
|
|
||||||
|
def send(self, recipient: str, message: str) -> None:
|
||||||
|
"""Send a text message."""
|
||||||
|
payload = {
|
||||||
|
"message": message,
|
||||||
|
"number": self.phone_number,
|
||||||
|
"recipients": [recipient],
|
||||||
|
}
|
||||||
|
response = self._client.post("/v2/send", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def send_to_group(self, group_id: str, message: str) -> None:
|
||||||
|
"""Send a message to a group."""
|
||||||
|
payload = {
|
||||||
|
"message": message,
|
||||||
|
"number": self.phone_number,
|
||||||
|
"recipients": [group_id],
|
||||||
|
}
|
||||||
|
response = self._client.post("/v2/send", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def get_attachment(self, attachment_id: str) -> bytes:
|
||||||
|
"""Download an attachment by ID."""
|
||||||
|
response = self._client.get(f"/v1/attachments/{attachment_id}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
def get_identities(self) -> list[dict[str, Any]]:
|
||||||
|
"""List known identities and their trust levels."""
|
||||||
|
response = self._client.get(f"/v1/identities/{self.phone_number}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def get_safety_number(self, phone_number: str) -> str | None:
|
||||||
|
"""Look up the safety number for a contact from signal-cli's local store."""
|
||||||
|
for identity in self.get_identities():
|
||||||
|
if identity.get("number") == phone_number:
|
||||||
|
return identity.get("safety_number", identity.get("fingerprint", ""))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def trust_identity(self, number_to_trust: str, *, trust_all_known_keys: bool = False) -> None:
|
||||||
|
"""Trust an identity (verify safety number)."""
|
||||||
|
payload: dict[str, Any] = {}
|
||||||
|
if trust_all_known_keys:
|
||||||
|
payload["trust_all_known_keys"] = True
|
||||||
|
response = self._client.put(
|
||||||
|
f"/v1/identities/{self.phone_number}/trust/{number_to_trust}",
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def reply(self, message: SignalMessage, text: str) -> None:
|
||||||
|
"""Reply to a message, routing to group or individual."""
|
||||||
|
if message.group_id:
|
||||||
|
self.send_to_group(message.group_id, text)
|
||||||
|
else:
|
||||||
|
self.send(message.source, text)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""Enter the context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: object) -> None:
|
||||||
|
"""Close the HTTP client on exit."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
self._client.close()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
game_data/
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""init."""
|
||||||
@@ -0,0 +1,675 @@
|
|||||||
|
"""Base logic for the Splendor game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Literal, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
GemColor = Literal["white", "blue", "green", "red", "black", "gold"]
|
||||||
|
|
||||||
|
GEM_COLORS: tuple[GemColor, ...] = (
|
||||||
|
"white",
|
||||||
|
"blue",
|
||||||
|
"green",
|
||||||
|
"red",
|
||||||
|
"black",
|
||||||
|
"gold",
|
||||||
|
)
|
||||||
|
BASE_COLORS: tuple[GemColor, ...] = (
|
||||||
|
"white",
|
||||||
|
"blue",
|
||||||
|
"green",
|
||||||
|
"red",
|
||||||
|
"black",
|
||||||
|
)
|
||||||
|
|
||||||
|
GEM_ORDER: list[GemColor] = list(GEM_COLORS)
|
||||||
|
GEM_INDEX: dict[GemColor, int] = {c: i for i, c in enumerate(GEM_ORDER)}
|
||||||
|
BASE_INDEX: dict[GemColor, int] = {c: i for i, c in enumerate(BASE_COLORS)}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Card:
|
||||||
|
"""Development card: gives points + a permanent gem discount."""
|
||||||
|
|
||||||
|
tier: int
|
||||||
|
points: int
|
||||||
|
color: GemColor
|
||||||
|
cost: dict[GemColor, int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Noble:
|
||||||
|
"""Noble tile: gives points if you have enough bonuses."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
points: int
|
||||||
|
requirements: dict[GemColor, int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlayerState:
|
||||||
|
"""State of a player in the game."""
|
||||||
|
|
||||||
|
strategy: Strategy
|
||||||
|
tokens: dict[GemColor, int] = field(default_factory=lambda: dict.fromkeys(GEM_COLORS, 0))
|
||||||
|
discounts: dict[GemColor, int] = field(default_factory=lambda: dict.fromkeys(GEM_COLORS, 0))
|
||||||
|
cards: list[Card] = field(default_factory=list)
|
||||||
|
reserved: list[Card] = field(default_factory=list)
|
||||||
|
nobles: list[Noble] = field(default_factory=list)
|
||||||
|
card_score: int = 0
|
||||||
|
noble_score: int = 0
|
||||||
|
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Total tokens in player's bank."""
|
||||||
|
return sum(self.tokens.values())
|
||||||
|
|
||||||
|
def add_noble(self, noble: Noble) -> None:
|
||||||
|
"""Add a noble to the player."""
|
||||||
|
self.nobles.append(noble)
|
||||||
|
self.noble_score = sum(noble.points for noble in self.nobles)
|
||||||
|
|
||||||
|
def add_card(self, card: Card) -> None:
|
||||||
|
"""Add a card to the player."""
|
||||||
|
self.cards.append(card)
|
||||||
|
self.card_score = sum(card.points for card in self.cards)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> int:
|
||||||
|
"""Total points in player's cards + nobles."""
|
||||||
|
return self.card_score + self.noble_score
|
||||||
|
|
||||||
|
def can_afford(self, card: Card) -> bool:
|
||||||
|
"""Check if player can afford card, using discounts + gold."""
|
||||||
|
missing = 0
|
||||||
|
gold = self.tokens["gold"]
|
||||||
|
|
||||||
|
for color, cost in card.cost.items():
|
||||||
|
missing += max(0, cost - self.discounts.get(color, 0) - self.tokens.get(color, 0))
|
||||||
|
if missing > gold:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def pay_for_card(self, card: Card) -> dict[GemColor, int]:
|
||||||
|
"""Pay tokens for card, move card to tableau, return payment for bank."""
|
||||||
|
if not self.can_afford(card):
|
||||||
|
msg = f"cannot afford card {card}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
payment: dict[GemColor, int] = dict.fromkeys(GEM_COLORS, 0)
|
||||||
|
gold_available = self.tokens["gold"]
|
||||||
|
|
||||||
|
for color in BASE_COLORS:
|
||||||
|
cost = card.cost.get(color, 0)
|
||||||
|
effective_cost = max(0, cost - self.discounts.get(color, 0))
|
||||||
|
|
||||||
|
use = min(self.tokens[color], effective_cost)
|
||||||
|
self.tokens[color] -= use
|
||||||
|
payment[color] += use
|
||||||
|
|
||||||
|
remaining = effective_cost - use
|
||||||
|
if remaining > 0:
|
||||||
|
use_gold = min(gold_available, remaining)
|
||||||
|
gold_available -= use_gold
|
||||||
|
self.tokens["gold"] -= use_gold
|
||||||
|
payment["gold"] += use_gold
|
||||||
|
|
||||||
|
self.add_card(card)
|
||||||
|
self.discounts[card.color] += 1
|
||||||
|
return payment
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_starting_tokens(player_count: int) -> dict[GemColor, int]:
|
||||||
|
"""get_default_starting_tokens."""
|
||||||
|
token_count = (player_count * player_count - 3 * player_count + 10) // 2
|
||||||
|
return {
|
||||||
|
"white": token_count,
|
||||||
|
"blue": token_count,
|
||||||
|
"green": token_count,
|
||||||
|
"red": token_count,
|
||||||
|
"black": token_count,
|
||||||
|
"gold": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GameConfig:
|
||||||
|
"""Game configuration: gems, bank, cards, nobles, etc."""
|
||||||
|
|
||||||
|
win_score: int = 15
|
||||||
|
table_cards_per_tier: int = 4
|
||||||
|
reserve_limit: int = 3
|
||||||
|
token_limit: int = 10
|
||||||
|
turn_limit: int = 1000
|
||||||
|
minimum_tokens_to_buy_2: int = 4
|
||||||
|
max_token_take: int = 3
|
||||||
|
|
||||||
|
cards: list[Card] = field(default_factory=list)
|
||||||
|
nobles: list[Noble] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class GameState:
|
||||||
|
"""Game state: players, bank, decks, table, available nobles, etc."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GameConfig,
|
||||||
|
players: list[PlayerState],
|
||||||
|
bank: dict[GemColor, int],
|
||||||
|
decks_by_tier: dict[int, list[Card]],
|
||||||
|
table_by_tier: dict[int, list[Card]],
|
||||||
|
available_nobles: list[Noble],
|
||||||
|
) -> None:
|
||||||
|
"""Game state."""
|
||||||
|
self.config = config
|
||||||
|
self.players = players
|
||||||
|
self.bank = bank
|
||||||
|
self.decks_by_tier = decks_by_tier
|
||||||
|
self.table_by_tier = table_by_tier
|
||||||
|
self.available_nobles = available_nobles
|
||||||
|
self.noble_min_requirements = 0
|
||||||
|
self.get_noble_min_requirements()
|
||||||
|
self.current_player_index = 0
|
||||||
|
self.finished = False
|
||||||
|
|
||||||
|
def get_noble_min_requirements(self) -> None:
|
||||||
|
"""Find the minimum requirement for all available nobles."""
|
||||||
|
test = 0
|
||||||
|
|
||||||
|
for noble in self.available_nobles:
|
||||||
|
test = max(test, min(foo for foo in noble.requirements.values()))
|
||||||
|
|
||||||
|
self.noble_min_requirements = test
|
||||||
|
|
||||||
|
def next_player(self) -> None:
|
||||||
|
"""Advance to the next player."""
|
||||||
|
self.current_player_index = (self.current_player_index + 1) % len(self.players)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_player(self) -> PlayerState:
|
||||||
|
"""Current player."""
|
||||||
|
return self.players[self.current_player_index]
|
||||||
|
|
||||||
|
def refill_table(self) -> None:
|
||||||
|
"""Refill face-up cards from decks."""
|
||||||
|
for tier, deck in self.decks_by_tier.items():
|
||||||
|
table = self.table_by_tier[tier]
|
||||||
|
while len(table) < self.config.table_cards_per_tier and deck:
|
||||||
|
table.append(deck.pop())
|
||||||
|
|
||||||
|
def check_winner_simple(self) -> PlayerState | None:
|
||||||
|
"""Simplified: end immediately when someone hits win_score."""
|
||||||
|
eligible = [player for player in self.players if player.score >= self.config.win_score]
|
||||||
|
if not eligible:
|
||||||
|
return None
|
||||||
|
eligible.sort(
|
||||||
|
key=lambda p: (p.score, -len(p.cards)),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
self.finished = True
|
||||||
|
return eligible[0]
|
||||||
|
|
||||||
|
|
||||||
|
class Action:
|
||||||
|
"""Marker protocol for actions."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TakeDifferent(Action):
|
||||||
|
"""Take up to 3 different gem colors."""
|
||||||
|
|
||||||
|
colors: list[GemColor]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TakeDouble(Action):
|
||||||
|
"""Take two of the same color."""
|
||||||
|
|
||||||
|
color: GemColor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BuyCard(Action):
|
||||||
|
"""Buy a face-up card."""
|
||||||
|
|
||||||
|
tier: int
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BuyCardReserved(Action):
|
||||||
|
"""Buy a face-up card."""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReserveCard(Action):
|
||||||
|
"""Reserve a face-up card."""
|
||||||
|
|
||||||
|
tier: int
|
||||||
|
index: int | None = None
|
||||||
|
from_deck: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class Strategy(Protocol):
|
||||||
|
"""Implement this to make a bot or human controller."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialize a strategy."""
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Return an Action, or None to concede/end."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int,
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Called if player has more than token_limit tokens after an action.
|
||||||
|
|
||||||
|
Default: naive auto-discard.
|
||||||
|
"""
|
||||||
|
return auto_discard_tokens(player, excess)
|
||||||
|
|
||||||
|
def choose_noble(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState, # noqa: ARG002
|
||||||
|
nobles: list[Noble],
|
||||||
|
) -> Noble:
|
||||||
|
"""Called if player qualifies for multiple nobles. Default: first."""
|
||||||
|
return nobles[0]
|
||||||
|
|
||||||
|
|
||||||
|
def auto_discard_tokens(player: PlayerState, excess: int) -> dict[GemColor, int]:
|
||||||
|
"""Very dumb discard logic: discard from colors you have the most of."""
|
||||||
|
to_discard: dict[GemColor, int] = dict.fromkeys(GEM_COLORS, 0)
|
||||||
|
remaining = excess
|
||||||
|
while remaining > 0:
|
||||||
|
color = max(player.tokens, key=lambda c: player.tokens[c])
|
||||||
|
if player.tokens[color] == 0:
|
||||||
|
break
|
||||||
|
player.tokens[color] -= 1
|
||||||
|
to_discard[color] += 1
|
||||||
|
remaining -= 1
|
||||||
|
return to_discard
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_token_limit(
|
||||||
|
game: GameState,
|
||||||
|
strategy: Strategy,
|
||||||
|
player: PlayerState,
|
||||||
|
) -> None:
|
||||||
|
"""If player has more than token_limit tokens, force discards."""
|
||||||
|
limit = game.config.token_limit
|
||||||
|
total = player.total_tokens()
|
||||||
|
if total <= limit:
|
||||||
|
return
|
||||||
|
excess = total - limit
|
||||||
|
discards = strategy.choose_discard(game, player, excess)
|
||||||
|
for color, amount in discards.items():
|
||||||
|
available = player.tokens[color]
|
||||||
|
to_remove = min(amount, available)
|
||||||
|
if to_remove <= 0:
|
||||||
|
continue
|
||||||
|
player.tokens[color] -= to_remove
|
||||||
|
game.bank[color] += to_remove
|
||||||
|
remaining = player.total_tokens() - limit
|
||||||
|
if remaining > 0:
|
||||||
|
auto = auto_discard_tokens(player, remaining)
|
||||||
|
for color, amount in auto.items():
|
||||||
|
game.bank[color] += amount
|
||||||
|
|
||||||
|
|
||||||
|
def _check_nobles_for_player(player: PlayerState, noble: Noble) -> bool:
|
||||||
|
# this rule is slower
|
||||||
|
for color, cost in noble.requirements.items(): # noqa: SIM110
|
||||||
|
if player.discounts[color] < cost:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_nobles_for_player(
|
||||||
|
game: GameState,
|
||||||
|
strategy: Strategy,
|
||||||
|
player: PlayerState,
|
||||||
|
) -> None:
|
||||||
|
"""Award at most one noble to player if they qualify."""
|
||||||
|
if game.noble_min_requirements > max(player.discounts.values()):
|
||||||
|
return
|
||||||
|
|
||||||
|
candidates = [noble for noble in game.available_nobles if _check_nobles_for_player(player, noble)]
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return
|
||||||
|
|
||||||
|
chosen = candidates[0] if len(candidates) == 1 else strategy.choose_noble(game, player, candidates)
|
||||||
|
|
||||||
|
if chosen not in game.available_nobles:
|
||||||
|
return
|
||||||
|
game.available_nobles.remove(chosen)
|
||||||
|
game.get_noble_min_requirements()
|
||||||
|
|
||||||
|
player.add_noble(chosen)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_take_different(game: GameState, strategy: Strategy, action: TakeDifferent) -> None:
|
||||||
|
"""Mutate game state according to action."""
|
||||||
|
player = game.current_player
|
||||||
|
|
||||||
|
colors = [color for color in action.colors if color in BASE_COLORS and game.bank[color] > 0]
|
||||||
|
if not (1 <= len(colors) <= game.config.max_token_take):
|
||||||
|
return
|
||||||
|
|
||||||
|
for color in colors:
|
||||||
|
game.bank[color] -= 1
|
||||||
|
player.tokens[color] += 1
|
||||||
|
|
||||||
|
enforce_token_limit(game, strategy, player)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_take_double(game: GameState, strategy: Strategy, action: TakeDouble) -> None:
|
||||||
|
"""Mutate game state according to action."""
|
||||||
|
player = game.current_player
|
||||||
|
color = action.color
|
||||||
|
if color not in BASE_COLORS:
|
||||||
|
return
|
||||||
|
if game.bank[color] < game.config.minimum_tokens_to_buy_2:
|
||||||
|
return
|
||||||
|
game.bank[color] -= 2
|
||||||
|
player.tokens[color] += 2
|
||||||
|
enforce_token_limit(game, strategy, player)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_buy_card(game: GameState, _strategy: Strategy, action: BuyCard) -> None:
|
||||||
|
"""Mutate game state according to action."""
|
||||||
|
player = game.current_player
|
||||||
|
|
||||||
|
row = game.table_by_tier.get(action.tier)
|
||||||
|
if row is None or not (0 <= action.index < len(row)):
|
||||||
|
return
|
||||||
|
card = row[action.index]
|
||||||
|
if not player.can_afford(card):
|
||||||
|
return
|
||||||
|
row.pop(action.index)
|
||||||
|
payment = player.pay_for_card(card)
|
||||||
|
for color, amount in payment.items():
|
||||||
|
game.bank[color] += amount
|
||||||
|
game.refill_table()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_buy_card_reserved(game: GameState, _strategy: Strategy, action: BuyCardReserved) -> None:
|
||||||
|
"""Mutate game state according to action."""
|
||||||
|
player = game.current_player
|
||||||
|
if not (0 <= action.index < len(player.reserved)):
|
||||||
|
return
|
||||||
|
card = player.reserved[action.index]
|
||||||
|
if not player.can_afford(card):
|
||||||
|
return
|
||||||
|
player.reserved.pop(action.index)
|
||||||
|
payment = player.pay_for_card(card)
|
||||||
|
for color, amount in payment.items():
|
||||||
|
game.bank[color] += amount
|
||||||
|
|
||||||
|
|
||||||
|
def apply_reserve_card(game: GameState, strategy: Strategy, action: ReserveCard) -> None:
|
||||||
|
"""Mutate game state according to action."""
|
||||||
|
player = game.current_player
|
||||||
|
|
||||||
|
if len(player.reserved) >= game.config.reserve_limit:
|
||||||
|
return
|
||||||
|
|
||||||
|
card: Card | None = None
|
||||||
|
if action.from_deck:
|
||||||
|
deck = game.decks_by_tier.get(action.tier)
|
||||||
|
if deck:
|
||||||
|
card = deck.pop()
|
||||||
|
else:
|
||||||
|
row = game.table_by_tier.get(action.tier)
|
||||||
|
if row is None:
|
||||||
|
return
|
||||||
|
if action.index is None or not (0 <= action.index < len(row)):
|
||||||
|
return
|
||||||
|
card = row.pop(action.index)
|
||||||
|
game.refill_table()
|
||||||
|
|
||||||
|
if card is None:
|
||||||
|
return
|
||||||
|
player.reserved.append(card)
|
||||||
|
|
||||||
|
if game.bank["gold"] > 0:
|
||||||
|
game.bank["gold"] -= 1
|
||||||
|
player.tokens["gold"] += 1
|
||||||
|
enforce_token_limit(game, strategy, player)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_action(game: GameState, strategy: Strategy, action: Action) -> None:
|
||||||
|
"""Mutate game state according to action."""
|
||||||
|
actions = {
|
||||||
|
TakeDifferent: apply_take_different,
|
||||||
|
TakeDouble: apply_take_double,
|
||||||
|
BuyCard: apply_buy_card,
|
||||||
|
ReserveCard: apply_reserve_card,
|
||||||
|
BuyCardReserved: apply_buy_card_reserved,
|
||||||
|
}
|
||||||
|
action_func = actions.get(type(action))
|
||||||
|
if action_func is None:
|
||||||
|
msg = f"Unknown action type: {type(action)}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
action_func(game, strategy, action)
|
||||||
|
|
||||||
|
|
||||||
|
# not sure how to simplify this yet
|
||||||
|
def get_legal_actions( # noqa: C901
|
||||||
|
game: GameState,
|
||||||
|
player: PlayerState | None = None,
|
||||||
|
) -> list[Action]:
|
||||||
|
"""Enumerate all syntactically legal actions for the given player.
|
||||||
|
|
||||||
|
This enforces:
|
||||||
|
- token-taking rules
|
||||||
|
- reserve limits
|
||||||
|
- affordability for buys
|
||||||
|
"""
|
||||||
|
if player is None:
|
||||||
|
player = game.players[game.current_player_index]
|
||||||
|
|
||||||
|
actions: list[Action] = []
|
||||||
|
|
||||||
|
colors_available = [c for c in BASE_COLORS if game.bank[c] > 0]
|
||||||
|
for r in (1, 2, 3):
|
||||||
|
actions.extend(TakeDifferent(colors=list(combo)) for combo in itertools.combinations(colors_available, r))
|
||||||
|
|
||||||
|
actions.extend(
|
||||||
|
TakeDouble(color=color) for color in BASE_COLORS if game.bank[color] >= game.config.minimum_tokens_to_buy_2
|
||||||
|
)
|
||||||
|
|
||||||
|
for tier, row in game.table_by_tier.items():
|
||||||
|
for idx, card in enumerate(row):
|
||||||
|
if player.can_afford(card):
|
||||||
|
actions.append(BuyCard(tier=tier, index=idx))
|
||||||
|
|
||||||
|
for idx, card in enumerate(player.reserved):
|
||||||
|
if player.can_afford(card):
|
||||||
|
actions.append(BuyCardReserved(index=idx))
|
||||||
|
|
||||||
|
if len(player.reserved) < game.config.reserve_limit:
|
||||||
|
for tier, row in game.table_by_tier.items():
|
||||||
|
for idx, _ in enumerate(row):
|
||||||
|
actions.append(
|
||||||
|
ReserveCard(tier=tier, index=idx, from_deck=False),
|
||||||
|
)
|
||||||
|
for tier, deck in game.decks_by_tier.items():
|
||||||
|
if deck:
|
||||||
|
actions.append(
|
||||||
|
ReserveCard(tier=tier, index=None, from_deck=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def create_random_cards_tier(
|
||||||
|
tier: int,
|
||||||
|
card_count: int,
|
||||||
|
cost_choices: list[int],
|
||||||
|
point_choices: list[int],
|
||||||
|
) -> list[Card]:
|
||||||
|
"""Create a random set of cards for a given tier."""
|
||||||
|
cards: list[Card] = []
|
||||||
|
|
||||||
|
for color in BASE_COLORS:
|
||||||
|
for _ in range(card_count):
|
||||||
|
cost = dict.fromkeys(GEM_COLORS, 0)
|
||||||
|
for c in BASE_COLORS:
|
||||||
|
if c == color:
|
||||||
|
continue
|
||||||
|
cost[c] = random.choice(cost_choices)
|
||||||
|
points = random.choice(point_choices)
|
||||||
|
cards.append(Card(tier=tier, points=points, color=color, cost=cost))
|
||||||
|
|
||||||
|
return cards
|
||||||
|
|
||||||
|
|
||||||
|
def create_random_cards() -> list[Card]:
|
||||||
|
"""Generate a generic but Splendor-ish set of cards.
|
||||||
|
|
||||||
|
This is not the official deck, but structured similarly enough for play.
|
||||||
|
"""
|
||||||
|
cards: list[Card] = []
|
||||||
|
cards.extend(
|
||||||
|
create_random_cards_tier(
|
||||||
|
tier=1,
|
||||||
|
card_count=5,
|
||||||
|
cost_choices=[0, 1, 1, 2],
|
||||||
|
point_choices=[0, 0, 1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cards.extend(
|
||||||
|
create_random_cards_tier(
|
||||||
|
tier=2,
|
||||||
|
card_count=4,
|
||||||
|
cost_choices=[2, 3, 4],
|
||||||
|
point_choices=[1, 2, 2, 3],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cards.extend(
|
||||||
|
create_random_cards_tier(
|
||||||
|
tier=3,
|
||||||
|
card_count=3,
|
||||||
|
cost_choices=[4, 5, 6],
|
||||||
|
point_choices=[3, 4, 5],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
random.shuffle(cards)
|
||||||
|
return cards
|
||||||
|
|
||||||
|
|
||||||
|
def create_random_nobles() -> list[Noble]:
|
||||||
|
"""A small set of noble tiles, roughly Splendor-ish."""
|
||||||
|
nobles: list[Noble] = []
|
||||||
|
|
||||||
|
base_requirements: list[dict[GemColor, int]] = [
|
||||||
|
{"white": 3, "blue": 3, "green": 3},
|
||||||
|
{"blue": 3, "green": 3, "red": 3},
|
||||||
|
{"green": 3, "red": 3, "black": 3},
|
||||||
|
{"red": 3, "black": 3, "white": 3},
|
||||||
|
{"black": 3, "white": 3, "blue": 3},
|
||||||
|
{"white": 4, "blue": 4},
|
||||||
|
{"green": 4, "red": 4},
|
||||||
|
{"blue": 4, "black": 4},
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx, req in enumerate(base_requirements, start=1):
|
||||||
|
nobles.append(
|
||||||
|
Noble(
|
||||||
|
name=f"Noble {idx}",
|
||||||
|
points=3,
|
||||||
|
requirements=dict(req.items()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return nobles
|
||||||
|
|
||||||
|
|
||||||
|
def load_nobles(file: Path) -> list[Noble]:
|
||||||
|
"""Load nobles from a file."""
|
||||||
|
nobles = json.loads(file.read_text())
|
||||||
|
return [Noble(**noble) for noble in nobles]
|
||||||
|
|
||||||
|
|
||||||
|
def load_cards(file: Path) -> list[Card]:
|
||||||
|
"""Load cards from a file."""
|
||||||
|
cards = json.loads(file.read_text())
|
||||||
|
return [Card(**card) for card in cards]
|
||||||
|
|
||||||
|
|
||||||
|
def new_game(
|
||||||
|
strategies: Sequence[Strategy],
|
||||||
|
config: GameConfig,
|
||||||
|
) -> GameState:
|
||||||
|
"""Create a new game state from a config + list of players."""
|
||||||
|
num_players = len(strategies)
|
||||||
|
bank = get_default_starting_tokens(num_players)
|
||||||
|
|
||||||
|
decks_by_tier: dict[int, list[Card]] = {1: [], 2: [], 3: []}
|
||||||
|
for card in config.cards:
|
||||||
|
decks_by_tier.setdefault(card.tier, []).append(card)
|
||||||
|
for deck in decks_by_tier.values():
|
||||||
|
random.shuffle(deck)
|
||||||
|
|
||||||
|
table_by_tier: dict[int, list[Card]] = {1: [], 2: [], 3: []}
|
||||||
|
players = [PlayerState(strategy=strategy) for strategy in strategies]
|
||||||
|
|
||||||
|
nobles = list(config.nobles)
|
||||||
|
random.shuffle(nobles)
|
||||||
|
nobles = nobles[: num_players + 1]
|
||||||
|
|
||||||
|
game = GameState(
|
||||||
|
config=config,
|
||||||
|
players=players,
|
||||||
|
bank=bank,
|
||||||
|
decks_by_tier=decks_by_tier,
|
||||||
|
table_by_tier=table_by_tier,
|
||||||
|
available_nobles=nobles,
|
||||||
|
)
|
||||||
|
game.refill_table()
|
||||||
|
return game
|
||||||
|
|
||||||
|
|
||||||
|
def run_game(game: GameState) -> tuple[PlayerState, int]:
|
||||||
|
"""Run a full game loop until someone wins or a player returns None."""
|
||||||
|
turn_count = 0
|
||||||
|
while not game.finished:
|
||||||
|
turn_count += 1
|
||||||
|
player = game.current_player
|
||||||
|
strategy = player.strategy
|
||||||
|
action = strategy.choose_action(game, player)
|
||||||
|
if action is None:
|
||||||
|
game.finished = True
|
||||||
|
break
|
||||||
|
|
||||||
|
apply_action(game, strategy, action)
|
||||||
|
check_nobles_for_player(game, strategy, player)
|
||||||
|
|
||||||
|
winner = game.check_winner_simple()
|
||||||
|
if winner is not None:
|
||||||
|
return winner, turn_count
|
||||||
|
|
||||||
|
game.next_player()
|
||||||
|
if turn_count >= game.config.turn_limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
fallback = max(game.players, key=lambda player: player.score)
|
||||||
|
return fallback, turn_count
|
||||||
@@ -0,0 +1,288 @@
|
|||||||
|
"""Bot for Splendor game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
BASE_COLORS,
|
||||||
|
Action,
|
||||||
|
BuyCard,
|
||||||
|
BuyCardReserved,
|
||||||
|
Card,
|
||||||
|
GameState,
|
||||||
|
GemColor,
|
||||||
|
PlayerState,
|
||||||
|
ReserveCard,
|
||||||
|
Strategy,
|
||||||
|
TakeDifferent,
|
||||||
|
TakeDouble,
|
||||||
|
auto_discard_tokens,
|
||||||
|
get_legal_actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def can_bot_afford(player: PlayerState, card: Card) -> bool:
|
||||||
|
"""Check if player can afford card, using discounts + gold."""
|
||||||
|
missing = 0
|
||||||
|
gold = player.tokens["gold"]
|
||||||
|
for color, cost in card.cost.items():
|
||||||
|
missing += max(0, cost - player.discounts.get(color, 0) - player.tokens.get(color, 0))
|
||||||
|
if missing > gold:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class RandomBot(Strategy):
|
||||||
|
"""Dumb bot that follows rules but doesn't think."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialize the bot."""
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Choose an action for the current player."""
|
||||||
|
affordable: list[tuple[int, int]] = []
|
||||||
|
for tier, row in game.table_by_tier.items():
|
||||||
|
for idx, card in enumerate(row):
|
||||||
|
if can_bot_afford(player, card):
|
||||||
|
affordable.append((tier, idx))
|
||||||
|
if affordable and random.random() < 0.5:
|
||||||
|
tier, idx = random.choice(affordable)
|
||||||
|
return BuyCard(tier=tier, index=idx)
|
||||||
|
|
||||||
|
if random.random() < 0.2:
|
||||||
|
tier = random.choice([1, 2, 3])
|
||||||
|
row = game.table_by_tier.get(tier, [])
|
||||||
|
if row:
|
||||||
|
idx = random.randrange(len(row))
|
||||||
|
return ReserveCard(tier=tier, index=idx, from_deck=False)
|
||||||
|
|
||||||
|
if random.random() < 0.5:
|
||||||
|
colors_for_double = [c for c in BASE_COLORS if game.bank[c] >= 4]
|
||||||
|
if colors_for_double:
|
||||||
|
return TakeDouble(color=random.choice(colors_for_double))
|
||||||
|
|
||||||
|
colors_for_diff = [c for c in BASE_COLORS if game.bank[c] > 0]
|
||||||
|
random.shuffle(colors_for_diff)
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int,
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Choose how many tokens to discard."""
|
||||||
|
return auto_discard_tokens(player, excess)
|
||||||
|
|
||||||
|
|
||||||
|
def check_cards_in_tier(row: list[Card], player: PlayerState) -> list[int]:
|
||||||
|
"""Check if player can afford card, using discounts + gold."""
|
||||||
|
return [index for index, card in enumerate(row) if can_bot_afford(player, card)]
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBot(Strategy):
|
||||||
|
"""PersonalizedBot."""
|
||||||
|
|
||||||
|
"""Dumb bot that follows rules but doesn't think."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialize the bot."""
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Choose an action for the current player."""
|
||||||
|
for tier in (1, 2, 3):
|
||||||
|
row = game.table_by_tier[tier]
|
||||||
|
if affordable := check_cards_in_tier(row, player):
|
||||||
|
index = random.choice(affordable)
|
||||||
|
return BuyCard(tier=tier, index=index)
|
||||||
|
|
||||||
|
colors_for_diff = [c for c in BASE_COLORS if game.bank[c] > 0]
|
||||||
|
random.shuffle(colors_for_diff)
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int,
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Choose how many tokens to discard."""
|
||||||
|
return auto_discard_tokens(player, excess)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBot2(Strategy):
|
||||||
|
"""PersonalizedBot2."""
|
||||||
|
|
||||||
|
"""Dumb bot that follows rules but doesn't think."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialize the bot."""
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Choose an action for the current player."""
|
||||||
|
tiers = (1, 2, 3)
|
||||||
|
for tier in tiers:
|
||||||
|
row = game.table_by_tier[tier]
|
||||||
|
if affordable := check_cards_in_tier(row, player):
|
||||||
|
index = random.choice(affordable)
|
||||||
|
return BuyCard(tier=tier, index=index)
|
||||||
|
|
||||||
|
if affordable := check_cards_in_tier(player.reserved, player):
|
||||||
|
index = random.choice(affordable)
|
||||||
|
return BuyCardReserved(index=index)
|
||||||
|
|
||||||
|
colors_for_diff = [c for c in BASE_COLORS if game.bank[c] > 0]
|
||||||
|
if len(colors_for_diff) >= 3:
|
||||||
|
random.shuffle(colors_for_diff)
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
for tier in tiers:
|
||||||
|
len_deck = len(game.decks_by_tier[tier])
|
||||||
|
if len_deck:
|
||||||
|
return ReserveCard(tier=tier, index=None, from_deck=True)
|
||||||
|
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int,
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Choose how many tokens to discard."""
|
||||||
|
return auto_discard_tokens(player, excess)
|
||||||
|
|
||||||
|
|
||||||
|
def buy_card_reserved(player: PlayerState) -> Action | None:
|
||||||
|
"""Buy a card reserved."""
|
||||||
|
if affordable := check_cards_in_tier(player.reserved, player):
|
||||||
|
index = random.choice(affordable)
|
||||||
|
return BuyCardReserved(index=index)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def buy_card(game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Buy a card."""
|
||||||
|
for tier in (1, 2, 3):
|
||||||
|
row = game.table_by_tier[tier]
|
||||||
|
if affordable := check_cards_in_tier(row, player):
|
||||||
|
index = random.choice(affordable)
|
||||||
|
return BuyCard(tier=tier, index=index)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def take_tokens(game: GameState) -> Action | None:
|
||||||
|
"""Take tokens."""
|
||||||
|
colors_for_diff = [color for color in BASE_COLORS if game.bank[color] > 0]
|
||||||
|
if len(colors_for_diff) >= 3:
|
||||||
|
random.shuffle(colors_for_diff)
|
||||||
|
return TakeDifferent(colors=colors_for_diff[: game.config.max_token_take])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBot3(Strategy):
|
||||||
|
"""PersonalizedBot3."""
|
||||||
|
|
||||||
|
"""Dumb bot that follows rules but doesn't think."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialize the bot."""
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Choose an action for the current player."""
|
||||||
|
print(len(get_legal_actions(game, player)))
|
||||||
|
print(get_legal_actions(game, player))
|
||||||
|
if action := buy_card_reserved(player):
|
||||||
|
return action
|
||||||
|
if action := buy_card(game, player):
|
||||||
|
return action
|
||||||
|
|
||||||
|
colors_for_diff = [color for color in BASE_COLORS if game.bank[color] > 0]
|
||||||
|
if len(colors_for_diff) >= 3:
|
||||||
|
random.shuffle(colors_for_diff)
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
for tier in (1, 2, 3):
|
||||||
|
len_deck = len(game.decks_by_tier[tier])
|
||||||
|
if len_deck:
|
||||||
|
return ReserveCard(tier=tier, index=None, from_deck=True)
|
||||||
|
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int,
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Choose how many tokens to discard."""
|
||||||
|
return auto_discard_tokens(player, excess)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_value_of_card(game: GameState, player: PlayerState, color: GemColor) -> int:
|
||||||
|
"""Estimate value of a color in the player's bank."""
|
||||||
|
return game.bank[color] - player.discounts.get(color, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_value_of_token(game: GameState, player: PlayerState, color: GemColor) -> int:
|
||||||
|
"""Estimate value of a color in the player's bank."""
|
||||||
|
return game.bank[color] - player.discounts.get(color, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBot4(Strategy):
|
||||||
|
"""PersonalizedBot4."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialize the bot."""
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
def filter_actions(self, actions: list[Action]) -> list[Action]:
|
||||||
|
"""Filter actions to only take different."""
|
||||||
|
return [
|
||||||
|
action
|
||||||
|
for action in actions
|
||||||
|
if (isinstance(action, TakeDifferent) and len(action.colors) == 3) or not isinstance(action, TakeDifferent)
|
||||||
|
]
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||||
|
"""Choose an action for the current player."""
|
||||||
|
legal_actions = get_legal_actions(game, player)
|
||||||
|
print(len(legal_actions))
|
||||||
|
|
||||||
|
good_actions = self.filter_actions(legal_actions)
|
||||||
|
print(len(good_actions))
|
||||||
|
|
||||||
|
print(good_actions)
|
||||||
|
|
||||||
|
print(len(get_legal_actions(game, player)))
|
||||||
|
if action := buy_card_reserved(player):
|
||||||
|
return action
|
||||||
|
if action := buy_card(game, player):
|
||||||
|
return action
|
||||||
|
|
||||||
|
colors_for_diff = [color for color in BASE_COLORS if game.bank[color] > 0]
|
||||||
|
if len(colors_for_diff) >= 3:
|
||||||
|
random.shuffle(colors_for_diff)
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
for tier in (1, 2, 3):
|
||||||
|
len_deck = len(game.decks_by_tier[tier])
|
||||||
|
if len_deck:
|
||||||
|
return ReserveCard(tier=tier, index=None, from_deck=True)
|
||||||
|
|
||||||
|
return TakeDifferent(colors=colors_for_diff[:3])
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState, # noqa: ARG002
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int,
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Choose how many tokens to discard."""
|
||||||
|
return auto_discard_tokens(player, excess)
|
||||||
@@ -0,0 +1,724 @@
|
|||||||
|
"""Splendor game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from textual.app import App, ComposeResult
|
||||||
|
from textual.containers import Horizontal, Vertical
|
||||||
|
from textual.widget import Widget
|
||||||
|
from textual.widgets import Footer, Header, Input, Static
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
BASE_COLORS,
|
||||||
|
GEM_COLORS,
|
||||||
|
Action,
|
||||||
|
BuyCard,
|
||||||
|
BuyCardReserved,
|
||||||
|
Card,
|
||||||
|
GameState,
|
||||||
|
GemColor,
|
||||||
|
Noble,
|
||||||
|
PlayerState,
|
||||||
|
ReserveCard,
|
||||||
|
Strategy,
|
||||||
|
TakeDifferent,
|
||||||
|
TakeDouble,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
# Abbreviations used when rendering costs
|
||||||
|
COST_ABBR: dict[GemColor, str] = {
|
||||||
|
"white": "W",
|
||||||
|
"blue": "B",
|
||||||
|
"green": "G",
|
||||||
|
"red": "R",
|
||||||
|
"black": "K",
|
||||||
|
"gold": "O",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Abbreviations players can type on the command line
|
||||||
|
COLOR_ABBR_TO_FULL: dict[str, GemColor] = {
|
||||||
|
"w": "white",
|
||||||
|
"b": "blue",
|
||||||
|
"g": "green",
|
||||||
|
"r": "red",
|
||||||
|
"k": "black",
|
||||||
|
"o": "gold",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_color_token(raw: str) -> GemColor:
|
||||||
|
"""Convert user input into a GemColor.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- full names: white, blue, green, red, black, gold
|
||||||
|
- abbreviations: w, b, g, r, k, o
|
||||||
|
"""
|
||||||
|
key = raw.lower()
|
||||||
|
|
||||||
|
# full color names first
|
||||||
|
if key in BASE_COLORS:
|
||||||
|
return key # type: ignore[return-value]
|
||||||
|
|
||||||
|
# abbreviations
|
||||||
|
if key in COLOR_ABBR_TO_FULL:
|
||||||
|
return COLOR_ABBR_TO_FULL[key]
|
||||||
|
|
||||||
|
error = f"Unknown color: {raw}"
|
||||||
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
|
def format_cost(cost: Mapping[GemColor, int]) -> str:
|
||||||
|
"""Format a cost/requirements dict as colored tokens like 'B:2, R:1'.
|
||||||
|
|
||||||
|
Uses `color_token` internally so colors are guaranteed to match your bank.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
for color in GEM_COLORS:
|
||||||
|
n = cost.get(color, 0)
|
||||||
|
if not n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# color_token gives us e.g. "[blue]blue: 3[/]"
|
||||||
|
token = color_token(color, n)
|
||||||
|
|
||||||
|
# Turn the leading color name into the abbreviation (blue: 3 → B:3)
|
||||||
|
# We only replace the first occurrence.
|
||||||
|
full = f"{color}:"
|
||||||
|
abbr = f"{COST_ABBR[color]}:"
|
||||||
|
token = token.replace(full, abbr, 1)
|
||||||
|
|
||||||
|
parts.append(token)
|
||||||
|
|
||||||
|
return ", ".join(parts) if parts else "-"
|
||||||
|
|
||||||
|
|
||||||
|
def format_card(card: Card) -> str:
|
||||||
|
"""Readable card line using dataclass fields instead of __str__."""
|
||||||
|
color_abbr = COST_ABBR[card.color]
|
||||||
|
header = f"T{card.tier} {color_abbr} P{card.points}"
|
||||||
|
cost_str = format_cost(card.cost)
|
||||||
|
return f"{header} ({cost_str})"
|
||||||
|
|
||||||
|
|
||||||
|
def format_noble(noble: Noble) -> str:
|
||||||
|
"""Readable noble line using dataclass fields instead of __str__."""
|
||||||
|
cost_str = format_cost(noble.requirements)
|
||||||
|
return f"{noble.name} +{noble.points} ({cost_str})"
|
||||||
|
|
||||||
|
|
||||||
|
def format_tokens(tokens: Mapping[GemColor, int]) -> str:
|
||||||
|
"""Colored 'color: n' list for a token dict."""
|
||||||
|
return " ".join(color_token(c, tokens.get(c, 0)) for c in GEM_COLORS)
|
||||||
|
|
||||||
|
|
||||||
|
def format_discounts(discounts: Mapping[GemColor, int]) -> str:
|
||||||
|
"""Colored discounts, skipping zeros."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for c in GEM_COLORS:
|
||||||
|
n = discounts.get(c, 0)
|
||||||
|
if not n:
|
||||||
|
continue
|
||||||
|
abbr = COST_ABBR[c]
|
||||||
|
fg, bg = COLOR_STYLE[c]
|
||||||
|
parts.append(f"[{fg} on {bg}]{abbr}:{n}[/{fg} on {bg}]")
|
||||||
|
return ", ".join(parts) if parts else "-"
|
||||||
|
|
||||||
|
|
||||||
|
COLOR_STYLE: dict[GemColor, tuple[str, str]] = {
|
||||||
|
"white": ("black", "white"), # fg, bg
|
||||||
|
"blue": ("bright_white", "blue"),
|
||||||
|
"green": ("bright_white", "sea_green4"),
|
||||||
|
"red": ("white", "red3"),
|
||||||
|
"black": ("white", "grey0"),
|
||||||
|
"gold": ("black", "yellow3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fmt_gem(color: GemColor) -> str:
|
||||||
|
"""Render gem name with fg/bg matching real token color."""
|
||||||
|
fg, bg = COLOR_STYLE[color]
|
||||||
|
return f"[{fg} on {bg}] {color} [/{fg} on {bg}]"
|
||||||
|
|
||||||
|
|
||||||
|
def fmt_number(value: int) -> str:
|
||||||
|
"""Return a Rich-markup colored 'value' string."""
|
||||||
|
return f"[bold cyan]{value}[/]"
|
||||||
|
|
||||||
|
|
||||||
|
def color_token(name: GemColor, amount: int) -> str:
|
||||||
|
"""Return a Rich-markup colored 'name: n' string."""
|
||||||
|
# Map Splendor colors -> terminal colors
|
||||||
|
color_map: Mapping[GemColor, str] = {
|
||||||
|
"white": "white",
|
||||||
|
"blue": "blue",
|
||||||
|
"green": "green",
|
||||||
|
"red": "red",
|
||||||
|
"black": "grey70", # 'black' is unreadable on dark backgrounds
|
||||||
|
"gold": "yellow",
|
||||||
|
}
|
||||||
|
style = color_map.get(name, "white")
|
||||||
|
return f"[{style}]{name}: {amount}[/]"
|
||||||
|
|
||||||
|
|
||||||
|
class Board(Widget):
|
||||||
|
"""Big board widget with the layout you sketched."""
|
||||||
|
|
||||||
|
def __init__(self, game: GameState, me: PlayerState, **kwargs: Any) -> None: # noqa: ANN401
|
||||||
|
"""Initialize the board widget."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.game = game
|
||||||
|
self.me = me
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the board widget."""
|
||||||
|
# Structure:
|
||||||
|
# ┌ bank row
|
||||||
|
# ├ middle row (tiers | nobles)
|
||||||
|
# └ players row
|
||||||
|
with Vertical(id="board_root"):
|
||||||
|
yield Static(id="bank_box")
|
||||||
|
with Horizontal(id="middle_row"):
|
||||||
|
with Vertical(id="tiers_box"):
|
||||||
|
yield Static(id="tier1_box")
|
||||||
|
yield Static(id="tier2_box")
|
||||||
|
yield Static(id="tier3_box")
|
||||||
|
yield Static(id="nobles_box")
|
||||||
|
yield Static(id="players_box")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Refresh the board content."""
|
||||||
|
self.refresh_content()
|
||||||
|
|
||||||
|
def refresh_content(self) -> None:
|
||||||
|
"""Refresh the board content."""
|
||||||
|
self._render_bank()
|
||||||
|
self._render_tiers()
|
||||||
|
self._render_nobles()
|
||||||
|
self._render_players()
|
||||||
|
|
||||||
|
# --- sections ----------------------------------------------------
|
||||||
|
|
||||||
|
def _render_bank(self) -> None:
|
||||||
|
bank = self.game.bank
|
||||||
|
parts: list[str] = ["[b]Bank:[/b]"]
|
||||||
|
# One line, all tokens colored
|
||||||
|
parts.append(format_tokens(bank))
|
||||||
|
self.query_one("#bank_box", Static).update("\n".join(parts))
|
||||||
|
|
||||||
|
def _render_tiers(self) -> None:
|
||||||
|
for tier in (1, 2, 3):
|
||||||
|
box = self.query_one(f"#tier{tier}_box", Static)
|
||||||
|
cards: list[Card] = self.game.table_by_tier.get(tier, [])
|
||||||
|
lines: list[str] = [f"[b]Tier {tier} cards:[/b]"]
|
||||||
|
if not cards:
|
||||||
|
lines.append(" (none)")
|
||||||
|
else:
|
||||||
|
for idx, card in enumerate(cards):
|
||||||
|
lines.append(f" [{idx}] {format_card(card)}")
|
||||||
|
box.update("\n".join(lines))
|
||||||
|
|
||||||
|
def _render_nobles(self) -> None:
|
||||||
|
nobles_box = self.query_one("#nobles_box", Static)
|
||||||
|
lines: list[str] = ["[b]Nobles[/b]"]
|
||||||
|
if not self.game.available_nobles:
|
||||||
|
lines.append(" (none)")
|
||||||
|
else:
|
||||||
|
lines.extend(" - " + format_noble(noble) for noble in self.game.available_nobles)
|
||||||
|
nobles_box.update("\n".join(lines))
|
||||||
|
|
||||||
|
def _render_players(self) -> None:
|
||||||
|
players_box = self.query_one("#players_box", Static)
|
||||||
|
lines: list[str] = ["[b]Players:[/b]", ""]
|
||||||
|
for player in self.game.players:
|
||||||
|
mark = "*" if player is self.me else " "
|
||||||
|
token_str = format_tokens(player.tokens)
|
||||||
|
discount_str = format_discounts(player.discounts)
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
f"{mark} {player.name:10} Score={player.score:2d} Discounts={discount_str}",
|
||||||
|
)
|
||||||
|
lines.append(f" Tokens: {token_str}")
|
||||||
|
|
||||||
|
if player.nobles:
|
||||||
|
noble_names = ", ".join(n.name for n in player.nobles)
|
||||||
|
lines.append(f" Nobles: {noble_names}")
|
||||||
|
|
||||||
|
# Optional: show counts of cards / reserved
|
||||||
|
if player.cards:
|
||||||
|
lines.append(f" Cards: {len(player.cards)}")
|
||||||
|
if player.reserved:
|
||||||
|
lines.append(f" Reserved: {len(player.reserved)}")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
players_box.update("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
class ActionApp(App[None]):
|
||||||
|
"""Textual app that asks for a single action command and returns an Action."""
|
||||||
|
|
||||||
|
CSS = """
|
||||||
|
Screen {
|
||||||
|
/* 3 rows: command zone, board, footer */
|
||||||
|
layout: grid;
|
||||||
|
grid-size: 1 3;
|
||||||
|
grid-rows: auto 1fr auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Top area with input + instructions */
|
||||||
|
#command_zone {
|
||||||
|
grid-columns: 1;
|
||||||
|
grid-rows: 1;
|
||||||
|
padding: 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Board sits in the middle row and can grow */
|
||||||
|
#board {
|
||||||
|
grid-columns: 1;
|
||||||
|
grid-rows: 2;
|
||||||
|
padding: 0 1 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Footer {
|
||||||
|
grid-columns: 1;
|
||||||
|
grid-rows: 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
Input {
|
||||||
|
border: round $accent;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* === Board layout === */
|
||||||
|
|
||||||
|
#board_root {
|
||||||
|
/* outer frame around the whole board area */
|
||||||
|
border: heavy white;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Bank row: full width */
|
||||||
|
#bank_box {
|
||||||
|
border: heavy white;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Middle row: tiers (left) + nobles (right) */
|
||||||
|
#middle_row {
|
||||||
|
layout: horizontal;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tiers_box {
|
||||||
|
border: heavy white;
|
||||||
|
padding: 0 1;
|
||||||
|
width: 70%;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tier1_box,
|
||||||
|
#tier2_box,
|
||||||
|
#tier3_box {
|
||||||
|
border-bottom: heavy white;
|
||||||
|
padding: 0 0 1 0;
|
||||||
|
margin-bottom: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#nobles_box {
|
||||||
|
border: heavy white;
|
||||||
|
padding: 0 1;
|
||||||
|
width: 30%;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Players row: full width at bottom */
|
||||||
|
#players_box {
|
||||||
|
border: heavy white;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, game: GameState, player: PlayerState) -> None:
|
||||||
|
"""Initialize the action app."""
|
||||||
|
super().__init__()
|
||||||
|
self.game = game
|
||||||
|
self.player = player
|
||||||
|
self.result: Action | None = None
|
||||||
|
self.message: str = ""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the action app."""
|
||||||
|
# Row 1: input + Actions text
|
||||||
|
with Vertical(id="command_zone"):
|
||||||
|
yield Input(
|
||||||
|
placeholder="Enter command, e.g. '1 white blue red' or '1 w b r' or 'q'",
|
||||||
|
id="input_line",
|
||||||
|
)
|
||||||
|
yield Static("", id="prompt")
|
||||||
|
|
||||||
|
# Row 2: board
|
||||||
|
yield Board(self.game, self.player, id="board")
|
||||||
|
|
||||||
|
# Row 3: footer
|
||||||
|
yield Footer()
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Mount the action app."""
|
||||||
|
self._update_prompt()
|
||||||
|
self.query_one(Input).focus()
|
||||||
|
|
||||||
|
def _update_prompt(self) -> None:
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append("[bold underline]Actions:[/]")
|
||||||
|
lines.append(
|
||||||
|
" [bold green]1[/] <colors...> - Take up to 3 different gem colors "
|
||||||
|
"(e.g. [cyan]1 white blue red[/] or [cyan]1 w b r[/])",
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
f" [bold green]2[/] <color> - Take 2 of the same color (needs {fmt_number(4)} in bank, "
|
||||||
|
"e.g. [cyan]2 blue[/] or [cyan]2 b[/])",
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
" [bold green]3[/] <tier> <idx> - Buy a face-up card (e.g. [cyan]3 1 0[/] for tier 1, index 0)",
|
||||||
|
)
|
||||||
|
lines.append(" [bold green]4[/] <idx> - Buy a reserved card")
|
||||||
|
lines.append(" [bold green]5[/] <tier> <idx> - Reserve a face-up card")
|
||||||
|
lines.append(" [bold green]6[/] <tier> - Reserve top card of a deck")
|
||||||
|
lines.append(" [bold red]q[/] - Quit game")
|
||||||
|
if self.message:
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"[bold red]Message:[/] {self.message}")
|
||||||
|
self.query_one("#prompt", Static).update("\n".join(lines))
|
||||||
|
|
||||||
|
def _cmd_1(self, parts: list[str]) -> str | None:
|
||||||
|
"""Take up to 3 different gem colors: 1 white blue red OR 1 w b r."""
|
||||||
|
color_names = parts[1:]
|
||||||
|
if not color_names:
|
||||||
|
return "Need at least one color (full name or abbreviation)."
|
||||||
|
colors: list[GemColor] = []
|
||||||
|
for name in color_names:
|
||||||
|
color = parse_color_token(name)
|
||||||
|
if self.game.bank[color] <= 0:
|
||||||
|
return f"No tokens left for color: {color}"
|
||||||
|
colors.append(color)
|
||||||
|
self.result = TakeDifferent(colors=colors[:3])
|
||||||
|
self.exit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cmd_2(self, parts: list[str]) -> str | None:
|
||||||
|
"""Take two of the same color."""
|
||||||
|
if len(parts) < 2:
|
||||||
|
return "Usage: 2 <color>"
|
||||||
|
color = parse_color_token(parts[1])
|
||||||
|
if self.game.bank[color] < self.game.config.minimum_tokens_to_buy_2:
|
||||||
|
return "Bank must have at least 4 of that color."
|
||||||
|
self.result = TakeDouble(color=color)
|
||||||
|
self.exit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cmd_3(self, parts: list[str]) -> str | None:
|
||||||
|
"""Buy face-up card."""
|
||||||
|
if len(parts) < 3:
|
||||||
|
return "Usage: 3 <tier> <index>"
|
||||||
|
tier = int(parts[1])
|
||||||
|
idx = int(parts[2])
|
||||||
|
self.result = BuyCard(tier=tier, index=idx)
|
||||||
|
self.exit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cmd_4(self, parts: list[str]) -> str | None:
|
||||||
|
"""Buy reserved card."""
|
||||||
|
if len(parts) < 2:
|
||||||
|
return "Usage: 4 <reserved_index>"
|
||||||
|
idx = int(parts[1])
|
||||||
|
if not (0 <= idx < len(self.player.reserved)):
|
||||||
|
return "Reserved index out of range."
|
||||||
|
self.result = BuyCardReserved(tier=0, index=idx)
|
||||||
|
self.exit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cmd_5(self, parts: list[str]) -> str | None:
|
||||||
|
"""Reserve face-up card."""
|
||||||
|
if len(parts) < 3:
|
||||||
|
return "Usage: 5 <tier> <index>"
|
||||||
|
tier = int(parts[1])
|
||||||
|
idx = int(parts[2])
|
||||||
|
self.result = ReserveCard(tier=tier, index=idx, from_deck=False)
|
||||||
|
self.exit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cmd_6(self, parts: list[str]) -> str | None:
|
||||||
|
"""Reserve top of deck."""
|
||||||
|
if len(parts) < 2:
|
||||||
|
return "Usage: 6 <tier>"
|
||||||
|
tier = int(parts[1])
|
||||||
|
self.result = ReserveCard(tier=tier, index=None, from_deck=True)
|
||||||
|
self.exit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _unknown_cmd(self, _parts: list[str]) -> str:
|
||||||
|
return "Unknown command."
|
||||||
|
|
||||||
|
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||||
|
"""Handle user input."""
|
||||||
|
text = (event.value or "").strip()
|
||||||
|
event.input.value = ""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
if text.lower() in {"q", "quit", "0"}:
|
||||||
|
self.result = None
|
||||||
|
self.exit()
|
||||||
|
return
|
||||||
|
|
||||||
|
parts = text.split()
|
||||||
|
|
||||||
|
cmds = {
|
||||||
|
"1": self._cmd_1,
|
||||||
|
"2": self._cmd_2,
|
||||||
|
"3": self._cmd_3,
|
||||||
|
"4": self._cmd_4,
|
||||||
|
"5": self._cmd_5,
|
||||||
|
"6": self._cmd_6,
|
||||||
|
}
|
||||||
|
cmd = parts[0]
|
||||||
|
|
||||||
|
error = cmds.get(cmd, self._unknown_cmd)(parts)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
self.message = error
|
||||||
|
self._update_prompt()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class DiscardApp(App[None]):
|
||||||
|
"""Textual app to choose discards when over token limit."""
|
||||||
|
|
||||||
|
CSS = """
|
||||||
|
Screen {
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command_zone {
|
||||||
|
padding: 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#board {
|
||||||
|
padding: 0 1 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Input {
|
||||||
|
border: round $accent;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, game: GameState, player: PlayerState) -> None:
|
||||||
|
"""Initialize the discard app."""
|
||||||
|
super().__init__()
|
||||||
|
self.game = game
|
||||||
|
self.player = player
|
||||||
|
self.discards: dict[GemColor, int] = dict.fromkeys(GEM_COLORS, 0)
|
||||||
|
self.message: str = ""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult: # type: ignore[override]
|
||||||
|
"""Compose the discard app."""
|
||||||
|
yield Header(show_clock=False)
|
||||||
|
|
||||||
|
with Vertical(id="command_zone"):
|
||||||
|
yield Input(
|
||||||
|
placeholder="Enter color to discard, e.g. 'blue' or 'b'",
|
||||||
|
id="input_line",
|
||||||
|
)
|
||||||
|
yield Static("", id="prompt")
|
||||||
|
|
||||||
|
# Board directly under the command zone
|
||||||
|
yield Board(self.game, self.player, id="board")
|
||||||
|
|
||||||
|
yield Footer()
|
||||||
|
|
||||||
|
def on_mount(self) -> None: # type: ignore[override]
|
||||||
|
"""Mount the discard app."""
|
||||||
|
self._update_prompt()
|
||||||
|
self.query_one(Input).focus()
|
||||||
|
|
||||||
|
def _remaining_to_discard(self) -> int:
|
||||||
|
return self.player.total_tokens() - sum(self.discards.values()) - self.game.config.token_limit
|
||||||
|
|
||||||
|
def _update_prompt(self) -> None:
|
||||||
|
remaining = max(self._remaining_to_discard(), 0)
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append(
|
||||||
|
"You must discard "
|
||||||
|
f"{fmt_number(remaining)} token(s) "
|
||||||
|
f"to get down to {fmt_number(self.game.config.token_limit)}.",
|
||||||
|
)
|
||||||
|
disc_str = ", ".join(f"{fmt_gem(c)}={fmt_number(self.discards[c])}" for c in GEM_COLORS)
|
||||||
|
lines.append(f"Current planned discards: {{ {disc_str} }}")
|
||||||
|
lines.append(
|
||||||
|
"Type a color name or abbreviation (e.g. 'blue' or 'b') to discard one token.",
|
||||||
|
)
|
||||||
|
if self.message:
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"[bold red]Message:[/] {self.message}")
|
||||||
|
self.query_one("#prompt", Static).update("\n".join(lines))
|
||||||
|
|
||||||
|
def on_input_submitted(self, event: Input.Submitted) -> None: # type: ignore[override]
|
||||||
|
"""Handle user input."""
|
||||||
|
raw = (event.value or "").strip()
|
||||||
|
event.input.value = ""
|
||||||
|
if not raw:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
color = parse_color_token(raw)
|
||||||
|
except ValueError:
|
||||||
|
self.message = f"Unknown color: {raw}"
|
||||||
|
self._update_prompt()
|
||||||
|
return
|
||||||
|
|
||||||
|
available = self.player.tokens[color] - self.discards[color]
|
||||||
|
if available <= 0:
|
||||||
|
self.message = f"No more {color} tokens available to discard."
|
||||||
|
self._update_prompt()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.discards[color] += 1
|
||||||
|
if self._remaining_to_discard() <= 0:
|
||||||
|
self.exit()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.message = ""
|
||||||
|
self._update_prompt()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Noble choice app
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class NobleChoiceApp(App[None]):
|
||||||
|
"""Textual app to choose one noble."""
|
||||||
|
|
||||||
|
CSS = """
|
||||||
|
Screen {
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command_zone {
|
||||||
|
padding: 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#board {
|
||||||
|
padding: 0 1 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Input {
|
||||||
|
border: round $accent;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
game: GameState,
|
||||||
|
player: PlayerState,
|
||||||
|
nobles: list[Noble],
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the noble choice app."""
|
||||||
|
super().__init__()
|
||||||
|
self.game = game
|
||||||
|
self.player = player
|
||||||
|
self.nobles = nobles
|
||||||
|
self.result: Noble | None = None
|
||||||
|
self.message: str = ""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult: # type: ignore[override]
|
||||||
|
"""Compose the noble choice app."""
|
||||||
|
yield Header(show_clock=False)
|
||||||
|
|
||||||
|
with Vertical(id="command_zone"):
|
||||||
|
yield Input(
|
||||||
|
placeholder="Enter noble index, e.g. '0'",
|
||||||
|
id="input_line",
|
||||||
|
)
|
||||||
|
yield Static("", id="prompt")
|
||||||
|
|
||||||
|
# Board directly under the command zone
|
||||||
|
yield Board(self.game, self.player, id="board")
|
||||||
|
|
||||||
|
yield Footer()
|
||||||
|
|
||||||
|
def on_mount(self) -> None: # type: ignore[override]
|
||||||
|
"""Mount the noble choice app."""
|
||||||
|
self._update_prompt()
|
||||||
|
self.query_one(Input).focus()
|
||||||
|
|
||||||
|
def _update_prompt(self) -> None:
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append("[bold underline]You qualify for nobles:[/]")
|
||||||
|
for i, noble in enumerate(self.nobles):
|
||||||
|
lines.append(f" [bright_cyan]{i})[/] {format_noble(noble)}")
|
||||||
|
lines.append("Enter the index of the noble you want.")
|
||||||
|
if self.message:
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"[bold red]Message:[/] {self.message}")
|
||||||
|
self.query_one("#prompt", Static).update("\n".join(lines))
|
||||||
|
|
||||||
|
def on_input_submitted(self, event: Input.Submitted) -> None: # type: ignore[override]
|
||||||
|
"""Handle user input."""
|
||||||
|
raw = (event.value or "").strip()
|
||||||
|
event.input.value = ""
|
||||||
|
if not raw:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
idx = int(raw)
|
||||||
|
except ValueError:
|
||||||
|
self.message = "Please enter a valid integer index."
|
||||||
|
self._update_prompt()
|
||||||
|
return
|
||||||
|
if not (0 <= idx < len(self.nobles)):
|
||||||
|
self.message = "Index out of range."
|
||||||
|
self._update_prompt()
|
||||||
|
return
|
||||||
|
self.result = self.nobles[idx]
|
||||||
|
self.exit()
|
||||||
|
|
||||||
|
|
||||||
|
class TuiHuman(Strategy):
|
||||||
|
"""Textual-based human player Strategy with colorful board."""
|
||||||
|
|
||||||
|
def choose_action(
|
||||||
|
self,
|
||||||
|
game: GameState,
|
||||||
|
player: PlayerState,
|
||||||
|
) -> Action | None:
|
||||||
|
"""Choose an action for the player."""
|
||||||
|
if not sys.stdout.isatty():
|
||||||
|
return None
|
||||||
|
app = ActionApp(game, player)
|
||||||
|
app.run()
|
||||||
|
return app.result
|
||||||
|
|
||||||
|
def choose_discard(
|
||||||
|
self,
|
||||||
|
game: GameState,
|
||||||
|
player: PlayerState,
|
||||||
|
excess: int, # noqa: ARG002
|
||||||
|
) -> dict[GemColor, int]:
|
||||||
|
"""Choose tokens to discard."""
|
||||||
|
if not sys.stdout.isatty():
|
||||||
|
return dict.fromkeys(GEM_COLORS, 0)
|
||||||
|
app = DiscardApp(game, player)
|
||||||
|
app.run()
|
||||||
|
return app.discards
|
||||||
|
|
||||||
|
def choose_noble(
|
||||||
|
self,
|
||||||
|
game: GameState,
|
||||||
|
player: PlayerState,
|
||||||
|
nobles: list[Noble],
|
||||||
|
) -> Noble:
|
||||||
|
"""Choose a noble for the player."""
|
||||||
|
if not sys.stdout.isatty():
|
||||||
|
return nobles[0]
|
||||||
|
app = NobleChoiceApp(game, player, nobles)
|
||||||
|
app.run()
|
||||||
|
return app.result
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""Main entry point for Splendor game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import new_game, run_game
|
||||||
|
from .bot import RandomBot
|
||||||
|
from .human import TuiHuman
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main entry point."""
|
||||||
|
human = TuiHuman()
|
||||||
|
bot = RandomBot()
|
||||||
|
game_state = new_game(["You", "Bot A"])
|
||||||
|
run_game(game_state, [human, bot])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
"""Public state for RL/search."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
BASE_COLORS,
|
||||||
|
BASE_INDEX,
|
||||||
|
GEM_ORDER,
|
||||||
|
Card,
|
||||||
|
GameState,
|
||||||
|
Noble,
|
||||||
|
PlayerState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ObsCard:
|
||||||
|
"""Numeric-ish card view for RL/search."""
|
||||||
|
|
||||||
|
tier: int
|
||||||
|
points: int
|
||||||
|
color_index: int
|
||||||
|
cost: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ObsNoble:
|
||||||
|
"""Numeric-ish noble view for RL/search."""
|
||||||
|
|
||||||
|
points: int
|
||||||
|
requirements: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ObsPlayer:
|
||||||
|
"""Numeric-ish player view for RL/search."""
|
||||||
|
|
||||||
|
tokens: list[int]
|
||||||
|
discounts: list[int]
|
||||||
|
score: int
|
||||||
|
cards: list[ObsCard]
|
||||||
|
reserved: list[ObsCard]
|
||||||
|
nobles: list[ObsNoble]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Observation:
|
||||||
|
"""Full public state for RL/search."""
|
||||||
|
|
||||||
|
current_player: int
|
||||||
|
bank: list[int]
|
||||||
|
players: list[ObsPlayer]
|
||||||
|
table_by_tier: dict[int, list[ObsCard]]
|
||||||
|
decks_remaining: dict[int, int]
|
||||||
|
available_nobles: list[ObsNoble]
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_card(card: Card) -> ObsCard:
|
||||||
|
color_index = BASE_INDEX.get(card.color, -1)
|
||||||
|
cost_vec = [card.cost.get(c, 0) for c in BASE_COLORS]
|
||||||
|
return ObsCard(
|
||||||
|
tier=card.tier,
|
||||||
|
points=card.points,
|
||||||
|
color_index=color_index,
|
||||||
|
cost=cost_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_noble(noble: Noble) -> ObsNoble:
|
||||||
|
req_vec = [noble.requirements.get(c, 0) for c in BASE_COLORS]
|
||||||
|
return ObsNoble(
|
||||||
|
points=noble.points,
|
||||||
|
requirements=req_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_player(player: PlayerState) -> ObsPlayer:
|
||||||
|
tokens_vec = [player.tokens[c] for c in GEM_ORDER]
|
||||||
|
discounts_vec = [player.discounts[c] for c in GEM_ORDER]
|
||||||
|
cards_enc = [_encode_card(c) for c in player.cards]
|
||||||
|
reserved_enc = [_encode_card(c) for c in player.reserved]
|
||||||
|
nobles_enc = [_encode_noble(n) for n in player.nobles]
|
||||||
|
return ObsPlayer(
|
||||||
|
tokens=tokens_vec,
|
||||||
|
discounts=discounts_vec,
|
||||||
|
score=player.score,
|
||||||
|
cards=cards_enc,
|
||||||
|
reserved=reserved_enc,
|
||||||
|
nobles=nobles_enc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_observation(game: GameState) -> Observation:
|
||||||
|
"""Create a structured observation of the full public state."""
|
||||||
|
bank_vec = [game.bank[c] for c in GEM_ORDER]
|
||||||
|
players_enc = [_encode_player(p) for p in game.players]
|
||||||
|
table_enc: dict[int, list[ObsCard]] = {
|
||||||
|
tier: [_encode_card(c) for c in row] for tier, row in game.table_by_tier.items()
|
||||||
|
}
|
||||||
|
decks_remaining = {tier: len(deck) for tier, deck in game.decks_by_tier.items()}
|
||||||
|
nobles_enc = [_encode_noble(n) for n in game.available_nobles]
|
||||||
|
return Observation(
|
||||||
|
current_player=game.current_player_index,
|
||||||
|
bank=bank_vec,
|
||||||
|
players=players_enc,
|
||||||
|
table_by_tier=table_enc,
|
||||||
|
decks_remaining=decks_remaining,
|
||||||
|
available_nobles=nobles_enc,
|
||||||
|
)
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""Simulate a step in the game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from .base import Action, GameState, PlayerState, apply_action, check_nobles_for_player
|
||||||
|
from .bot import RandomBot
|
||||||
|
|
||||||
|
|
||||||
|
class SimStrategy(RandomBot):
|
||||||
|
"""Strategy used in simulate_step.
|
||||||
|
|
||||||
|
We never call choose_action here (caller chooses actions),
|
||||||
|
but we reuse discard/noble-selection logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def choose_action(self, game: GameState, player: PlayerState) -> Action | None: # noqa: ARG002
|
||||||
|
"""Choose an action for the current player."""
|
||||||
|
msg = "SimStrategy.choose_action should not be used in simulate_step"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def simulate_step(game: GameState, action: Action) -> GameState:
|
||||||
|
"""Return a deep-copied next state after applying action for the current player.
|
||||||
|
|
||||||
|
Useful for tree search / MCTS:
|
||||||
|
|
||||||
|
next_state = simulate_step(state, action)
|
||||||
|
"""
|
||||||
|
next_state = copy.deepcopy(game)
|
||||||
|
sim_strategy = SimStrategy()
|
||||||
|
apply_action(next_state, sim_strategy, action)
|
||||||
|
check_nobles_for_player(next_state, sim_strategy, next_state.current_player)
|
||||||
|
next_state.next_player()
|
||||||
|
return next_state
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""Simulator for Splendor game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
from .base import GameConfig, load_cards, load_nobles, new_game, run_game
|
||||||
|
from .bot import PersonalizedBot4, RandomBot
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main entry point."""
|
||||||
|
turn_limit = 1000
|
||||||
|
good_games = 0
|
||||||
|
games = 1
|
||||||
|
winners: dict[str, list] = defaultdict(list)
|
||||||
|
game_data = Path(__file__).parent / "game_data"
|
||||||
|
|
||||||
|
cards = load_cards(game_data / "cards/default.json")
|
||||||
|
nobles = load_nobles(game_data / "nobles/default.json")
|
||||||
|
|
||||||
|
for _ in range(games):
|
||||||
|
bot_a = RandomBot("bot_a")
|
||||||
|
bot_b = RandomBot("bot_b")
|
||||||
|
bot_c = RandomBot("bot_c")
|
||||||
|
bot_d = PersonalizedBot4("my_bot")
|
||||||
|
config = GameConfig(
|
||||||
|
cards=cards,
|
||||||
|
nobles=nobles,
|
||||||
|
turn_limit=turn_limit,
|
||||||
|
)
|
||||||
|
players = (bot_a, bot_b, bot_c, bot_d)
|
||||||
|
game_state = new_game(players, config)
|
||||||
|
winner, turns = run_game(game_state)
|
||||||
|
if turns < turn_limit:
|
||||||
|
good_games += 1
|
||||||
|
winners[winner.strategy.name].append(turns)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"out of {games} {turn_limit} turn games with {len(players)}"
|
||||||
|
f"random bots there where {good_games} games where a bot won"
|
||||||
|
)
|
||||||
|
for name, turns in winners.items():
|
||||||
|
print(f"{name} won {len(turns)} games in {mean(turns):.2f} turns")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
import tomllib
|
import tomllib
|
||||||
from os import environ
|
from os import environ
|
||||||
from pathlib import Path # noqa: TC003 This is required for the typer CLI
|
from pathlib import Path
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Audiobook tools."""
|
|
||||||
@@ -1,471 +0,0 @@
|
|||||||
"""Convert Audible AAX downloads into Audiobookshelf-friendly M4B files."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from os import getenv
|
|
||||||
from pathlib import Path # noqa: TC003 This is required for the typer CLI
|
|
||||||
from typing import TYPE_CHECKING, Annotated, Any
|
|
||||||
from uuid import uuid7
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from python.common import configure_logger
|
|
||||||
from python.orm.common import get_postgres_engine
|
|
||||||
from python.tools.audiobook.metadata_agent import (
|
|
||||||
AgentConfig,
|
|
||||||
StandardBookMetadata,
|
|
||||||
standard_book_metadata,
|
|
||||||
write_agent_log,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SENSITIVE_COMMAND_ARGUMENTS = {"-activation_bytes"}
|
|
||||||
BOOK_RANGE_PATTERN = re.compile(r"(?:^|-)books?-(?P<start>[1-9]\d*)-(?P<end>[1-9]\d*)(?:-|$)")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConversionConfig:
|
|
||||||
"""Runtime settings for one conversion command."""
|
|
||||||
|
|
||||||
resolved_output: Path
|
|
||||||
ollama_api_key: str
|
|
||||||
agent_config: AgentConfig
|
|
||||||
engine: Engine
|
|
||||||
activation_bytes: str | None
|
|
||||||
dry_run: bool
|
|
||||||
overwrite: bool
|
|
||||||
work_directory_name: str = ".audible_convert"
|
|
||||||
dry_run_directory_name: str = "dry-run"
|
|
||||||
temp_directory_name: str = "tmp"
|
|
||||||
log_directory_name: str = "logs"
|
|
||||||
review_directory_name: str = "review"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConcurrentConversionResult:
|
|
||||||
"""Result from running ffmpeg and metadata resolution together."""
|
|
||||||
|
|
||||||
metadata: StandardBookMetadata | None
|
|
||||||
conversion_error: Exception | None
|
|
||||||
metadata_error: Exception | None
|
|
||||||
|
|
||||||
|
|
||||||
class CommandExecutionError(RuntimeError):
|
|
||||||
"""Command failed without exposing sensitive arguments."""
|
|
||||||
|
|
||||||
def __init__(self, arguments: list[str], returncode: int) -> None:
|
|
||||||
"""Create a redacted command failure."""
|
|
||||||
self.arguments = tuple(arguments)
|
|
||||||
self.returncode = returncode
|
|
||||||
command = " ".join(redact_command_arguments(arguments))
|
|
||||||
super().__init__(f"Command failed with exit code {returncode}: {command}")
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
|
||||||
input_directory: Annotated[Path, typer.Argument(help="Directory audible-cli downloads AAX files into.")],
|
|
||||||
output_directory: Annotated[Path, typer.Argument(help="Audiobook output directory.")],
|
|
||||||
*,
|
|
||||||
dry_run: Annotated[
|
|
||||||
bool,
|
|
||||||
typer.Option("--dry-run", help="Print planned output files and write marker files without converting."),
|
|
||||||
] = False,
|
|
||||||
overwrite: Annotated[bool, typer.Option("--overwrite", help="Overwrite existing M4B files.")] = False,
|
|
||||||
) -> None:
|
|
||||||
"""Convert AAX files from a download directory into M4B files."""
|
|
||||||
configure_logger()
|
|
||||||
resolved_input = input_directory.resolve(strict=True)
|
|
||||||
resolved_output = output_directory.resolve()
|
|
||||||
if not dry_run:
|
|
||||||
resolved_output.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
ollama_api_key = getenv("OLLAMA_API_KEY")
|
|
||||||
if not ollama_api_key:
|
|
||||||
msg = "OLLAMA_API_KEY is required for audiobook metadata resolution"
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
config = ConversionConfig(
|
|
||||||
resolved_output=resolved_output,
|
|
||||||
ollama_api_key=ollama_api_key,
|
|
||||||
agent_config=AgentConfig(),
|
|
||||||
engine=get_postgres_engine(name="RICHIE"),
|
|
||||||
activation_bytes=getenv("AUDIBLE_ACTIVATION_BYTES"),
|
|
||||||
dry_run=dry_run,
|
|
||||||
overwrite=overwrite,
|
|
||||||
)
|
|
||||||
|
|
||||||
aax_files = sorted(resolved_input.glob("*.aax"))
|
|
||||||
if not aax_files:
|
|
||||||
logger.info("No AAX files found in %s", resolved_input)
|
|
||||||
return
|
|
||||||
for aax_file in aax_files:
|
|
||||||
logger.info("Converting %s", aax_file)
|
|
||||||
convert_aax_file_with_agent(aax_file, config)
|
|
||||||
|
|
||||||
|
|
||||||
def run_command(arguments: list[str], *, capture: bool = False) -> subprocess.CompletedProcess[str]:
|
|
||||||
"""Run a command and return the completed process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
arguments: Command and arguments to run.
|
|
||||||
capture: Whether to capture stdout and stderr.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The completed process.
|
|
||||||
"""
|
|
||||||
logger.debug("%s", " ".join(redact_command_arguments(arguments)))
|
|
||||||
try:
|
|
||||||
return subprocess.run(arguments, check=True, capture_output=capture, text=True)
|
|
||||||
except subprocess.CalledProcessError as error:
|
|
||||||
raise CommandExecutionError(arguments, error.returncode) from error
|
|
||||||
|
|
||||||
|
|
||||||
def redact_command_arguments(arguments: list[str]) -> list[str]:
|
|
||||||
"""Return command arguments with sensitive values redacted."""
|
|
||||||
redacted = []
|
|
||||||
redact_next = False
|
|
||||||
for argument in arguments:
|
|
||||||
if redact_next:
|
|
||||||
redacted.append("<redacted>")
|
|
||||||
redact_next = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
redacted.append(argument)
|
|
||||||
redact_next = argument in SENSITIVE_COMMAND_ARGUMENTS
|
|
||||||
return redacted
|
|
||||||
|
|
||||||
|
|
||||||
def read_metadata(aax_file: Path) -> dict[str, str]:
|
|
||||||
"""Read ffprobe format tags from an AAX file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
aax_file: AAX file to inspect.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Lower-cased metadata tag names mapped to their values.
|
|
||||||
"""
|
|
||||||
completed = run_command(
|
|
||||||
[
|
|
||||||
"ffprobe",
|
|
||||||
"-v",
|
|
||||||
"quiet",
|
|
||||||
"-print_format",
|
|
||||||
"json",
|
|
||||||
"-show_format",
|
|
||||||
str(aax_file),
|
|
||||||
],
|
|
||||||
capture=True,
|
|
||||||
)
|
|
||||||
ffprobe_data: dict[str, Any] = json.loads(completed.stdout)
|
|
||||||
tags = ffprobe_data.get("format", {}).get("tags", {})
|
|
||||||
return {str(key).lower(): str(value) for key, value in tags.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def output_stem(metadata: StandardBookMetadata) -> str:
|
|
||||||
"""Build the output stem for a book.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadata: Book metadata.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output stem in author-series_01-title form.
|
|
||||||
"""
|
|
||||||
index_slug = series_index_slug(metadata.series_index, metadata.title)
|
|
||||||
return f"{metadata.author}-{metadata.series}_{index_slug}-{metadata.title}"
|
|
||||||
|
|
||||||
|
|
||||||
def series_index_slug(series_index: float, title: str = "") -> str:
|
|
||||||
"""Return a filename-safe series index."""
|
|
||||||
if title_range := title_series_range_slug(series_index, title):
|
|
||||||
return title_range
|
|
||||||
index = float(series_index)
|
|
||||||
if index.is_integer():
|
|
||||||
return f"{int(index):02}"
|
|
||||||
return f"{int(index):02}.5"
|
|
||||||
|
|
||||||
|
|
||||||
def title_series_range_slug(series_index: float, title: str) -> str | None:
|
|
||||||
"""Return a series range slug found in an omnibus title."""
|
|
||||||
index = float(series_index)
|
|
||||||
if not index.is_integer():
|
|
||||||
return None
|
|
||||||
first_index = int(index)
|
|
||||||
for match in BOOK_RANGE_PATTERN.finditer(title):
|
|
||||||
start = int(match.group("start"))
|
|
||||||
end = int(match.group("end"))
|
|
||||||
if start == first_index and end > start:
|
|
||||||
return f"{start:02}-{end:02}"
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def metadata_output_path(output_directory: Path, metadata: StandardBookMetadata) -> Path:
|
|
||||||
"""Build the final M4B path from resolved metadata."""
|
|
||||||
stem = output_stem(metadata)
|
|
||||||
return output_directory / stem / f"{stem}.m4b"
|
|
||||||
|
|
||||||
|
|
||||||
def convert_aax_file(
|
|
||||||
aax_file: Path,
|
|
||||||
destination: Path,
|
|
||||||
activation_bytes: str | None,
|
|
||||||
*,
|
|
||||||
overwrite: bool,
|
|
||||||
) -> None:
|
|
||||||
"""Convert an AAX file into an M4B file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
aax_file: Source AAX file.
|
|
||||||
destination: Destination M4B file.
|
|
||||||
activation_bytes: Optional Audible activation bytes for ffmpeg.
|
|
||||||
overwrite: Whether to overwrite an existing M4B.
|
|
||||||
"""
|
|
||||||
if destination.exists() and not overwrite:
|
|
||||||
logger.info("Skipping existing file %s", destination)
|
|
||||||
return
|
|
||||||
|
|
||||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
arguments = ["ffmpeg", "-hide_banner", "-y" if overwrite else "-n"]
|
|
||||||
if activation_bytes:
|
|
||||||
arguments.extend(["-activation_bytes", activation_bytes])
|
|
||||||
arguments.extend(["-i", str(aax_file), "-map_metadata", "0", "-c", "copy", str(destination)])
|
|
||||||
run_command(arguments)
|
|
||||||
|
|
||||||
|
|
||||||
def write_review_file(
|
|
||||||
*,
|
|
||||||
destination: Path | None,
|
|
||||||
ffprobe_metadata: dict[str, str],
|
|
||||||
log_file: Path,
|
|
||||||
metadata: StandardBookMetadata | None,
|
|
||||||
reason: str,
|
|
||||||
review_file: Path,
|
|
||||||
source: Path,
|
|
||||||
temp_file: Path | None,
|
|
||||||
) -> None:
|
|
||||||
"""Write a manual review file for an unresolved conversion."""
|
|
||||||
review_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
payload = {
|
|
||||||
"destination": str(destination) if destination else None,
|
|
||||||
"ffprobe_metadata": ffprobe_metadata,
|
|
||||||
"metadata": asdict(metadata) if metadata else None,
|
|
||||||
"reason": reason,
|
|
||||||
"source": str(source),
|
|
||||||
"temp_file": str(temp_file) if temp_file else None,
|
|
||||||
}
|
|
||||||
review_file.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
|
||||||
write_agent_log(log_file, "review_written", path=str(review_file), reason=reason)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_temp_output(temp_file: Path) -> None:
|
|
||||||
"""Remove a run's temporary output directory."""
|
|
||||||
shutil.rmtree(temp_file.parent, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
def dry_run_aax_file_with_agent(
|
|
||||||
aax_file: Path,
|
|
||||||
ffprobe_metadata: dict[str, str],
|
|
||||||
engine: Engine,
|
|
||||||
config: ConversionConfig,
|
|
||||||
log_file: Path,
|
|
||||||
review_file: Path,
|
|
||||||
) -> None:
|
|
||||||
"""Resolve and print the planned output path without converting."""
|
|
||||||
metadata = standard_book_metadata(
|
|
||||||
aax_file.name,
|
|
||||||
ffprobe_metadata,
|
|
||||||
engine,
|
|
||||||
log_file,
|
|
||||||
config.ollama_api_key,
|
|
||||||
config.agent_config,
|
|
||||||
)
|
|
||||||
destination = None if metadata.needs_review else metadata_output_path(config.resolved_output, metadata)
|
|
||||||
if metadata.needs_review:
|
|
||||||
write_review_file(
|
|
||||||
destination=destination,
|
|
||||||
ffprobe_metadata=ffprobe_metadata,
|
|
||||||
log_file=log_file,
|
|
||||||
metadata=metadata,
|
|
||||||
reason="metadata_needs_review",
|
|
||||||
review_file=review_file,
|
|
||||||
source=aax_file,
|
|
||||||
temp_file=None,
|
|
||||||
)
|
|
||||||
typer.echo(f"{aax_file} -> REVIEW {review_file}")
|
|
||||||
else:
|
|
||||||
stem = output_stem(metadata)
|
|
||||||
dry_run_file = (
|
|
||||||
config.resolved_output / config.work_directory_name / config.dry_run_directory_name / stem / f"{stem}.m4b"
|
|
||||||
)
|
|
||||||
dry_run_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
dry_run_file.write_text(f"{destination}\n", encoding="utf-8")
|
|
||||||
write_agent_log(
|
|
||||||
log_file,
|
|
||||||
"dry_run_file_written",
|
|
||||||
destination=str(destination),
|
|
||||||
path=str(dry_run_file),
|
|
||||||
)
|
|
||||||
typer.echo(f"{aax_file} -> {destination}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_temp_file_and_resolve_metadata(
|
|
||||||
aax_file: Path,
|
|
||||||
temp_file: Path,
|
|
||||||
ffprobe_metadata: dict[str, str],
|
|
||||||
config: ConversionConfig,
|
|
||||||
log_file: Path,
|
|
||||||
) -> ConcurrentConversionResult:
|
|
||||||
"""Run ffmpeg and metadata resolution in parallel."""
|
|
||||||
conversion_error: Exception | None = None
|
|
||||||
metadata_error: Exception | None = None
|
|
||||||
metadata: StandardBookMetadata | None = None
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
||||||
conversion_future = executor.submit(
|
|
||||||
convert_aax_file,
|
|
||||||
aax_file,
|
|
||||||
temp_file,
|
|
||||||
config.activation_bytes,
|
|
||||||
overwrite=True,
|
|
||||||
)
|
|
||||||
metadata_future = executor.submit(
|
|
||||||
standard_book_metadata,
|
|
||||||
aax_file.name,
|
|
||||||
ffprobe_metadata,
|
|
||||||
config.engine,
|
|
||||||
log_file,
|
|
||||||
config.ollama_api_key,
|
|
||||||
config.agent_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
conversion_error = conversion_future.exception()
|
|
||||||
if conversion_error is None:
|
|
||||||
conversion_future.result()
|
|
||||||
|
|
||||||
metadata_error = metadata_future.exception()
|
|
||||||
if metadata_error is None:
|
|
||||||
metadata = metadata_future.result()
|
|
||||||
|
|
||||||
return ConcurrentConversionResult(
|
|
||||||
metadata=metadata,
|
|
||||||
conversion_error=conversion_error,
|
|
||||||
metadata_error=metadata_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_aax_file_with_agent(aax_file: Path, config: ConversionConfig) -> None:
|
|
||||||
"""Convert one AAX file using the metadata agent for the final path."""
|
|
||||||
run_id = uuid7().hex
|
|
||||||
log_file = config.resolved_output / config.work_directory_name / config.log_directory_name / f"{run_id}.jsonl"
|
|
||||||
review_file = config.resolved_output / config.work_directory_name / config.review_directory_name / f"{run_id}.json"
|
|
||||||
write_agent_log(log_file, "conversion_start", source=str(aax_file), dry_run=config.dry_run)
|
|
||||||
try:
|
|
||||||
ffprobe_metadata = read_metadata(aax_file)
|
|
||||||
except Exception as error:
|
|
||||||
logger.exception("ffprobe failed")
|
|
||||||
write_review_file(
|
|
||||||
destination=None,
|
|
||||||
ffprobe_metadata={},
|
|
||||||
log_file=log_file,
|
|
||||||
metadata=None,
|
|
||||||
reason=f"ffprobe_failed: {error}",
|
|
||||||
review_file=review_file,
|
|
||||||
source=aax_file,
|
|
||||||
temp_file=None,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if config.dry_run:
|
|
||||||
dry_run_aax_file_with_agent(
|
|
||||||
aax_file,
|
|
||||||
ffprobe_metadata,
|
|
||||||
config.engine,
|
|
||||||
config,
|
|
||||||
log_file,
|
|
||||||
review_file,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
temp_file = (
|
|
||||||
config.resolved_output / config.work_directory_name / config.temp_directory_name / run_id / "converted.m4b"
|
|
||||||
)
|
|
||||||
temp_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
result = convert_temp_file_and_resolve_metadata(aax_file, temp_file, ffprobe_metadata, config, log_file)
|
|
||||||
|
|
||||||
if result.conversion_error:
|
|
||||||
reason = f"ffmpeg_failed: {result.conversion_error}"
|
|
||||||
write_review_file(
|
|
||||||
destination=None,
|
|
||||||
ffprobe_metadata=ffprobe_metadata,
|
|
||||||
log_file=log_file,
|
|
||||||
metadata=result.metadata,
|
|
||||||
reason=reason,
|
|
||||||
review_file=review_file,
|
|
||||||
source=aax_file,
|
|
||||||
temp_file=temp_file if temp_file.exists() else None,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if result.metadata_error:
|
|
||||||
write_review_file(
|
|
||||||
destination=None,
|
|
||||||
ffprobe_metadata=ffprobe_metadata,
|
|
||||||
log_file=log_file,
|
|
||||||
metadata=None,
|
|
||||||
reason=f"metadata_failed: {result.metadata_error}",
|
|
||||||
review_file=review_file,
|
|
||||||
source=aax_file,
|
|
||||||
temp_file=temp_file,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if result.metadata is None or result.metadata.needs_review:
|
|
||||||
write_review_file(
|
|
||||||
destination=None,
|
|
||||||
ffprobe_metadata=ffprobe_metadata,
|
|
||||||
log_file=log_file,
|
|
||||||
metadata=result.metadata,
|
|
||||||
reason="metadata_needs_review",
|
|
||||||
review_file=review_file,
|
|
||||||
source=aax_file,
|
|
||||||
temp_file=temp_file,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
destination = metadata_output_path(config.resolved_output, result.metadata)
|
|
||||||
if destination.exists() and not config.overwrite:
|
|
||||||
write_agent_log(log_file, "destination_exists", destination=str(destination))
|
|
||||||
cleanup_temp_output(temp_file)
|
|
||||||
return
|
|
||||||
|
|
||||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
|
||||||
temp_file.replace(destination)
|
|
||||||
except OSError as error:
|
|
||||||
write_review_file(
|
|
||||||
destination=destination,
|
|
||||||
ffprobe_metadata=ffprobe_metadata,
|
|
||||||
log_file=log_file,
|
|
||||||
metadata=result.metadata,
|
|
||||||
reason=f"rename_failed: {error}",
|
|
||||||
review_file=review_file,
|
|
||||||
source=aax_file,
|
|
||||||
temp_file=temp_file if temp_file.exists() else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cleanup_temp_output(temp_file)
|
|
||||||
write_agent_log(log_file, "conversion_complete", destination=str(destination))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
typer.run(main)
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
"""Import audiobook catalog authors and series from CSV files."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import csv
|
|
||||||
import logging
|
|
||||||
from pathlib import Path # noqa: TC003 This is required for the typer CLI
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from python.common import configure_logger
|
|
||||||
from python.orm.common import get_postgres_engine
|
|
||||||
from python.orm.richie import AudiobookAuthor, AudiobookSeries
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
AUTHOR_NAME_COLUMN = "author_name"
|
|
||||||
ID_COLUMN = "id"
|
|
||||||
NAME_COLUMN = "name"
|
|
||||||
|
|
||||||
|
|
||||||
class CatalogImportError(ValueError):
|
|
||||||
"""CSV catalog import failed validation."""
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
|
||||||
authors_csv: Annotated[Path, typer.Argument(help="CSV with name and optional id.")],
|
|
||||||
series_csv: Annotated[Path, typer.Argument(help="CSV with name, author_name, and optional id.")],
|
|
||||||
) -> None:
|
|
||||||
"""Upsert audiobook authors and series from CSV files."""
|
|
||||||
configure_logger()
|
|
||||||
try:
|
|
||||||
engine = get_postgres_engine(name="RICHIE")
|
|
||||||
with Session(engine) as session:
|
|
||||||
author_count = upsert_authors_from_csv(session, authors_csv)
|
|
||||||
series_count = upsert_series_from_csv(session, series_csv)
|
|
||||||
session.commit()
|
|
||||||
except CatalogImportError as error:
|
|
||||||
typer.echo(str(error), err=True)
|
|
||||||
raise typer.Exit(code=1) from error
|
|
||||||
|
|
||||||
logger.info("Upserted %s authors and %s series", author_count, series_count)
|
|
||||||
|
|
||||||
|
|
||||||
def upsert_authors_from_csv(session: Session, authors_csv: Path) -> int:
|
|
||||||
"""Upsert authors from a CSV file."""
|
|
||||||
count = 0
|
|
||||||
for row_number, row in csv_rows(authors_csv):
|
|
||||||
name = required_csv_value(row, authors_csv, row_number, NAME_COLUMN)
|
|
||||||
upsert_author(session, name, csv_id(row, authors_csv, row_number))
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def upsert_series_from_csv(session: Session, series_csv: Path) -> int:
|
|
||||||
"""Upsert series from a CSV file."""
|
|
||||||
count = 0
|
|
||||||
for row_number, row in csv_rows(series_csv):
|
|
||||||
series_name = required_csv_value(row, series_csv, row_number, NAME_COLUMN)
|
|
||||||
author_name = required_csv_value(row, series_csv, row_number, AUTHOR_NAME_COLUMN)
|
|
||||||
author = find_author_by_name(session, author_name)
|
|
||||||
if author is None:
|
|
||||||
msg = f"{series_csv}:{row_number}: author not found: {author_name}"
|
|
||||||
raise CatalogImportError(msg)
|
|
||||||
upsert_series(session, series_name, author, csv_id(row, series_csv, row_number))
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def upsert_author(session: Session, name: str, author_id: int | None) -> AudiobookAuthor:
|
|
||||||
"""Upsert one author by id or exact name."""
|
|
||||||
if author_id is not None:
|
|
||||||
author = session.get(AudiobookAuthor, author_id)
|
|
||||||
if author is None:
|
|
||||||
author = AudiobookAuthor(id=author_id, name=name)
|
|
||||||
session.add(author)
|
|
||||||
else:
|
|
||||||
author.name = name
|
|
||||||
session.flush()
|
|
||||||
return author
|
|
||||||
|
|
||||||
author = find_author_by_name(session, name)
|
|
||||||
if author is None:
|
|
||||||
author = AudiobookAuthor(name=name)
|
|
||||||
session.add(author)
|
|
||||||
session.flush()
|
|
||||||
return author
|
|
||||||
|
|
||||||
|
|
||||||
def upsert_series(
|
|
||||||
session: Session,
|
|
||||||
name: str,
|
|
||||||
author: AudiobookAuthor,
|
|
||||||
series_id: int | None,
|
|
||||||
) -> AudiobookSeries:
|
|
||||||
"""Upsert one series by id or exact author/name match."""
|
|
||||||
if series_id is not None:
|
|
||||||
series = session.get(AudiobookSeries, series_id)
|
|
||||||
if series is None:
|
|
||||||
series = AudiobookSeries(id=series_id, name=name, author=author)
|
|
||||||
session.add(series)
|
|
||||||
else:
|
|
||||||
series.name = name
|
|
||||||
series.author = author
|
|
||||||
session.flush()
|
|
||||||
return series
|
|
||||||
|
|
||||||
series = find_series_by_name_and_author(session, name, author.id)
|
|
||||||
if series is None:
|
|
||||||
series = AudiobookSeries(name=name, author=author)
|
|
||||||
session.add(series)
|
|
||||||
session.flush()
|
|
||||||
return series
|
|
||||||
|
|
||||||
|
|
||||||
def find_author_by_name(session: Session, name: str) -> AudiobookAuthor | None:
|
|
||||||
"""Find one author by exact name."""
|
|
||||||
return session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
|
|
||||||
|
|
||||||
|
|
||||||
def find_series_by_name_and_author(
|
|
||||||
session: Session,
|
|
||||||
name: str,
|
|
||||||
author_id: int,
|
|
||||||
) -> AudiobookSeries | None:
|
|
||||||
"""Find one series by exact name and author."""
|
|
||||||
return session.scalar(
|
|
||||||
select(AudiobookSeries).where(
|
|
||||||
AudiobookSeries.name == name,
|
|
||||||
AudiobookSeries.author_id == author_id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def csv_rows(csv_path: Path) -> list[tuple[int, dict[str, str | None]]]:
|
|
||||||
"""Read a CSV file as numbered rows."""
|
|
||||||
with csv_path.open(newline="", encoding="utf-8") as file:
|
|
||||||
reader = csv.DictReader(file)
|
|
||||||
if reader.fieldnames is None:
|
|
||||||
msg = f"{csv_path}: missing CSV header"
|
|
||||||
raise CatalogImportError(msg)
|
|
||||||
return [(row_number, row) for row_number, row in enumerate(reader, start=2)]
|
|
||||||
|
|
||||||
|
|
||||||
def required_csv_value(
|
|
||||||
row: dict[str, str | None],
|
|
||||||
csv_path: Path,
|
|
||||||
row_number: int,
|
|
||||||
column: str,
|
|
||||||
) -> str:
|
|
||||||
"""Read a required CSV value."""
|
|
||||||
value = row.get(column)
|
|
||||||
if value and value.strip():
|
|
||||||
return value.strip()
|
|
||||||
msg = f"{csv_path}:{row_number}: missing required column value: {column}"
|
|
||||||
raise CatalogImportError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def csv_id(row: dict[str, str | None], csv_path: Path, row_number: int) -> int | None:
|
|
||||||
"""Read an optional id field from a CSV row."""
|
|
||||||
value = row.get(ID_COLUMN)
|
|
||||||
if value is None or not value.strip():
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError as error:
|
|
||||||
msg = f"{csv_path}:{row_number}: id must be an integer: {value}"
|
|
||||||
raise CatalogImportError(msg) from error
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
typer.run(main)
|
|
||||||
@@ -1,599 +0,0 @@
|
|||||||
"""LLM tool calling support for audiobook metadata resolution."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from sqlalchemy import or_, select
|
|
||||||
|
|
||||||
from python.orm.richie import Audiobook, AudiobookAuthor, AudiobookSeries
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from python.tools.audiobook.metadata_agent import AgentConfig
|
|
||||||
|
|
||||||
CATALOG_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+)*$")
|
|
||||||
TITLE_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
|
|
||||||
|
|
||||||
LogWriter = Callable[..., None]
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataResolutionError(ValueError):
|
|
||||||
"""Metadata resolution failed validation."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class EnsuredBook:
|
|
||||||
"""Book row plus whether it was created."""
|
|
||||||
|
|
||||||
book: Audiobook
|
|
||||||
action: str
|
|
||||||
|
|
||||||
|
|
||||||
class CatalogToolRegistry:
|
|
||||||
"""Controlled catalog tools exposed to the metadata model."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
log_path: Path,
|
|
||||||
config: AgentConfig,
|
|
||||||
write_log: LogWriter,
|
|
||||||
) -> None:
|
|
||||||
"""Create a registry bound to one database session and audit log."""
|
|
||||||
self.session = session
|
|
||||||
self.log_path = log_path
|
|
||||||
self.config = config
|
|
||||||
self.write_log = write_log
|
|
||||||
self.seen_author_ids: set[int] = set()
|
|
||||||
self.seen_series_ids: set[int] = set()
|
|
||||||
self.seen_book_ids: set[int] = set()
|
|
||||||
self.created_author_ids: set[int] = set()
|
|
||||||
self.created_series_ids: set[int] = set()
|
|
||||||
self.created_book_ids: set[int] = set()
|
|
||||||
|
|
||||||
def tool_schemas(self) -> list[dict[str, object]]:
|
|
||||||
"""Return Ollama tool schemas."""
|
|
||||||
schemas = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_authors",
|
|
||||||
"description": "Search canonical audiobook authors by slug or noisy source text.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_series",
|
|
||||||
"description": "Search canonical audiobook series by slug or noisy source text.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string"},
|
|
||||||
"author_id": {"type": ["integer", "null"]},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_books",
|
|
||||||
"description": "Search canonical audiobook titles with optional author and series filters.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string"},
|
|
||||||
"author_id": {"type": ["integer", "null"]},
|
|
||||||
"series_id": {"type": ["integer", "null"]},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "ensure_author",
|
|
||||||
"description": "Normalize an author name to a catalog slug, then return or create that author.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"name": {"type": "string"}},
|
|
||||||
"required": ["name"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "ensure_series",
|
|
||||||
"description": "Normalize a series name to a catalog slug, then return or create it for an author.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"author_id": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"required": ["name", "author_id"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "ensure_book",
|
|
||||||
"description": "Normalize a title to a book slug, then return or create it for an author/series.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"title": {"type": "string"},
|
|
||||||
"author_id": {"type": "integer"},
|
|
||||||
"series_id": {"type": ["integer", "null"]},
|
|
||||||
"series_index": {"type": "number", "multipleOf": 0.5},
|
|
||||||
},
|
|
||||||
"required": ["title", "author_id", "series_id", "series_index"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
enabled_tool_names = set(self.config.tool_names)
|
|
||||||
return [schema for schema in schemas if schema["function"]["name"] in enabled_tool_names]
|
|
||||||
|
|
||||||
def run(self, name: str, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Run one catalog tool and audit the call."""
|
|
||||||
handlers = {
|
|
||||||
"search_authors": self.run_search_authors,
|
|
||||||
"search_series": self.run_search_series,
|
|
||||||
"search_books": self.run_search_books,
|
|
||||||
"ensure_author": self.run_ensure_author,
|
|
||||||
"ensure_series": self.run_ensure_series,
|
|
||||||
"ensure_book": self.run_ensure_book,
|
|
||||||
}
|
|
||||||
handler = handlers.get(name)
|
|
||||||
if handler is None:
|
|
||||||
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="unknown_tool")
|
|
||||||
msg = f"Unknown audiobook metadata tool: {name}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if name not in self.config.tool_names:
|
|
||||||
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="tool_not_enabled")
|
|
||||||
msg = f"Audiobook metadata tool is not enabled: {name}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
started = time.perf_counter()
|
|
||||||
self.write_log(self.log_path, "tool_call", tool=name, arguments=arguments)
|
|
||||||
result = handler(arguments)
|
|
||||||
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
|
||||||
self.write_log(
|
|
||||||
self.log_path,
|
|
||||||
"tool_result",
|
|
||||||
tool=name,
|
|
||||||
duration_ms=duration_ms,
|
|
||||||
result_count=len(result),
|
|
||||||
preview=result[:3],
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_author(self, author_id: int) -> AudiobookAuthor | None:
|
|
||||||
"""Return an author by id."""
|
|
||||||
return self.session.get(AudiobookAuthor, author_id)
|
|
||||||
|
|
||||||
def get_book(self, book_id: int) -> Audiobook | None:
|
|
||||||
"""Return a book by id."""
|
|
||||||
return self.session.get(Audiobook, book_id)
|
|
||||||
|
|
||||||
def get_series(self, series_id: int) -> AudiobookSeries | None:
|
|
||||||
"""Return a series by id."""
|
|
||||||
return self.session.get(AudiobookSeries, series_id)
|
|
||||||
|
|
||||||
def prune_unused_created_rows(self, *, author_id: int, book_id: int | None, series_id: int | None) -> None:
|
|
||||||
"""Remove catalog rows created during this run but not used by final metadata."""
|
|
||||||
used_book_ids = {book_id} if book_id is not None else set()
|
|
||||||
for created_book_id in self.created_book_ids - used_book_ids:
|
|
||||||
if book := self.get_book(created_book_id):
|
|
||||||
self.session.delete(book)
|
|
||||||
|
|
||||||
self.session.flush()
|
|
||||||
used_series_ids = {series_id} if series_id is not None else set()
|
|
||||||
for created_series_id in self.created_series_ids - used_series_ids:
|
|
||||||
series = self.get_series(created_series_id)
|
|
||||||
if series and not series.books:
|
|
||||||
self.session.delete(series)
|
|
||||||
|
|
||||||
self.session.flush()
|
|
||||||
for created_author_id in self.created_author_ids - {author_id}:
|
|
||||||
author = self.get_author(created_author_id)
|
|
||||||
if author and not author.books and not author.series:
|
|
||||||
self.session.delete(author)
|
|
||||||
|
|
||||||
def run_search_authors(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Search authors from tool arguments and remember returned ids."""
|
|
||||||
query = required_string(arguments, "query")
|
|
||||||
statement = select(AudiobookAuthor).order_by(AudiobookAuthor.name).limit(self.config.max_tool_results)
|
|
||||||
if terms := query_terms(query):
|
|
||||||
statement = statement.where(or_(*(AudiobookAuthor.name.ilike(f"%{term}%") for term in terms)))
|
|
||||||
|
|
||||||
authors = self.session.scalars(statement).all()
|
|
||||||
self.seen_author_ids.update(author.id for author in authors)
|
|
||||||
return [{"id": author.id, "name": author.name} for author in authors]
|
|
||||||
|
|
||||||
def run_search_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Search series from tool arguments and remember returned ids."""
|
|
||||||
query = required_string(arguments, "query")
|
|
||||||
author_id = optional_int(arguments.get("author_id"), "author_id")
|
|
||||||
statement = select(AudiobookSeries).order_by(AudiobookSeries.name).limit(self.config.max_tool_results)
|
|
||||||
if terms := query_terms(query):
|
|
||||||
statement = statement.where(or_(*(AudiobookSeries.name.ilike(f"%{term}%") for term in terms)))
|
|
||||||
if author_id is not None:
|
|
||||||
statement = statement.where(AudiobookSeries.author_id == author_id)
|
|
||||||
|
|
||||||
series_rows = self.session.scalars(statement).all()
|
|
||||||
self.seen_series_ids.update(series.id for series in series_rows)
|
|
||||||
self.seen_author_ids.update(series.author_id for series in series_rows)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": series.id,
|
|
||||||
"name": series.name,
|
|
||||||
"author_id": series.author_id,
|
|
||||||
"author": series.author.name,
|
|
||||||
}
|
|
||||||
for series in series_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
def run_search_books(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Search books from tool arguments and remember returned ids."""
|
|
||||||
query = required_string(arguments, "query")
|
|
||||||
author_id = optional_int(arguments.get("author_id"), "author_id")
|
|
||||||
series_id = optional_int(arguments.get("series_id"), "series_id")
|
|
||||||
statement = select(Audiobook).order_by(Audiobook.title).limit(self.config.max_tool_results)
|
|
||||||
if terms := query_terms(query):
|
|
||||||
statement = statement.where(or_(*(Audiobook.title.ilike(f"%{term}%") for term in terms)))
|
|
||||||
if author_id is not None:
|
|
||||||
statement = statement.where(Audiobook.author_id == author_id)
|
|
||||||
if series_id is not None:
|
|
||||||
statement = statement.where(Audiobook.series_id == series_id)
|
|
||||||
|
|
||||||
books = self.session.scalars(statement).all()
|
|
||||||
self.seen_book_ids.update(book.id for book in books)
|
|
||||||
self.seen_author_ids.update(book.author_id for book in books)
|
|
||||||
self.seen_series_ids.update(book.series_id for book in books if book.series_id is not None)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": book.id,
|
|
||||||
"title": book.title,
|
|
||||||
"author_id": book.author_id,
|
|
||||||
"author": book.author.name,
|
|
||||||
"series_id": book.series_id,
|
|
||||||
"series": book.series.name if book.series else self.config.standalone_series,
|
|
||||||
"series_index": book.series_index,
|
|
||||||
}
|
|
||||||
for book in books
|
|
||||||
]
|
|
||||||
|
|
||||||
def run_ensure_author(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Ensure an author from tool arguments and return a tool result."""
|
|
||||||
name = normalize_catalog_slug(required_string(arguments, "name"))
|
|
||||||
validate_catalog_slug(name, "author")
|
|
||||||
author = self.session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
|
|
||||||
action = "existing"
|
|
||||||
if author is None:
|
|
||||||
author = AudiobookAuthor(name=name)
|
|
||||||
self.session.add(author)
|
|
||||||
self.session.flush()
|
|
||||||
self.created_author_ids.add(author.id)
|
|
||||||
action = "created"
|
|
||||||
|
|
||||||
self.seen_author_ids.add(author.id)
|
|
||||||
return [{"id": author.id, "name": author.name, "action": action}]
|
|
||||||
|
|
||||||
def run_ensure_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Ensure a series from tool arguments and return a tool result."""
|
|
||||||
name = normalize_catalog_slug(required_string(arguments, "name"))
|
|
||||||
author_id = required_int(arguments, "author_id")
|
|
||||||
validate_catalog_slug(name, "series")
|
|
||||||
author = self.required_author(author_id)
|
|
||||||
series = self.find_series_by_catalog_slug(name, author.id)
|
|
||||||
action = "existing"
|
|
||||||
if series is None:
|
|
||||||
series = AudiobookSeries(name=name, author=author)
|
|
||||||
self.session.add(series)
|
|
||||||
self.session.flush()
|
|
||||||
self.created_series_ids.add(series.id)
|
|
||||||
action = "created"
|
|
||||||
|
|
||||||
self.seen_author_ids.add(author.id)
|
|
||||||
self.seen_series_ids.add(series.id)
|
|
||||||
return [self.series_result(series, action)]
|
|
||||||
|
|
||||||
def run_ensure_book(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Ensure a book from tool arguments and return a tool result."""
|
|
||||||
title = required_string(arguments, "title")
|
|
||||||
author_id = required_int(arguments, "author_id")
|
|
||||||
series_id = optional_int(arguments.get("series_id"), "series_id")
|
|
||||||
series_index = required_series_index(arguments, "series_index")
|
|
||||||
ensured = self.ensure_book(title, author_id, series_id, series_index)
|
|
||||||
return [self.book_result(ensured.book, ensured.action)]
|
|
||||||
|
|
||||||
def ensure_book(
|
|
||||||
self,
|
|
||||||
title: str,
|
|
||||||
author_id: int,
|
|
||||||
series_id: int | None,
|
|
||||||
series_index: float,
|
|
||||||
) -> EnsuredBook:
|
|
||||||
"""Return an existing book row, or create it after validating ownership."""
|
|
||||||
title = normalize_title_slug(title)
|
|
||||||
validate_title_slug(title)
|
|
||||||
author = self.required_author(author_id)
|
|
||||||
series = None
|
|
||||||
if series_id is None:
|
|
||||||
if series_index != 0:
|
|
||||||
msg = "standalone books must use series_index 0"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
else:
|
|
||||||
series = self.required_series(series_id)
|
|
||||||
if series.author_id != author.id:
|
|
||||||
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if series_index <= 0:
|
|
||||||
msg = "series books must use a positive series_index"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
statement = select(Audiobook).where(
|
|
||||||
Audiobook.title == title,
|
|
||||||
Audiobook.author_id == author.id,
|
|
||||||
)
|
|
||||||
if series is None:
|
|
||||||
statement = statement.where(Audiobook.series_id.is_(None))
|
|
||||||
else:
|
|
||||||
statement = statement.where(Audiobook.series_id == series.id)
|
|
||||||
book = self.session.scalar(statement)
|
|
||||||
if book is None:
|
|
||||||
book = Audiobook(title=title, author=author, series=series, series_index=series_index)
|
|
||||||
self.session.add(book)
|
|
||||||
self.session.flush()
|
|
||||||
self.created_book_ids.add(book.id)
|
|
||||||
action = "created"
|
|
||||||
else:
|
|
||||||
action = "existing"
|
|
||||||
|
|
||||||
self.seen_book_ids.add(book.id)
|
|
||||||
self.seen_author_ids.add(author.id)
|
|
||||||
if book.series_id is not None:
|
|
||||||
self.seen_series_ids.add(book.series_id)
|
|
||||||
return EnsuredBook(book=book, action=action)
|
|
||||||
|
|
||||||
def required_author(self, author_id: int) -> AudiobookAuthor:
|
|
||||||
"""Return an author or fail metadata resolution."""
|
|
||||||
author = self.get_author(author_id)
|
|
||||||
if author is None:
|
|
||||||
msg = f"author_id {author_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return author
|
|
||||||
|
|
||||||
def required_series(self, series_id: int) -> AudiobookSeries:
|
|
||||||
"""Return a series or fail metadata resolution."""
|
|
||||||
series = self.get_series(series_id)
|
|
||||||
if series is None:
|
|
||||||
msg = f"series_id {series_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return series
|
|
||||||
|
|
||||||
def find_series_by_catalog_slug(self, name: str, author_id: int) -> AudiobookSeries | None:
|
|
||||||
"""Return a series by exact slug or underscore-insensitive slug."""
|
|
||||||
exact = self.session.scalar(
|
|
||||||
select(AudiobookSeries).where(
|
|
||||||
AudiobookSeries.name == name,
|
|
||||||
AudiobookSeries.author_id == author_id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if exact is not None:
|
|
||||||
return exact
|
|
||||||
|
|
||||||
compact_name = compact_catalog_slug(name)
|
|
||||||
series_rows = self.session.scalars(
|
|
||||||
select(AudiobookSeries).where(AudiobookSeries.author_id == author_id).order_by(AudiobookSeries.name),
|
|
||||||
).all()
|
|
||||||
for series in series_rows:
|
|
||||||
if compact_catalog_slug(series.name) == compact_name:
|
|
||||||
return series
|
|
||||||
return None
|
|
||||||
|
|
||||||
def series_result(self, series: AudiobookSeries, action: str) -> dict[str, object]:
|
|
||||||
"""Build a normalized series tool result."""
|
|
||||||
return {
|
|
||||||
"id": series.id,
|
|
||||||
"name": series.name,
|
|
||||||
"author_id": series.author_id,
|
|
||||||
"author": series.author.name,
|
|
||||||
"action": action,
|
|
||||||
}
|
|
||||||
|
|
||||||
def book_result(self, book: Audiobook, action: str) -> dict[str, object]:
|
|
||||||
"""Build a normalized book tool result."""
|
|
||||||
return {
|
|
||||||
"id": book.id,
|
|
||||||
"title": book.title,
|
|
||||||
"author_id": book.author_id,
|
|
||||||
"author": book.author.name,
|
|
||||||
"series_id": book.series_id,
|
|
||||||
"series": book.series.name if book.series else self.config.standalone_series,
|
|
||||||
"series_index": book.series_index,
|
|
||||||
"action": action,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_tool_calls(
|
|
||||||
messages: list[dict[str, object]],
|
|
||||||
message: dict[str, object],
|
|
||||||
tool_calls: list[tuple[str, dict[str, object]]],
|
|
||||||
registry: CatalogToolRegistry,
|
|
||||||
log_path: Path,
|
|
||||||
write_log: LogWriter,
|
|
||||||
) -> str | None:
|
|
||||||
"""Run tool calls, append tool messages, and return fatal error text when stopped."""
|
|
||||||
messages.append(message)
|
|
||||||
for tool_name, arguments in tool_calls:
|
|
||||||
try:
|
|
||||||
tool_result = registry.run(tool_name, arguments)
|
|
||||||
except MetadataResolutionError as error:
|
|
||||||
if is_fatal_tool_error(error):
|
|
||||||
return str(error)
|
|
||||||
write_log(log_path, "tool_error", tool=tool_name, arguments=arguments, error=str(error))
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"content": json.dumps({"error": str(error)}, sort_keys=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"content": json.dumps(tool_result, sort_keys=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_calls(message: dict[str, object]) -> list[tuple[str, dict[str, object]]]:
|
|
||||||
"""Parse Ollama tool calls from a response message."""
|
|
||||||
raw_tool_calls = message.get("tool_calls") or []
|
|
||||||
if not isinstance(raw_tool_calls, list):
|
|
||||||
msg = "tool_calls must be a list"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
for raw_call in raw_tool_calls:
|
|
||||||
if not isinstance(raw_call, dict):
|
|
||||||
msg = "tool call must be an object"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
function = raw_call.get("function")
|
|
||||||
if not isinstance(function, dict):
|
|
||||||
msg = "tool call is missing function"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
name = function.get("name")
|
|
||||||
if not isinstance(name, str) or not name:
|
|
||||||
msg = "tool call is missing function name"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
arguments = parse_tool_arguments(function.get("arguments", {}))
|
|
||||||
tool_calls.append((name, arguments))
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_arguments(raw_arguments: object) -> dict[str, object]:
|
|
||||||
"""Parse tool call arguments returned by Ollama."""
|
|
||||||
if isinstance(raw_arguments, dict):
|
|
||||||
return {str(key): value for key, value in raw_arguments.items()}
|
|
||||||
if isinstance(raw_arguments, str):
|
|
||||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
return {str(key): value for key, value in parsed.items()}
|
|
||||||
msg = "tool arguments must be an object"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_title_slug(title: str) -> None:
|
|
||||||
"""Validate a canonical book title slug."""
|
|
||||||
if not TITLE_SLUG_PATTERN.fullmatch(title):
|
|
||||||
msg = f"title slug is invalid: {title}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_catalog_slug(value: str, label: str) -> None:
|
|
||||||
"""Validate a canonical catalog slug."""
|
|
||||||
if not CATALOG_SLUG_PATTERN.fullmatch(value):
|
|
||||||
msg = f"{label} slug is invalid: {value}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_catalog_slug(value: str) -> str:
|
|
||||||
"""Normalize noisy catalog names into lower snake-case slugs."""
|
|
||||||
return re.sub(r"[^a-z0-9]+", "_", value.strip().casefold()).strip("_")
|
|
||||||
|
|
||||||
|
|
||||||
def compact_catalog_slug(value: str) -> str:
|
|
||||||
"""Return a catalog slug comparison key that ignores underscores."""
|
|
||||||
return normalize_catalog_slug(value).replace("_", "")
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_title_slug(value: str) -> str:
|
|
||||||
"""Normalize noisy book titles into lower kebab-case slugs."""
|
|
||||||
return re.sub(r"[^a-z0-9]+", "-", value.strip().casefold()).strip("-")
|
|
||||||
|
|
||||||
|
|
||||||
def is_fatal_tool_error(error: MetadataResolutionError) -> bool:
|
|
||||||
"""Return whether a tool error should stop the agent immediately."""
|
|
||||||
message = str(error)
|
|
||||||
return message.startswith(
|
|
||||||
(
|
|
||||||
"Unknown audiobook metadata tool",
|
|
||||||
"Audiobook metadata tool is not enabled",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def query_terms(query: str) -> tuple[str, ...]:
|
|
||||||
"""Return text variants useful for matching noisy audiobook metadata."""
|
|
||||||
normalized = query.strip().casefold()
|
|
||||||
underscore_slug = normalize_catalog_slug(normalized)
|
|
||||||
compact_slug = compact_catalog_slug(normalized)
|
|
||||||
hyphen_slug = normalize_title_slug(normalized)
|
|
||||||
return tuple(dict.fromkeys(term for term in (normalized, underscore_slug, compact_slug, hyphen_slug) if term))
|
|
||||||
|
|
||||||
|
|
||||||
def required_string(data: dict[str, object], key: str) -> str:
|
|
||||||
"""Read a required string field."""
|
|
||||||
value = data.get(key)
|
|
||||||
if not isinstance(value, str) or not value.strip():
|
|
||||||
msg = f"{key} must be a non-empty string"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return value.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def required_int(data: dict[str, object], key: str) -> int:
|
|
||||||
"""Read a required integer field."""
|
|
||||||
value = data.get(key)
|
|
||||||
if isinstance(value, bool) or not isinstance(value, int):
|
|
||||||
msg = f"{key} must be an integer"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def required_series_index(data: dict[str, object], key: str) -> float:
|
|
||||||
"""Read a required whole-number or half-number series index."""
|
|
||||||
value = data.get(key)
|
|
||||||
if isinstance(value, bool) or not isinstance(value, int | float):
|
|
||||||
msg = f"{key} must be a number"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
series_index = float(value)
|
|
||||||
if not (series_index * 2).is_integer():
|
|
||||||
msg = f"{key} must be a whole number or .5 increment"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return series_index
|
|
||||||
|
|
||||||
|
|
||||||
def optional_int(value: object, key: str) -> int | None:
|
|
||||||
"""Read an optional integer field."""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
if isinstance(value, bool) or not isinstance(value, int):
|
|
||||||
msg = f"{key} must be an integer or null"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return value
|
|
||||||
@@ -1,575 +0,0 @@
|
|||||||
"""Resolve audiobook metadata with a controlled Ollama tool loop."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from dataclasses import asdict, dataclass, is_dataclass, replace
|
|
||||||
from os import PathLike
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from python.common import utcnow
|
|
||||||
from python.tools.audiobook.llm_tool_calling import (
|
|
||||||
CatalogToolRegistry,
|
|
||||||
MetadataResolutionError,
|
|
||||||
normalize_title_slug,
|
|
||||||
optional_int,
|
|
||||||
parse_tool_calls,
|
|
||||||
required_int,
|
|
||||||
required_series_index,
|
|
||||||
required_string,
|
|
||||||
run_tool_calls,
|
|
||||||
validate_catalog_slug,
|
|
||||||
validate_title_slug,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
|
|
||||||
from python.orm.richie import AudiobookAuthor
|
|
||||||
|
|
||||||
FENCED_JSON_PATTERN = re.compile(r"^```(?:json)?\s*(?P<json>.*?)\s*```$", re.IGNORECASE | re.DOTALL)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AgentConfig:
|
|
||||||
"""Runtime settings for the audiobook metadata agent."""
|
|
||||||
|
|
||||||
model: str = "deepseek-v4-flash:cloud"
|
|
||||||
ollama_chat_url: str = "https://ollama.com/api/chat"
|
|
||||||
http_timeout_seconds: int = 300
|
|
||||||
max_agent_turns: int = 8
|
|
||||||
max_tool_results: int = 10
|
|
||||||
min_confidence: float = 0.85
|
|
||||||
invalid_final_retries: int = 1
|
|
||||||
standalone_series: str = "standalone"
|
|
||||||
tool_names: tuple[str, ...] = (
|
|
||||||
"search_authors",
|
|
||||||
"search_series",
|
|
||||||
"search_books",
|
|
||||||
"ensure_author",
|
|
||||||
"ensure_series",
|
|
||||||
"ensure_book",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class StandardBookMetadata:
|
|
||||||
"""Canonical metadata for the final audiobook path."""
|
|
||||||
|
|
||||||
author_id: int
|
|
||||||
author: str
|
|
||||||
book_id: int | None
|
|
||||||
title: str
|
|
||||||
series_id: int | None
|
|
||||||
series: str
|
|
||||||
series_index: float
|
|
||||||
confidence: float
|
|
||||||
needs_review: bool
|
|
||||||
evidence: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FinalMetadataFields:
|
|
||||||
"""Raw model fields after schema validation."""
|
|
||||||
|
|
||||||
author_id: int
|
|
||||||
book_id: int | None
|
|
||||||
title: str
|
|
||||||
series_id: int | None
|
|
||||||
series_index: float
|
|
||||||
confidence: float
|
|
||||||
evidence: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ResolvedBookFields:
|
|
||||||
"""Book fields after optional catalog book resolution."""
|
|
||||||
|
|
||||||
book_id: int | None
|
|
||||||
title: str
|
|
||||||
series_id: int | None
|
|
||||||
series_index: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AgentStepResult:
|
|
||||||
"""Outcome from one model response."""
|
|
||||||
|
|
||||||
metadata: StandardBookMetadata | None
|
|
||||||
invalid_final_count: int
|
|
||||||
should_continue: bool
|
|
||||||
|
|
||||||
|
|
||||||
def standard_book_metadata(
|
|
||||||
aax_file_name: str,
|
|
||||||
aax_metadata_from_ffprobe: dict[str, str],
|
|
||||||
engine: Engine,
|
|
||||||
log_path: Path,
|
|
||||||
ollama_api_key: str,
|
|
||||||
config: AgentConfig,
|
|
||||||
) -> StandardBookMetadata:
|
|
||||||
"""Resolve canonical audiobook metadata with the configured Ollama Cloud model."""
|
|
||||||
with Session(engine) as session:
|
|
||||||
registry = CatalogToolRegistry(session, log_path, config, write_agent_log)
|
|
||||||
agent = AudiobookMetadataAgent(
|
|
||||||
registry=registry, log_path=log_path, ollama_api_key=ollama_api_key, config=config
|
|
||||||
)
|
|
||||||
metadata = agent.run(aax_file_name, aax_metadata_from_ffprobe)
|
|
||||||
if metadata.needs_review:
|
|
||||||
session.rollback()
|
|
||||||
else:
|
|
||||||
registry.prune_unused_created_rows(
|
|
||||||
author_id=metadata.author_id,
|
|
||||||
book_id=metadata.book_id,
|
|
||||||
series_id=metadata.series_id,
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
class AudiobookMetadataAgent:
|
|
||||||
"""Ollama-backed metadata resolver with a fixed local tool registry."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
registry: CatalogToolRegistry,
|
|
||||||
log_path: Path,
|
|
||||||
ollama_api_key: str,
|
|
||||||
config: AgentConfig,
|
|
||||||
) -> None:
|
|
||||||
"""Create an Ollama metadata agent."""
|
|
||||||
self._registry = registry
|
|
||||||
self._log_path = log_path
|
|
||||||
self._ollama_api_key = ollama_api_key
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
def run(self, aax_file_name: str, aax_metadata_from_ffprobe: dict[str, str]) -> StandardBookMetadata:
|
|
||||||
"""Resolve metadata for one AAX file."""
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system_prompt()},
|
|
||||||
{"role": "user", "content": user_prompt(aax_file_name, aax_metadata_from_ffprobe)},
|
|
||||||
]
|
|
||||||
invalid_final_count = 0
|
|
||||||
result: StandardBookMetadata | None = None
|
|
||||||
|
|
||||||
for turn in range(1, self._config.max_agent_turns + 1):
|
|
||||||
step = self.run_step(messages, turn, invalid_final_count)
|
|
||||||
invalid_final_count = step.invalid_final_count
|
|
||||||
if step.should_continue:
|
|
||||||
continue
|
|
||||||
result = step.metadata
|
|
||||||
break
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return self.force_final_response(messages)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def run_step(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, object]],
|
|
||||||
turn: int,
|
|
||||||
invalid_final_count: int,
|
|
||||||
) -> AgentStepResult:
|
|
||||||
"""Run one model turn and return the next agent-loop action."""
|
|
||||||
data = self.chat(messages, turn)
|
|
||||||
message = data.get("message")
|
|
||||||
if not isinstance(message, dict):
|
|
||||||
return AgentStepResult(
|
|
||||||
metadata=review_metadata("Ollama response did not include a message", self._config),
|
|
||||||
invalid_final_count=invalid_final_count,
|
|
||||||
should_continue=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_calls = parse_tool_calls(message)
|
|
||||||
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
|
||||||
return AgentStepResult(
|
|
||||||
metadata=review_metadata(str(error), self._config),
|
|
||||||
invalid_final_count=invalid_final_count,
|
|
||||||
should_continue=False,
|
|
||||||
)
|
|
||||||
if tool_calls:
|
|
||||||
fatal_error = run_tool_calls(messages, message, tool_calls, self._registry, self._log_path, write_agent_log)
|
|
||||||
if fatal_error is not None:
|
|
||||||
return AgentStepResult(
|
|
||||||
metadata=review_metadata(fatal_error, self._config),
|
|
||||||
invalid_final_count=invalid_final_count,
|
|
||||||
should_continue=False,
|
|
||||||
)
|
|
||||||
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
|
||||||
return self.handle_final_message(messages, message, invalid_final_count)
|
|
||||||
|
|
||||||
def handle_final_message(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, object]],
|
|
||||||
message: dict[str, object],
|
|
||||||
invalid_final_count: int,
|
|
||||||
) -> AgentStepResult:
|
|
||||||
"""Validate a final model message or request one retry."""
|
|
||||||
content = message.get("content")
|
|
||||||
if not isinstance(content, str):
|
|
||||||
return AgentStepResult(
|
|
||||||
metadata=review_metadata("Ollama final response did not include string content", self._config),
|
|
||||||
invalid_final_count=invalid_final_count,
|
|
||||||
should_continue=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
resolved = self.validate_final(parse_final_json_content(content))
|
|
||||||
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
|
||||||
return self.handle_invalid_final(messages, error, invalid_final_count)
|
|
||||||
|
|
||||||
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
|
|
||||||
return AgentStepResult(metadata=resolved, invalid_final_count=invalid_final_count, should_continue=False)
|
|
||||||
|
|
||||||
def handle_invalid_final(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, object]],
|
|
||||||
error: json.JSONDecodeError | MetadataResolutionError,
|
|
||||||
invalid_final_count: int,
|
|
||||||
) -> AgentStepResult:
|
|
||||||
"""Log invalid final JSON and either retry or return review metadata."""
|
|
||||||
invalid_final_count += 1
|
|
||||||
write_agent_log(
|
|
||||||
self._log_path,
|
|
||||||
"final_validation_error",
|
|
||||||
error=str(error),
|
|
||||||
invalid_final_count=invalid_final_count,
|
|
||||||
)
|
|
||||||
if invalid_final_count > self._config.invalid_final_retries:
|
|
||||||
return AgentStepResult(
|
|
||||||
metadata=review_metadata(str(error), self._config),
|
|
||||||
invalid_final_count=invalid_final_count,
|
|
||||||
should_continue=False,
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": (
|
|
||||||
"Your previous final answer was invalid. Return only valid JSON matching the required "
|
|
||||||
f"schema. Validation error: {error}"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
|
||||||
|
|
||||||
def force_final_response(self, messages: list[dict[str, object]]) -> StandardBookMetadata:
|
|
||||||
"""Request a no-tool final answer after the normal turn limit."""
|
|
||||||
messages.append({"role": "user", "content": forced_final_prompt()})
|
|
||||||
write_agent_log(self._log_path, "forced_final_request", reason="max_turns")
|
|
||||||
data = self.chat(messages, self._config.max_agent_turns + 1, tools_enabled=False)
|
|
||||||
message = data.get("message")
|
|
||||||
if not isinstance(message, dict):
|
|
||||||
return review_metadata("Ollama forced final response did not include a message", self._config)
|
|
||||||
content = message.get("content")
|
|
||||||
if not isinstance(content, str):
|
|
||||||
return review_metadata("Ollama forced final response did not include string content", self._config)
|
|
||||||
try:
|
|
||||||
resolved = self.validate_final(parse_final_json_content(content))
|
|
||||||
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
|
||||||
return review_metadata(f"Ollama forced final response was invalid: {error}", self._config)
|
|
||||||
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
|
|
||||||
return resolved
|
|
||||||
|
|
||||||
def chat(self, messages: list[dict[str, object]], turn: int, *, tools_enabled: bool = True) -> dict[str, object]:
|
|
||||||
"""Send one chat request to Ollama and log the request and response."""
|
|
||||||
payload = {
|
|
||||||
"model": self._config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": False,
|
|
||||||
"options": {"temperature": 0.1},
|
|
||||||
}
|
|
||||||
tool_names = []
|
|
||||||
if tools_enabled:
|
|
||||||
payload["tools"] = self._registry.tool_schemas()
|
|
||||||
tool_names = self._config.tool_names
|
|
||||||
write_agent_log(
|
|
||||||
self._log_path,
|
|
||||||
"model_request",
|
|
||||||
model=self._config.model,
|
|
||||||
turn=turn,
|
|
||||||
message_count=len(messages),
|
|
||||||
tool_names=tool_names,
|
|
||||||
tools_enabled=tools_enabled,
|
|
||||||
)
|
|
||||||
write_agent_log(
|
|
||||||
self._log_path,
|
|
||||||
"llm_messages_sent",
|
|
||||||
model=self._config.model,
|
|
||||||
turn=turn,
|
|
||||||
messages=messages,
|
|
||||||
tools_enabled=tools_enabled,
|
|
||||||
)
|
|
||||||
response = httpx.post(
|
|
||||||
self._config.ollama_chat_url,
|
|
||||||
headers={"Authorization": f"Bearer {self._ollama_api_key}"},
|
|
||||||
json=payload,
|
|
||||||
timeout=self._config.http_timeout_seconds,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
raw_data = response.json()
|
|
||||||
if not isinstance(raw_data, dict):
|
|
||||||
return {}
|
|
||||||
data = {str(key): value for key, value in raw_data.items()}
|
|
||||||
message = data.get("message", {})
|
|
||||||
content = message.get("content") if isinstance(message, dict) else ""
|
|
||||||
write_agent_log(
|
|
||||||
self._log_path,
|
|
||||||
"llm_message_received",
|
|
||||||
model=self._config.model,
|
|
||||||
turn=turn,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
write_agent_log(
|
|
||||||
self._log_path,
|
|
||||||
"model_response",
|
|
||||||
model=self._config.model,
|
|
||||||
turn=turn,
|
|
||||||
has_tool_calls=bool(isinstance(message, dict) and message.get("tool_calls")),
|
|
||||||
content_chars=len(content) if isinstance(content, str) else 0,
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def validate_final(self, raw_metadata: object) -> StandardBookMetadata:
|
|
||||||
"""Validate final model metadata against catalog rows."""
|
|
||||||
fields = parse_final_metadata_fields(raw_metadata)
|
|
||||||
fields = replace(fields, title=normalize_title_slug(fields.title))
|
|
||||||
author = self.validate_author(fields.author_id)
|
|
||||||
validate_title_slug(fields.title)
|
|
||||||
book_fields = self.resolve_book_fields(fields)
|
|
||||||
series = self.validate_series(fields.author_id, book_fields.series_id, book_fields.series_index)
|
|
||||||
|
|
||||||
return StandardBookMetadata(
|
|
||||||
author_id=fields.author_id,
|
|
||||||
author=author.name,
|
|
||||||
book_id=book_fields.book_id,
|
|
||||||
title=book_fields.title,
|
|
||||||
series_id=book_fields.series_id,
|
|
||||||
series=series,
|
|
||||||
series_index=book_fields.series_index,
|
|
||||||
confidence=fields.confidence,
|
|
||||||
needs_review=fields.confidence < self._config.min_confidence,
|
|
||||||
evidence=fields.evidence,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_author(self, author_id: int) -> AudiobookAuthor:
|
|
||||||
"""Validate that an author id was seen and exists."""
|
|
||||||
if author_id not in self._registry.seen_author_ids:
|
|
||||||
msg = f"author_id {author_id} was not returned by search_authors"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
author = self._registry.get_author(author_id)
|
|
||||||
if author is None:
|
|
||||||
msg = f"author_id {author_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
validate_catalog_slug(author.name, "author")
|
|
||||||
return author
|
|
||||||
|
|
||||||
def resolve_book_fields(self, fields: FinalMetadataFields) -> ResolvedBookFields:
|
|
||||||
"""Resolve final book fields from a seen book id or created book."""
|
|
||||||
if fields.book_id is None:
|
|
||||||
ensured = self._registry.ensure_book(
|
|
||||||
fields.title,
|
|
||||||
fields.author_id,
|
|
||||||
fields.series_id,
|
|
||||||
fields.series_index,
|
|
||||||
)
|
|
||||||
return ResolvedBookFields(
|
|
||||||
book_id=ensured.book.id,
|
|
||||||
title=ensured.book.title,
|
|
||||||
series_id=ensured.book.series_id,
|
|
||||||
series_index=ensured.book.series_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
if fields.book_id not in self._registry.seen_book_ids:
|
|
||||||
msg = f"book_id {fields.book_id} was not returned by search_books"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
book = self._registry.get_book(fields.book_id)
|
|
||||||
if book is None:
|
|
||||||
msg = f"book_id {fields.book_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if book.author_id != fields.author_id:
|
|
||||||
msg = f"book_id {fields.book_id} does not belong to author_id {fields.author_id}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return ResolvedBookFields(
|
|
||||||
book_id=fields.book_id,
|
|
||||||
title=book.title,
|
|
||||||
series_id=book.series_id,
|
|
||||||
series_index=book.series_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_series(self, author_id: int, series_id: int | None, series_index: float) -> str:
|
|
||||||
"""Validate final series fields and return the canonical series slug."""
|
|
||||||
if series_id is None:
|
|
||||||
if series_index != 0:
|
|
||||||
msg = "standalone books must use series_index 0"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return self._config.standalone_series
|
|
||||||
|
|
||||||
if series_id not in self._registry.seen_series_ids:
|
|
||||||
msg = f"series_id {series_id} was not returned by search_series"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
series = self._registry.get_series(series_id)
|
|
||||||
if series is None:
|
|
||||||
msg = f"series_id {series_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if series.author_id != author_id:
|
|
||||||
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if series_index <= 0:
|
|
||||||
msg = "series books must use a positive series_index"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
validate_catalog_slug(series.name, "series")
|
|
||||||
return series.name
|
|
||||||
|
|
||||||
|
|
||||||
def write_agent_log(log_path: Path, event: str, **fields: object) -> None:
|
|
||||||
"""Append one JSONL audit event."""
|
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
record = {
|
|
||||||
"created": utcnow().isoformat(),
|
|
||||||
"event": event,
|
|
||||||
**{key: json_log_value(value) for key, value in fields.items()},
|
|
||||||
}
|
|
||||||
with log_path.open("a", encoding="utf-8") as file:
|
|
||||||
file.write(json.dumps(record, sort_keys=True))
|
|
||||||
file.write("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def json_log_value(value: object) -> object:
|
|
||||||
"""Return a JSON-serializable value for audit logs."""
|
|
||||||
if is_dataclass(value) and not isinstance(value, type):
|
|
||||||
return json_log_value(asdict(value))
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return {str(key): json_log_value(item) for key, item in value.items()}
|
|
||||||
if isinstance(value, list | tuple):
|
|
||||||
return [json_log_value(item) for item in value]
|
|
||||||
if isinstance(value, set):
|
|
||||||
return [json_log_value(item) for item in sorted(value, key=str)]
|
|
||||||
if isinstance(value, PathLike):
|
|
||||||
return str(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def system_prompt() -> str:
|
|
||||||
"""Return the stable system prompt."""
|
|
||||||
return """You standardize Audible audiobook metadata against a private catalog.
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
- You must use the provided tools before returning final metadata.
|
|
||||||
- Only use author_id, series_id, or book_id values returned by tools.
|
|
||||||
- Return final metadata as JSON only. Do not wrap it in Markdown.
|
|
||||||
- The final JSON object must contain author_id, book_id, title, series_id, series_index, confidence, and evidence.
|
|
||||||
- title must be a canonical title slug using lower-case words separated by hyphens.
|
|
||||||
- Use series_id null and series_index 0 for standalone books.
|
|
||||||
- If you use a series_id, series_index must be a whole number or .5 value greater than 0.
|
|
||||||
- Treat series slugs that differ only by underscores as the same series. Prefer the existing catalog row instead of
|
|
||||||
creating a new series.
|
|
||||||
- Detect omnibus or box-set editions that contain multiple numbered novels, books, or novellas.
|
|
||||||
- For an omnibus, make a best-effort range from the filename, tags, and catalog rows. Keep series_index as the
|
|
||||||
first covered book number and include the range in the title when the source title includes it, for example
|
|
||||||
books-1-3.
|
|
||||||
- Be careful with omnibuses of novels or novellas later published as one book: keep the omnibus as the audiobook's
|
|
||||||
book record unless catalog rows clearly identify a better match.
|
|
||||||
- Do not create publisher collections or author collections as series unless the book metadata clearly gives a
|
|
||||||
numbered series.
|
|
||||||
- Series belong to authors. Use a series_id only when it belongs to the selected author_id.
|
|
||||||
- Always search for the author before creating one. If no exact author slug exists, call ensure_author.
|
|
||||||
- Always search for a series with author_id before creating one. If no exact series slug exists, call ensure_series.
|
|
||||||
- Always search for a book before creating one. If no exact title slug exists, call ensure_book.
|
|
||||||
- If a tool returns an error, correct your tool arguments or final metadata before continuing.
|
|
||||||
- confidence must be a number from 0 to 1.
|
|
||||||
- evidence must be a short list of strings explaining which filename, tags, and catalog rows support the answer."""
|
|
||||||
|
|
||||||
|
|
||||||
def forced_final_prompt() -> str:
|
|
||||||
"""Return the no-tools finalization prompt."""
|
|
||||||
return (
|
|
||||||
"Stop calling tools. Return final metadata as JSON only using the tool results already provided. "
|
|
||||||
"If search_books returned no matching rows but author and series are known, use book_id null and resolve "
|
|
||||||
"the title slug from the AAX filename and ffprobe tags. The validator will create the missing book. "
|
|
||||||
"Use only author_id and series_id values returned by earlier tool results."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def user_prompt(aax_file_name: str, metadata: dict[str, str]) -> str:
|
|
||||||
"""Build the user prompt from source metadata."""
|
|
||||||
return (
|
|
||||||
"Resolve this Audible audiobook.\n\n"
|
|
||||||
f"AAX file name: {aax_file_name}\n\n"
|
|
||||||
"ffprobe format tags:\n"
|
|
||||||
f"{json.dumps(metadata, indent=2, sort_keys=True)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_final_json_content(content: str) -> object:
|
|
||||||
"""Parse final model content, accepting bare or fenced JSON."""
|
|
||||||
stripped = content.strip()
|
|
||||||
if match := FENCED_JSON_PATTERN.fullmatch(stripped):
|
|
||||||
stripped = match.group("json").strip()
|
|
||||||
return json.loads(stripped)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_final_metadata_fields(raw_metadata: object) -> FinalMetadataFields:
|
|
||||||
"""Parse the model's final JSON object into typed fields."""
|
|
||||||
if not isinstance(raw_metadata, dict):
|
|
||||||
msg = "Final metadata must be a JSON object"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
data = {str(key): value for key, value in raw_metadata.items()}
|
|
||||||
return FinalMetadataFields(
|
|
||||||
author_id=required_int(data, "author_id"),
|
|
||||||
book_id=optional_int(data.get("book_id"), "book_id"),
|
|
||||||
title=required_string(data, "title"),
|
|
||||||
series_id=optional_int(data.get("series_id"), "series_id"),
|
|
||||||
series_index=required_series_index(data, "series_index"),
|
|
||||||
confidence=required_float(data, "confidence"),
|
|
||||||
evidence=required_string_list(data, "evidence"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
|
|
||||||
"""Return a metadata result that must be reviewed manually."""
|
|
||||||
return StandardBookMetadata(
|
|
||||||
author_id=0,
|
|
||||||
author="unknown_author",
|
|
||||||
book_id=None,
|
|
||||||
title="unknown-title",
|
|
||||||
series_id=None,
|
|
||||||
series=config.standalone_series,
|
|
||||||
series_index=0,
|
|
||||||
confidence=0,
|
|
||||||
needs_review=True,
|
|
||||||
evidence=[reason],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def required_float(data: dict[str, object], key: str) -> float:
|
|
||||||
"""Read a required float field."""
|
|
||||||
value = data.get(key)
|
|
||||||
if isinstance(value, bool) or not isinstance(value, int | float):
|
|
||||||
msg = f"{key} must be a number"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
confidence = float(value)
|
|
||||||
if confidence < 0 or confidence > 1:
|
|
||||||
msg = f"{key} must be between 0 and 1"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return confidence
|
|
||||||
|
|
||||||
|
|
||||||
def required_string_list(data: dict[str, object], key: str) -> list[str]:
|
|
||||||
"""Read a required list of strings."""
|
|
||||||
value = data.get(key)
|
|
||||||
if not isinstance(value, list) or not value or not all(isinstance(item, str) for item in value):
|
|
||||||
msg = f"{key} must be a non-empty list of strings"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
strings = [item.strip() for item in value if item.strip()]
|
|
||||||
if not strings:
|
|
||||||
msg = f"{key} must include at least one non-empty string"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return strings
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Van inventory FastAPI application."""
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""FastAPI dependencies for van inventory."""
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
def get_db(request: Request) -> Iterator[Session]:
|
||||||
|
"""Get database session from app state."""
|
||||||
|
with Session(request.app.state.engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
DbSession = Annotated[Session, Depends(get_db)]
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""FastAPI app for van inventory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from python.common import configure_logger
|
||||||
|
from python.orm.common import get_postgres_engine
|
||||||
|
from python.van_inventory.routers import api_router, frontend_router
|
||||||
|
|
||||||
|
STATIC_DIR = Path(__file__).resolve().parent / "static"
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI application."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||||
|
app.state.engine = get_postgres_engine(name="VAN_INVENTORY")
|
||||||
|
yield
|
||||||
|
app.state.engine.dispose()
|
||||||
|
|
||||||
|
app = FastAPI(title="Van Inventory", lifespan=lifespan)
|
||||||
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
app.include_router(api_router)
|
||||||
|
app.include_router(frontend_router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def serve(
|
||||||
|
# Intentionally binds all interfaces — this is a LAN-only van server
|
||||||
|
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "0.0.0.0", # noqa: S104
|
||||||
|
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8001,
|
||||||
|
log_level: Annotated[str, typer.Option("--log-level", "-l", help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Start the Van Inventory server."""
|
||||||
|
configure_logger(log_level)
|
||||||
|
app = create_app()
|
||||||
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(serve)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Van inventory API routers."""
|
||||||
|
|
||||||
|
from python.van_inventory.routers.api import router as api_router
|
||||||
|
from python.van_inventory.routers.frontend import router as frontend_router
|
||||||
|
|
||||||
|
__all__ = ["api_router", "frontend_router"]
|
||||||
@@ -0,0 +1,314 @@
|
|||||||
|
"""Van inventory API router."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from python.orm.van_inventory.models import Item, Meal, MealIngredient
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from python.van_inventory.dependencies import DbSession
|
||||||
|
|
||||||
|
|
||||||
|
# --- Schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class ItemCreate(BaseModel):
|
||||||
|
"""Schema for creating an item."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
quantity: float = Field(default=0, ge=0)
|
||||||
|
unit: str
|
||||||
|
category: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ItemUpdate(BaseModel):
|
||||||
|
"""Schema for updating an item."""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
quantity: float | None = Field(default=None, ge=0)
|
||||||
|
unit: str | None = None
|
||||||
|
category: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ItemResponse(BaseModel):
|
||||||
|
"""Schema for item response."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
quantity: float
|
||||||
|
unit: str
|
||||||
|
category: str | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class IngredientCreate(BaseModel):
|
||||||
|
"""Schema for adding an ingredient to a meal."""
|
||||||
|
|
||||||
|
item_id: int
|
||||||
|
quantity_needed: float = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class MealCreate(BaseModel):
|
||||||
|
"""Schema for creating a meal."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
instructions: str | None = None
|
||||||
|
ingredients: list[IngredientCreate] = []
|
||||||
|
|
||||||
|
|
||||||
|
class MealUpdate(BaseModel):
|
||||||
|
"""Schema for updating a meal."""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
instructions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IngredientResponse(BaseModel):
|
||||||
|
"""Schema for ingredient response."""
|
||||||
|
|
||||||
|
item_id: int
|
||||||
|
item_name: str
|
||||||
|
quantity_needed: float
|
||||||
|
unit: str
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class MealResponse(BaseModel):
|
||||||
|
"""Schema for meal response."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
instructions: str | None
|
||||||
|
ingredients: list[IngredientResponse] = []
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_meal(cls, meal: Meal) -> MealResponse:
|
||||||
|
"""Build a MealResponse from an ORM Meal with loaded ingredients."""
|
||||||
|
return cls(
|
||||||
|
id=meal.id,
|
||||||
|
name=meal.name,
|
||||||
|
instructions=meal.instructions,
|
||||||
|
ingredients=[
|
||||||
|
IngredientResponse(
|
||||||
|
item_id=mi.item_id,
|
||||||
|
item_name=mi.item.name,
|
||||||
|
quantity_needed=mi.quantity_needed,
|
||||||
|
unit=mi.item.unit,
|
||||||
|
)
|
||||||
|
for mi in meal.ingredients
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ShoppingItem(BaseModel):
|
||||||
|
"""An item needed for a meal that is short on stock."""
|
||||||
|
|
||||||
|
item_name: str
|
||||||
|
unit: str
|
||||||
|
needed: float
|
||||||
|
have: float
|
||||||
|
short: float
|
||||||
|
|
||||||
|
|
||||||
|
class MealAvailability(BaseModel):
|
||||||
|
"""Availability status for a meal."""
|
||||||
|
|
||||||
|
meal_id: int
|
||||||
|
meal_name: str
|
||||||
|
can_make: bool
|
||||||
|
missing: list[ShoppingItem] = []
|
||||||
|
|
||||||
|
|
||||||
|
# --- Routes ---
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["van_inventory"])
|
||||||
|
|
||||||
|
|
||||||
|
# Items
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/items", response_model=ItemResponse)
|
||||||
|
def create_item(item: ItemCreate, db: DbSession) -> Item:
|
||||||
|
"""Create a new inventory item."""
|
||||||
|
db_item = Item(**item.model_dump())
|
||||||
|
db.add(db_item)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_item)
|
||||||
|
return db_item
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/items", response_model=list[ItemResponse])
|
||||||
|
def list_items(db: DbSession) -> list[Item]:
|
||||||
|
"""List all inventory items."""
|
||||||
|
return list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||||
|
def get_item(item_id: int, db: DbSession) -> Item:
|
||||||
|
"""Get an item by ID."""
|
||||||
|
item = db.get(Item, item_id)
|
||||||
|
if not item:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/items/{item_id}", response_model=ItemResponse)
|
||||||
|
def update_item(item_id: int, item: ItemUpdate, db: DbSession) -> Item:
|
||||||
|
"""Update an item by ID."""
|
||||||
|
db_item = db.get(Item, item_id)
|
||||||
|
if not db_item:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
for key, value in item.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(db_item, key, value)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_item)
|
||||||
|
return db_item
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/items/{item_id}")
|
||||||
|
def delete_item(item_id: int, db: DbSession) -> dict[str, bool]:
|
||||||
|
"""Delete an item by ID."""
|
||||||
|
item = db.get(Item, item_id)
|
||||||
|
if not item:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
db.delete(item)
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Meals
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/meals", response_model=MealResponse)
|
||||||
|
def create_meal(meal: MealCreate, db: DbSession) -> MealResponse:
|
||||||
|
"""Create a new meal with optional ingredients."""
|
||||||
|
for ing in meal.ingredients:
|
||||||
|
if not db.get(Item, ing.item_id):
|
||||||
|
raise HTTPException(status_code=422, detail=f"Item {ing.item_id} not found")
|
||||||
|
db_meal = Meal(name=meal.name, instructions=meal.instructions)
|
||||||
|
db.add(db_meal)
|
||||||
|
db.flush()
|
||||||
|
for ing in meal.ingredients:
|
||||||
|
db.add(MealIngredient(meal_id=db_meal.id, item_id=ing.item_id, quantity_needed=ing.quantity_needed))
|
||||||
|
db.commit()
|
||||||
|
db_meal = db.scalar(
|
||||||
|
select(Meal)
|
||||||
|
.where(Meal.id == db_meal.id)
|
||||||
|
.options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||||
|
)
|
||||||
|
return MealResponse.from_meal(db_meal)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meals", response_model=list[MealResponse])
|
||||||
|
def list_meals(db: DbSession) -> list[MealResponse]:
|
||||||
|
"""List all meals with ingredients."""
|
||||||
|
meals = list(
|
||||||
|
db.scalars(
|
||||||
|
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
return [MealResponse.from_meal(m) for m in meals]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meals/availability", response_model=list[MealAvailability])
|
||||||
|
def check_all_meals(db: DbSession) -> list[MealAvailability]:
|
||||||
|
"""Check which meals can be made with current inventory."""
|
||||||
|
meals = list(
|
||||||
|
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
|
||||||
|
)
|
||||||
|
return [_check_meal(m) for m in meals]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meals/{meal_id}", response_model=MealResponse)
|
||||||
|
def get_meal(meal_id: int, db: DbSession) -> MealResponse:
|
||||||
|
"""Get a meal by ID with ingredients."""
|
||||||
|
meal = db.scalar(
|
||||||
|
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||||
|
)
|
||||||
|
if not meal:
|
||||||
|
raise HTTPException(status_code=404, detail="Meal not found")
|
||||||
|
return MealResponse.from_meal(meal)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/meals/{meal_id}")
|
||||||
|
def delete_meal(meal_id: int, db: DbSession) -> dict[str, bool]:
|
||||||
|
"""Delete a meal by ID."""
|
||||||
|
meal = db.get(Meal, meal_id)
|
||||||
|
if not meal:
|
||||||
|
raise HTTPException(status_code=404, detail="Meal not found")
|
||||||
|
db.delete(meal)
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/meals/{meal_id}/ingredients", response_model=MealResponse)
|
||||||
|
def add_ingredient(meal_id: int, ingredient: IngredientCreate, db: DbSession) -> MealResponse:
|
||||||
|
"""Add an ingredient to a meal."""
|
||||||
|
meal = db.get(Meal, meal_id)
|
||||||
|
if not meal:
|
||||||
|
raise HTTPException(status_code=404, detail="Meal not found")
|
||||||
|
if not db.get(Item, ingredient.item_id):
|
||||||
|
raise HTTPException(status_code=422, detail="Item not found")
|
||||||
|
existing = db.scalar(
|
||||||
|
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == ingredient.item_id)
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
|
||||||
|
db.add(MealIngredient(meal_id=meal_id, item_id=ingredient.item_id, quantity_needed=ingredient.quantity_needed))
|
||||||
|
db.commit()
|
||||||
|
meal = db.scalar(
|
||||||
|
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||||
|
)
|
||||||
|
return MealResponse.from_meal(meal)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/meals/{meal_id}/ingredients/{item_id}")
|
||||||
|
def remove_ingredient(meal_id: int, item_id: int, db: DbSession) -> dict[str, bool]:
|
||||||
|
"""Remove an ingredient from a meal."""
|
||||||
|
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
|
||||||
|
if not mi:
|
||||||
|
raise HTTPException(status_code=404, detail="Ingredient not found")
|
||||||
|
db.delete(mi)
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meals/{meal_id}/availability", response_model=MealAvailability)
|
||||||
|
def check_meal(meal_id: int, db: DbSession) -> MealAvailability:
|
||||||
|
"""Check if a specific meal can be made and what's missing."""
|
||||||
|
meal = db.scalar(
|
||||||
|
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||||
|
)
|
||||||
|
if not meal:
|
||||||
|
raise HTTPException(status_code=404, detail="Meal not found")
|
||||||
|
return _check_meal(meal)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_meal(meal: Meal) -> MealAvailability:
|
||||||
|
missing = [
|
||||||
|
ShoppingItem(
|
||||||
|
item_name=mi.item.name,
|
||||||
|
unit=mi.item.unit,
|
||||||
|
needed=mi.quantity_needed,
|
||||||
|
have=mi.item.quantity,
|
||||||
|
short=mi.quantity_needed - mi.item.quantity,
|
||||||
|
)
|
||||||
|
for mi in meal.ingredients
|
||||||
|
if mi.item.quantity < mi.quantity_needed
|
||||||
|
]
|
||||||
|
return MealAvailability(
|
||||||
|
meal_id=meal.id,
|
||||||
|
meal_name=meal.name,
|
||||||
|
can_make=len(missing) == 0,
|
||||||
|
missing=missing,
|
||||||
|
)
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""HTMX frontend routes for van inventory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from python.orm.van_inventory.models import Item, Meal, MealIngredient
|
||||||
|
|
||||||
|
# FastAPI needs DbSession at runtime to resolve the Depends() annotation
|
||||||
|
from python.van_inventory.dependencies import DbSession # noqa: TC001
|
||||||
|
from python.van_inventory.routers.api import _check_meal
|
||||||
|
|
||||||
|
TEMPLATE_DIR = Path(__file__).resolve().parent.parent / "templates"
|
||||||
|
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["frontend"])
|
||||||
|
|
||||||
|
|
||||||
|
# --- Items ---
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_class=HTMLResponse)
|
||||||
|
def items_page(request: Request, db: DbSession) -> HTMLResponse:
|
||||||
|
"""Render the inventory page."""
|
||||||
|
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||||
|
return templates.TemplateResponse(request, "items.html", {"items": items})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/items", response_class=HTMLResponse)
|
||||||
|
def htmx_create_item(
|
||||||
|
request: Request,
|
||||||
|
db: DbSession,
|
||||||
|
name: Annotated[str, Form()],
|
||||||
|
quantity: Annotated[float, Form()] = 0,
|
||||||
|
unit: Annotated[str, Form()] = "",
|
||||||
|
category: Annotated[str | None, Form()] = None,
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Create an item and return updated item rows."""
|
||||||
|
if quantity < 0:
|
||||||
|
raise HTTPException(status_code=422, detail="Quantity must not be negative")
|
||||||
|
db.add(Item(name=name, quantity=quantity, unit=unit, category=category or None))
|
||||||
|
db.commit()
|
||||||
|
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||||
|
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/items/{item_id}", response_class=HTMLResponse)
|
||||||
|
def htmx_update_item(
|
||||||
|
request: Request,
|
||||||
|
item_id: int,
|
||||||
|
db: DbSession,
|
||||||
|
quantity: Annotated[float, Form()],
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Update an item's quantity and return updated item rows."""
|
||||||
|
if quantity < 0:
|
||||||
|
raise HTTPException(status_code=422, detail="Quantity must not be negative")
|
||||||
|
item = db.get(Item, item_id)
|
||||||
|
if item:
|
||||||
|
item.quantity = quantity
|
||||||
|
db.commit()
|
||||||
|
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||||
|
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/items/{item_id}", response_class=HTMLResponse)
|
||||||
|
def htmx_delete_item(request: Request, item_id: int, db: DbSession) -> HTMLResponse:
|
||||||
|
"""Delete an item and return updated item rows."""
|
||||||
|
item = db.get(Item, item_id)
|
||||||
|
if item:
|
||||||
|
db.delete(item)
|
||||||
|
db.commit()
|
||||||
|
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||||
|
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
|
||||||
|
|
||||||
|
|
||||||
|
# --- Meals ---
|
||||||
|
|
||||||
|
|
||||||
|
def _load_meals(db: DbSession) -> list[Meal]:
|
||||||
|
return list(
|
||||||
|
db.scalars(
|
||||||
|
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meals", response_class=HTMLResponse)
|
||||||
|
def meals_page(request: Request, db: DbSession) -> HTMLResponse:
|
||||||
|
"""Render the meals page."""
|
||||||
|
meals = _load_meals(db)
|
||||||
|
return templates.TemplateResponse(request, "meals.html", {"meals": meals})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/meals", response_class=HTMLResponse)
|
||||||
|
def htmx_create_meal(
|
||||||
|
request: Request,
|
||||||
|
db: DbSession,
|
||||||
|
name: Annotated[str, Form()],
|
||||||
|
instructions: Annotated[str | None, Form()] = None,
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Create a meal and return updated meal rows."""
|
||||||
|
db.add(Meal(name=name, instructions=instructions or None))
|
||||||
|
db.commit()
|
||||||
|
meals = _load_meals(db)
|
||||||
|
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/meals/{meal_id}", response_class=HTMLResponse)
|
||||||
|
def htmx_delete_meal(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
|
||||||
|
"""Delete a meal and return updated meal rows."""
|
||||||
|
meal = db.get(Meal, meal_id)
|
||||||
|
if meal:
|
||||||
|
db.delete(meal)
|
||||||
|
db.commit()
|
||||||
|
meals = _load_meals(db)
|
||||||
|
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
|
||||||
|
|
||||||
|
|
||||||
|
# --- Meal detail ---
|
||||||
|
|
||||||
|
|
||||||
|
def _load_meal(db: DbSession, meal_id: int) -> Meal | None:
|
||||||
|
return db.scalar(
|
||||||
|
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meals/{meal_id}", response_class=HTMLResponse)
|
||||||
|
def meal_detail_page(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
|
||||||
|
"""Render the meal detail page."""
|
||||||
|
meal = _load_meal(db, meal_id)
|
||||||
|
if not meal:
|
||||||
|
raise HTTPException(status_code=404, detail="Meal not found")
|
||||||
|
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||||
|
return templates.TemplateResponse(request, "meal_detail.html", {"meal": meal, "items": items})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/meals/{meal_id}/ingredients", response_class=HTMLResponse)
|
||||||
|
def htmx_add_ingredient(
|
||||||
|
request: Request,
|
||||||
|
meal_id: int,
|
||||||
|
db: DbSession,
|
||||||
|
item_id: Annotated[int, Form()],
|
||||||
|
quantity_needed: Annotated[float, Form()],
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Add an ingredient to a meal and return updated ingredient rows."""
|
||||||
|
if quantity_needed <= 0:
|
||||||
|
raise HTTPException(status_code=422, detail="Quantity must be positive")
|
||||||
|
meal = db.get(Meal, meal_id)
|
||||||
|
if not meal:
|
||||||
|
raise HTTPException(status_code=404, detail="Meal not found")
|
||||||
|
if not db.get(Item, item_id):
|
||||||
|
raise HTTPException(status_code=422, detail="Item not found")
|
||||||
|
existing = db.scalar(
|
||||||
|
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id)
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
|
||||||
|
db.add(MealIngredient(meal_id=meal_id, item_id=item_id, quantity_needed=quantity_needed))
|
||||||
|
db.commit()
|
||||||
|
meal = _load_meal(db, meal_id)
|
||||||
|
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/meals/{meal_id}/ingredients/{item_id}", response_class=HTMLResponse)
|
||||||
|
def htmx_remove_ingredient(
|
||||||
|
request: Request,
|
||||||
|
meal_id: int,
|
||||||
|
item_id: int,
|
||||||
|
db: DbSession,
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Remove an ingredient from a meal and return updated ingredient rows."""
|
||||||
|
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
|
||||||
|
if mi:
|
||||||
|
db.delete(mi)
|
||||||
|
db.commit()
|
||||||
|
meal = _load_meal(db, meal_id)
|
||||||
|
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
|
||||||
|
|
||||||
|
|
||||||
|
# --- Availability ---
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/availability", response_class=HTMLResponse)
|
||||||
|
def availability_page(request: Request, db: DbSession) -> HTMLResponse:
|
||||||
|
"""Render the meal availability page."""
|
||||||
|
meals = list(
|
||||||
|
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
|
||||||
|
)
|
||||||
|
availability = [_check_meal(m) for m in meals]
|
||||||
|
return templates.TemplateResponse(request, "availability.html", {"availability": availability})
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
:root {
|
||||||
|
--neon-pink: #ff2a6d;
|
||||||
|
--neon-cyan: #05d9e8;
|
||||||
|
--neon-yellow: #f9f002;
|
||||||
|
--neon-purple: #d300c5;
|
||||||
|
--bg-dark: #0a0a0f;
|
||||||
|
--bg-panel: #0d0d1a;
|
||||||
|
--bg-input: #111128;
|
||||||
|
--border: #1a1a3e;
|
||||||
|
--text: #c0c0d0;
|
||||||
|
--text-dim: #8e8ea0;
|
||||||
|
}
|
||||||
|
|
||||||
|
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: 'Share Tech Mono', monospace;
|
||||||
|
max-width: 900px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 1rem;
|
||||||
|
background: var(--bg-dark);
|
||||||
|
color: var(--text);
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Scanline overlay */
|
||||||
|
body::before {
|
||||||
|
content: '';
|
||||||
|
position: fixed;
|
||||||
|
top: 0; left: 0; right: 0; bottom: 0;
|
||||||
|
background: repeating-linear-gradient(
|
||||||
|
0deg,
|
||||||
|
transparent,
|
||||||
|
transparent 2px,
|
||||||
|
rgba(0, 0, 0, 0.08) 2px,
|
||||||
|
rgba(0, 0, 0, 0.08) 4px
|
||||||
|
);
|
||||||
|
pointer-events: none;
|
||||||
|
z-index: 9999;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1, h2, h3 {
|
||||||
|
font-family: 'Orbitron', sans-serif;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
color: var(--neon-cyan);
|
||||||
|
text-shadow: 0 0 10px rgba(5, 217, 232, 0.5), 0 0 40px rgba(5, 217, 232, 0.2);
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
a { color: var(--neon-pink); text-decoration: none; transition: all 0.2s; }
|
||||||
|
a:hover {
|
||||||
|
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8), 0 0 20px rgba(255, 42, 109, 0.4);
|
||||||
|
}
|
||||||
|
|
||||||
|
nav {
|
||||||
|
display: flex;
|
||||||
|
gap: 1.5rem;
|
||||||
|
padding: 1rem 0;
|
||||||
|
border-bottom: 1px solid var(--border);
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
nav::after {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
bottom: -1px;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
height: 1px;
|
||||||
|
background: linear-gradient(90deg, var(--neon-pink), var(--neon-cyan), var(--neon-purple));
|
||||||
|
opacity: 0.6;
|
||||||
|
}
|
||||||
|
nav a {
|
||||||
|
font-family: 'Orbitron', sans-serif;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
text-transform: uppercase;
|
||||||
|
padding: 0.3rem 0;
|
||||||
|
border-bottom: 2px solid transparent;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
nav a:hover {
|
||||||
|
border-bottom-color: var(--neon-pink);
|
||||||
|
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8);
|
||||||
|
}
|
||||||
|
|
||||||
|
table {
|
||||||
|
width: 100%;
|
||||||
|
border-collapse: collapse;
|
||||||
|
margin: 1rem 0;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
}
|
||||||
|
th, td {
|
||||||
|
text-align: left;
|
||||||
|
padding: 0.6rem 0.75rem;
|
||||||
|
border-bottom: 1px solid var(--border);
|
||||||
|
}
|
||||||
|
th {
|
||||||
|
font-family: 'Orbitron', sans-serif;
|
||||||
|
color: var(--neon-cyan);
|
||||||
|
font-size: 0.7rem;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 2px;
|
||||||
|
background: var(--bg-panel);
|
||||||
|
border-bottom: 1px solid var(--neon-cyan);
|
||||||
|
text-shadow: 0 0 6px rgba(5, 217, 232, 0.3);
|
||||||
|
}
|
||||||
|
tr:hover td {
|
||||||
|
background: rgba(5, 217, 232, 0.03);
|
||||||
|
}
|
||||||
|
|
||||||
|
form {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 0.5rem;
|
||||||
|
align-items: end;
|
||||||
|
margin: 1rem 0;
|
||||||
|
padding: 1rem;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
background: var(--bg-panel);
|
||||||
|
}
|
||||||
|
|
||||||
|
input, select {
|
||||||
|
padding: 0.5rem 0.6rem;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 2px;
|
||||||
|
background: var(--bg-input);
|
||||||
|
color: var(--neon-cyan);
|
||||||
|
font-family: 'Share Tech Mono', monospace;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
input:focus, select:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: var(--neon-cyan);
|
||||||
|
box-shadow: 0 0 8px rgba(5, 217, 232, 0.3), inset 0 0 8px rgba(5, 217, 232, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
button {
|
||||||
|
padding: 0.5rem 1.2rem;
|
||||||
|
border: 1px solid var(--neon-pink);
|
||||||
|
border-radius: 2px;
|
||||||
|
background: transparent;
|
||||||
|
color: var(--neon-pink);
|
||||||
|
cursor: pointer;
|
||||||
|
font-family: 'Orbitron', sans-serif;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 0.7rem;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
text-transform: uppercase;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
background: var(--neon-pink);
|
||||||
|
color: var(--bg-dark);
|
||||||
|
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5), 0 0 30px rgba(255, 42, 109, 0.2);
|
||||||
|
}
|
||||||
|
button.danger {
|
||||||
|
border-color: var(--text-dim);
|
||||||
|
color: var(--text-dim);
|
||||||
|
}
|
||||||
|
button.danger:hover {
|
||||||
|
border-color: var(--neon-pink);
|
||||||
|
background: var(--neon-pink);
|
||||||
|
color: var(--bg-dark);
|
||||||
|
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge {
|
||||||
|
display: inline-block;
|
||||||
|
padding: 0.2rem 0.6rem;
|
||||||
|
border-radius: 2px;
|
||||||
|
font-family: 'Orbitron', sans-serif;
|
||||||
|
font-size: 0.65rem;
|
||||||
|
font-weight: 700;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
.badge.yes {
|
||||||
|
background: rgba(5, 217, 232, 0.1);
|
||||||
|
color: var(--neon-cyan);
|
||||||
|
border: 1px solid var(--neon-cyan);
|
||||||
|
text-shadow: 0 0 6px rgba(5, 217, 232, 0.5);
|
||||||
|
}
|
||||||
|
.badge.no {
|
||||||
|
background: rgba(255, 42, 109, 0.1);
|
||||||
|
color: var(--neon-pink);
|
||||||
|
border: 1px solid var(--neon-pink);
|
||||||
|
text-shadow: 0 0 6px rgba(255, 42, 109, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.missing-list { font-size: 0.85rem; color: var(--text-dim); }
|
||||||
|
|
||||||
|
label {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--text-dim);
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.2rem;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.flash {
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
margin: 0.5rem 0;
|
||||||
|
border-radius: 2px;
|
||||||
|
background: rgba(5, 217, 232, 0.1);
|
||||||
|
color: var(--neon-cyan);
|
||||||
|
border: 1px solid var(--neon-cyan);
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}What Can I Make? - Van{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<h1>What Can I Make?</h1>
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Meal</th><th>Status</th><th>Missing</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for meal in availability %}
|
||||||
|
<tr>
|
||||||
|
<td><a href="/meals/{{ meal.meal_id }}">{{ meal.meal_name }}</a></td>
|
||||||
|
<td>
|
||||||
|
{% if meal.can_make %}
|
||||||
|
<span class="badge yes">Ready</span>
|
||||||
|
{% else %}
|
||||||
|
<span class="badge no">Missing items</span>
|
||||||
|
{% endif %}
|
||||||
|
</td>
|
||||||
|
<td class="missing-list">
|
||||||
|
{% for m in meal.missing %}
|
||||||
|
{{ m.item_name }}: need {{ m.short }} more {{ m.unit }}{% if not loop.last %}, {% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{% block title %}Van Inventory{% endblock %}</title>
|
||||||
|
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||||
|
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||||
|
<link href="https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&display=swap" rel="stylesheet">
|
||||||
|
<link rel="stylesheet" href="/static/style.css">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<nav>
|
||||||
|
<a href="/">Inventory</a>
|
||||||
|
<a href="/meals">Meals</a>
|
||||||
|
<a href="/availability">What Can I Make?</a>
|
||||||
|
</nav>
|
||||||
|
{% block content %}{% endblock %}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}Inventory - Van{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<h1>Van Inventory</h1>
|
||||||
|
|
||||||
|
<form hx-post="/items" hx-target="#item-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
|
||||||
|
<label>Name <input type="text" name="name" required></label>
|
||||||
|
<label>Qty <input type="number" name="quantity" step="any" value="0" min="0" required></label>
|
||||||
|
<label>Unit <input type="text" name="unit" required placeholder="lbs, cans, etc"></label>
|
||||||
|
<label>Category <input type="text" name="category" placeholder="optional"></label>
|
||||||
|
<button type="submit">Add Item</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div id="item-list">
|
||||||
|
{% include "partials/item_rows.html" %}
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}{{ meal.name }} - Van{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<h1>{{ meal.name }}</h1>
|
||||||
|
{% if meal.instructions %}<p>{{ meal.instructions }}</p>{% endif %}
|
||||||
|
|
||||||
|
<h2>Ingredients</h2>
|
||||||
|
<form hx-post="/meals/{{ meal.id }}/ingredients" hx-target="#ingredient-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
|
||||||
|
<label>Item
|
||||||
|
<select name="item_id" required>
|
||||||
|
<option value="">--</option>
|
||||||
|
{% for item in items %}
|
||||||
|
<option value="{{ item.id }}">{{ item.name }} ({{ item.unit }})</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
<label>Qty needed <input type="number" name="quantity_needed" step="any" min="0.01" required></label>
|
||||||
|
<button type="submit">Add</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div id="ingredient-list">
|
||||||
|
{% include "partials/ingredient_rows.html" %}
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}Meals - Van{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<h1>Meals</h1>
|
||||||
|
|
||||||
|
<form hx-post="/meals" hx-target="#meal-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
|
||||||
|
<label>Name <input type="text" name="name" required></label>
|
||||||
|
<label>Instructions <input type="text" name="instructions" placeholder="optional"></label>
|
||||||
|
<button type="submit">Add Meal</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div id="meal-list">
|
||||||
|
{% include "partials/meal_rows.html" %}
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Item</th><th>Needed</th><th>Have</th><th>Unit</th><th></th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for mi in meal.ingredients %}
|
||||||
|
<tr>
|
||||||
|
<td>{{ mi.item.name }}</td>
|
||||||
|
<td>{{ mi.quantity_needed }}</td>
|
||||||
|
<td>{{ mi.item.quantity }}</td>
|
||||||
|
<td>{{ mi.item.unit }}</td>
|
||||||
|
<td><button class="danger" hx-delete="/meals/{{ meal.id }}/ingredients/{{ mi.item_id }}" hx-target="#ingredient-list" hx-swap="innerHTML" hx-confirm="Remove {{ mi.item.name }}?">X</button></td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Name</th><th>Qty</th><th>Unit</th><th>Category</th><th></th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for item in items %}
|
||||||
|
<tr>
|
||||||
|
<td>{{ item.name }}</td>
|
||||||
|
<td>
|
||||||
|
<form hx-patch="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" style="display:inline; margin:0;">
|
||||||
|
<input type="number" name="quantity" value="{{ item.quantity }}" step="any" min="0" style="width:5rem">
|
||||||
|
<button type="submit" style="padding:0.2rem 0.5rem; font-size:0.8rem;">Update</button>
|
||||||
|
</form>
|
||||||
|
</td>
|
||||||
|
<td>{{ item.unit }}</td>
|
||||||
|
<td>{{ item.category or "" }}</td>
|
||||||
|
<td><button class="danger" hx-delete="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" hx-confirm="Delete {{ item.name }}?">X</button></td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Name</th><th>Ingredients</th><th>Instructions</th><th></th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for meal in meals %}
|
||||||
|
<tr>
|
||||||
|
<td><a href="/meals/{{ meal.id }}">{{ meal.name }}</a></td>
|
||||||
|
<td>{{ meal.ingredients | length }}</td>
|
||||||
|
<td>{{ (meal.instructions or "")[:50] }}</td>
|
||||||
|
<td><button class="danger" hx-delete="/meals/{{ meal.id }}" hx-target="#meal-list" hx-swap="innerHTML" hx-confirm="Delete {{ meal.name }}?">X</button></td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
@@ -257,7 +257,7 @@ def update_weather(config: Config) -> None:
|
|||||||
|
|
||||||
logger.info(f"Masked location: {masked_lat}, {masked_lon}")
|
logger.info(f"Masked location: {masked_lat}, {masked_lon}")
|
||||||
|
|
||||||
weather = fetch_weather(config.pirate_weather_api_key, masked_lat, masked_lon)
|
weather = fetch_weather(config.pirate_weather_api_key, lat, lon)
|
||||||
logger.info(f"Weather: {weather.temperature}°F, {weather.condition}")
|
logger.info(f"Weather: {weather.temperature}°F, {weather.condition}")
|
||||||
|
|
||||||
post_to_ha(config.ha_url, config.ha_token, weather)
|
post_to_ha(config.ha_url, config.ha_token, weather)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
"""Models for van weather service."""
|
"""Models for van weather service."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from datetime import datetime
|
||||||
|
|
||||||
from datetime import datetime # noqa: TC003 This is required for pydantic
|
|
||||||
|
|
||||||
from pydantic import BaseModel, field_serializer
|
from pydantic import BaseModel, field_serializer
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Dataset:
|
|||||||
self.written = int(properties["written"]["value"])
|
self.written = int(properties["written"]["value"])
|
||||||
self.xattr = properties["xattr"]["value"]
|
self.xattr = properties["xattr"]["value"]
|
||||||
|
|
||||||
def get_snapshots(self) -> list[Snapshot]:
|
def get_snapshots(self) -> list[Snapshot] | None:
|
||||||
"""Get all snapshots from zfs and process then is test dicts of sets."""
|
"""Get all snapshots from zfs and process then is test dicts of sets."""
|
||||||
snapshots_data = _zfs_list(f"zfs list -t snapshot -pHj {self.name} -o all")
|
snapshots_data = _zfs_list(f"zfs list -t snapshot -pHj {self.name} -o all")
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ class Dataset:
|
|||||||
if return_code == 0:
|
if return_code == 0:
|
||||||
return "snapshot created"
|
return "snapshot created"
|
||||||
|
|
||||||
snapshots = self.get_snapshots()
|
if snapshots := self.get_snapshots():
|
||||||
snapshot_names = {snapshot.name for snapshot in snapshots}
|
snapshot_names = {snapshot.name for snapshot in snapshots}
|
||||||
if snapshot_name in snapshot_names:
|
if snapshot_name in snapshot_names:
|
||||||
return f"Snapshot {snapshot_name} already exists for {self.name}"
|
return f"Snapshot {snapshot_name} already exists for {self.name}"
|
||||||
|
|||||||
@@ -28,12 +28,7 @@
|
|||||||
networking = {
|
networking = {
|
||||||
hostName = "bob";
|
hostName = "bob";
|
||||||
hostId = "7c678a41";
|
hostId = "7c678a41";
|
||||||
firewall = {
|
firewall.enable = true;
|
||||||
enable = true;
|
|
||||||
allowedTCPPorts = [
|
|
||||||
8000
|
|
||||||
];
|
|
||||||
};
|
|
||||||
networkmanager.enable = true;
|
networkmanager.enable = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -30,11 +30,6 @@
|
|||||||
keyFile = "/dev/disk/by-id/usb-Samsung_Flash_Drive_FIT_0374620080067131-0:0";
|
keyFile = "/dev/disk/by-id/usb-Samsung_Flash_Drive_FIT_0374620080067131-0:0";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
zfs.extraPools = [
|
|
||||||
"storage"
|
|
||||||
];
|
|
||||||
|
|
||||||
kernelModules = [ "kvm-amd" ];
|
kernelModules = [ "kvm-amd" ];
|
||||||
extraModulePackages = [ ];
|
extraModulePackages = [ ];
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
{
|
||||||
|
pkgs,
|
||||||
|
inputs,
|
||||||
|
...
|
||||||
|
}:
|
||||||
|
{
|
||||||
|
networking.firewall.allowedTCPPorts = [ 8001 ];
|
||||||
|
|
||||||
|
users = {
|
||||||
|
users.vaninventory = {
|
||||||
|
isSystemUser = true;
|
||||||
|
group = "vaninventory";
|
||||||
|
};
|
||||||
|
groups.vaninventory = { };
|
||||||
|
};
|
||||||
|
|
||||||
|
systemd.services.van_inventory = {
|
||||||
|
description = "Van Inventory API";
|
||||||
|
after = [
|
||||||
|
"network.target"
|
||||||
|
"postgresql.service"
|
||||||
|
];
|
||||||
|
requires = [ "postgresql.service" ];
|
||||||
|
wantedBy = [ "multi-user.target" ];
|
||||||
|
|
||||||
|
environment = {
|
||||||
|
PYTHONPATH = "${inputs.self}/";
|
||||||
|
VAN_INVENTORY_DB = "vaninventory";
|
||||||
|
VAN_INVENTORY_USER = "vaninventory";
|
||||||
|
VAN_INVENTORY_HOST = "/run/postgresql";
|
||||||
|
VAN_INVENTORY_PORT = "5432";
|
||||||
|
};
|
||||||
|
|
||||||
|
serviceConfig = {
|
||||||
|
Type = "simple";
|
||||||
|
User = "vaninventory";
|
||||||
|
Group = "vaninventory";
|
||||||
|
ExecStart = "${pkgs.my_python}/bin/python -m python.van_inventory.main --host 0.0.0.0 --port 8001";
|
||||||
|
Restart = "on-failure";
|
||||||
|
RestartSec = "5s";
|
||||||
|
StandardOutput = "journal";
|
||||||
|
StandardError = "journal";
|
||||||
|
NoNewPrivileges = true;
|
||||||
|
ProtectSystem = "strict";
|
||||||
|
ProtectHome = "read-only";
|
||||||
|
PrivateTmp = true;
|
||||||
|
ReadOnlyPaths = [ "${inputs.self}" ];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
"""Shared test fixtures."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine, event
|
||||||
|
|
||||||
|
from python.orm.signal_bot.base import SignalBotBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sqlite_engine() -> Generator[Engine]:
|
||||||
|
"""Create an in-memory SQLite engine for testing."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _set_sqlite_pragma(dbapi_connection, _connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
SignalBotBase.metadata.create_all(engine)
|
||||||
|
yield engine
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(sqlite_engine: Engine) -> Generator[Engine]:
|
||||||
|
"""Yield the shared engine after cleaning all tables between tests."""
|
||||||
|
yield sqlite_engine
|
||||||
|
with sqlite_engine.begin() as connection:
|
||||||
|
for table in reversed(SignalBotBase.metadata.sorted_tables):
|
||||||
|
connection.execute(table.delete())
|
||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user