Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 87e9963840 |
@@ -8,7 +8,6 @@ jobs:
|
||||
lockfile:
|
||||
runs-on: self-hosted
|
||||
permissions:
|
||||
actions: write
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
|
||||
+1
-2
@@ -171,5 +171,4 @@ frontend/dist/
|
||||
frontend/node_modules/
|
||||
|
||||
# data from testing llms
|
||||
data/*
|
||||
.ebook_search_bm25
|
||||
data/*
|
||||
@@ -0,0 +1,12 @@
|
||||
## Dev environment tips
|
||||
|
||||
- use treefmt to format all files
|
||||
- make python code ruff compliant
|
||||
- use pytest to test python code
|
||||
- always use the minimum amount of complexity
|
||||
- if judgment calls are easy to reverse make them. if not ask me first
|
||||
- Match existing code style.
|
||||
- Use builtin helpers getenv() over os.environ.get.
|
||||
- Prefer single-purpose functions over “do everything” helpers.
|
||||
- Avoid compatibility branches like PG_USER and POSTGRESQL_URL unless requested.
|
||||
- Keep helpers only if reused or they simplify the code otherwise inline.
|
||||
File diff suppressed because one or more lines are too long
Generated
+15
-15
@@ -8,11 +8,11 @@
|
||||
},
|
||||
"locked": {
|
||||
"dir": "pkgs/firefox-addons",
|
||||
"lastModified": 1781150628,
|
||||
"narHash": "sha256-b4mp8l3qWuSCyYYo9HSngDtcB3PpecYiOXjULrjwwlw=",
|
||||
"lastModified": 1780733803,
|
||||
"narHash": "sha256-QBJPq12P1DAXFGezoEJaSO/xPUrPlnaI3ddSaMG2JpM=",
|
||||
"owner": "rycee",
|
||||
"repo": "nur-expressions",
|
||||
"rev": "753319310f4673a2dabbfab87482187b40bf9bac",
|
||||
"rev": "c80b0aa94392c5f3612ac797108f6d952752036d",
|
||||
"type": "gitlab"
|
||||
},
|
||||
"original": {
|
||||
@@ -29,11 +29,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1781189114,
|
||||
"narHash": "sha256-5inaamLgUMWy+MOBE9ChF9QAF1o/74LFuHkI0W/9rqc=",
|
||||
"lastModified": 1780679734,
|
||||
"narHash": "sha256-KmRNvpNOb7QEORa06bVgjW9kITcx0VhsI7w0vhmZyD8=",
|
||||
"owner": "nix-community",
|
||||
"repo": "home-manager",
|
||||
"rev": "486595d2cf49cfcd649b58a284fa11ac0e34da22",
|
||||
"rev": "b2b7db486e06e098711dc291bb25db82850e1d16",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -47,11 +47,11 @@
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1781168557,
|
||||
"narHash": "sha256-LOnLQ2tpYF9gqIDDr3+j3DbpJJr/QCH6zPRT2GzEUOE=",
|
||||
"lastModified": 1780310866,
|
||||
"narHash": "sha256-fPBRVf6A5xlACYcOI59shGrjURuvwu0lRsDoSCEXt/I=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixos-hardware",
|
||||
"rev": "6358ff76821101c178e3ab4919a62799bfe3652e",
|
||||
"rev": "4ed851c979641e28597a05086332d75cdc9e395f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -76,11 +76,11 @@
|
||||
},
|
||||
"nixpkgs-master": {
|
||||
"locked": {
|
||||
"lastModified": 1781229721,
|
||||
"narHash": "sha256-ORvqDbb/LYxiJljGIejapjkc/kJbVote2N1WSb9W45I=",
|
||||
"lastModified": 1780798858,
|
||||
"narHash": "sha256-4KLc5ZMjfMQosXA2JasUgZTk3i+c/i1zMH4custtmI0=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "173d0ad7a974f8543a9ab01d2271b2e290341b33",
|
||||
"rev": "92840095e65b9970125843175f4be974b71a92ad",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -108,11 +108,11 @@
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1781074563,
|
||||
"narHash": "sha256-md8WlXOlfnIeHeOScMTTHFyf2d6iaTwPl2apR5EQ3P4=",
|
||||
"lastModified": 1780243769,
|
||||
"narHash": "sha256-x5UQuRsH3MqI0U9afaXSNqzTPSeZlRLvFAav2Ux1pNw=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "9ae611a455b90cf061d8f332b977e387bda8e1ca",
|
||||
"rev": "331800de5053fcebacf6813adb5db9c9dca22a0c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
+1
-28
@@ -17,41 +17,15 @@
|
||||
|
||||
python-env = final: _prev: {
|
||||
my_python = final.python314.withPackages (
|
||||
ps:
|
||||
let
|
||||
bm25s = ps.buildPythonPackage rec {
|
||||
pname = "bm25s";
|
||||
version = "0.3.9";
|
||||
pyproject = true;
|
||||
|
||||
src = final.fetchPypi {
|
||||
inherit pname version;
|
||||
hash = "sha256-iVxnnZUrfeg1XttfPhpiCh4vKU0dQrkZvwghzOLi9Zc=";
|
||||
};
|
||||
|
||||
build-system = [ ps.setuptools ];
|
||||
dependencies = with ps; [
|
||||
numpy
|
||||
scipy
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "bm25s" ];
|
||||
};
|
||||
in
|
||||
with ps;
|
||||
[
|
||||
ps: with ps; [
|
||||
alembic
|
||||
apprise
|
||||
apscheduler
|
||||
beautifulsoup4
|
||||
ebooklib
|
||||
fastapi
|
||||
fastapi-cli
|
||||
httpx
|
||||
mypy
|
||||
numpy
|
||||
orjson
|
||||
pgvector
|
||||
polars
|
||||
psycopg
|
||||
pydantic
|
||||
@@ -65,7 +39,6 @@
|
||||
scalene
|
||||
sqlalchemy
|
||||
sqlalchemy
|
||||
bm25s
|
||||
tenacity
|
||||
textual
|
||||
tiktoken
|
||||
|
||||
@@ -84,6 +84,9 @@ lint.ignore = [
|
||||
"python/alembic/**" = [
|
||||
"INP001", # (perm) this creates LSP issues for alembic
|
||||
]
|
||||
"python/signal_bot/**" = [
|
||||
"D107", # (perm) class docstrings cover __init__
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
+1417
File diff suppressed because it is too large
Load Diff
+50
@@ -0,0 +1,50 @@
|
||||
"""adding FailedIngestion.
|
||||
|
||||
Revision ID: 2f43120e3ffc
|
||||
Revises: f99be864fe69
|
||||
Create Date: 2026-03-24 23:46:17.277897
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import DataScienceDevBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2f43120e3ffc"
|
||||
down_revision: str | None = "f99be864fe69"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = DataScienceDevBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"failed_ingestion",
|
||||
sa.Column("raw_line", sa.Text(), nullable=False),
|
||||
sa.Column("error", sa.Text(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_failed_ingestion")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("failed_ingestion", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
+2770
File diff suppressed because it is too large
Load Diff
+1391
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,72 @@
|
||||
"""Attach all partition tables to the posts parent table.
|
||||
|
||||
Alembic autogenerate creates partition tables as standalone tables but does not
|
||||
emit the ALTER TABLE ... ATTACH PARTITION statements needed for PostgreSQL to
|
||||
route inserts to the correct partition.
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 605b1794838f
|
||||
Create Date: 2026-03-25 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
from python.orm import DataScienceDevBase
|
||||
from python.orm.data_science_dev.posts.partitions import (
|
||||
PARTITION_END_YEAR,
|
||||
PARTITION_START_YEAR,
|
||||
iso_weeks_in_year,
|
||||
week_bounds,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: str | None = "605b1794838f"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = DataScienceDevBase.schema_name
|
||||
|
||||
ALREADY_ATTACHED_QUERY = text("""
|
||||
SELECT inhrelid::regclass::text
|
||||
FROM pg_inherits
|
||||
WHERE inhparent = :parent::regclass
|
||||
""")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Attach all weekly partition tables to the posts parent table."""
|
||||
connection = op.get_bind()
|
||||
already_attached = {row[0] for row in connection.execute(ALREADY_ATTACHED_QUERY, {"parent": f"{schema}.posts"})}
|
||||
|
||||
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||
table_name = f"posts_{year}_{week:02d}"
|
||||
qualified_name = f"{schema}.{table_name}"
|
||||
if qualified_name in already_attached:
|
||||
continue
|
||||
start, end = week_bounds(year, week)
|
||||
start_str = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
op.execute(
|
||||
f"ALTER TABLE {schema}.posts "
|
||||
f"ATTACH PARTITION {qualified_name} "
|
||||
f"FOR VALUES FROM ('{start_str}') TO ('{end_str}')"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Detach all weekly partition tables from the posts parent table."""
|
||||
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||
table_name = f"posts_{year}_{week:02d}"
|
||||
op.execute(f"ALTER TABLE {schema}.posts DETACH PARTITION {schema}.{table_name}")
|
||||
+153
@@ -0,0 +1,153 @@
|
||||
"""adding congress data.
|
||||
|
||||
Revision ID: 83bfc8af92d8
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-03-27 10:43:02.324510
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import DataScienceDevBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "83bfc8af92d8"
|
||||
down_revision: str | None = "a1b2c3d4e5f6"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = DataScienceDevBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"bill",
|
||||
sa.Column("congress", sa.Integer(), nullable=False),
|
||||
sa.Column("bill_type", sa.String(), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("title_short", sa.String(), nullable=True),
|
||||
sa.Column("official_title", sa.String(), nullable=True),
|
||||
sa.Column("status", sa.String(), nullable=True),
|
||||
sa.Column("status_at", sa.Date(), nullable=True),
|
||||
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
|
||||
sa.Column("subjects_top_term", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
|
||||
sa.UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_index("ix_bill_congress", "bill", ["congress"], unique=False, schema=schema)
|
||||
op.create_table(
|
||||
"legislator",
|
||||
sa.Column("bioguide_id", sa.Text(), nullable=False),
|
||||
sa.Column("thomas_id", sa.String(), nullable=True),
|
||||
sa.Column("lis_id", sa.String(), nullable=True),
|
||||
sa.Column("govtrack_id", sa.Integer(), nullable=True),
|
||||
sa.Column("opensecrets_id", sa.String(), nullable=True),
|
||||
sa.Column("fec_ids", sa.String(), nullable=True),
|
||||
sa.Column("first_name", sa.String(), nullable=False),
|
||||
sa.Column("last_name", sa.String(), nullable=False),
|
||||
sa.Column("official_full_name", sa.String(), nullable=True),
|
||||
sa.Column("nickname", sa.String(), nullable=True),
|
||||
sa.Column("birthday", sa.Date(), nullable=True),
|
||||
sa.Column("gender", sa.String(), nullable=True),
|
||||
sa.Column("current_party", sa.String(), nullable=True),
|
||||
sa.Column("current_state", sa.String(), nullable=True),
|
||||
sa.Column("current_district", sa.Integer(), nullable=True),
|
||||
sa.Column("current_chamber", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
|
||||
op.create_table(
|
||||
"bill_text",
|
||||
sa.Column("bill_id", sa.Integer(), nullable=False),
|
||||
sa.Column("version_code", sa.String(), nullable=False),
|
||||
sa.Column("version_name", sa.String(), nullable=True),
|
||||
sa.Column("text_content", sa.String(), nullable=True),
|
||||
sa.Column("date", sa.Date(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_bill_text_bill_id_bill"), ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text")),
|
||||
sa.UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"vote",
|
||||
sa.Column("congress", sa.Integer(), nullable=False),
|
||||
sa.Column("chamber", sa.String(), nullable=False),
|
||||
sa.Column("session", sa.Integer(), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=False),
|
||||
sa.Column("vote_type", sa.String(), nullable=True),
|
||||
sa.Column("question", sa.String(), nullable=True),
|
||||
sa.Column("result", sa.String(), nullable=True),
|
||||
sa.Column("result_text", sa.String(), nullable=True),
|
||||
sa.Column("vote_date", sa.Date(), nullable=False),
|
||||
sa.Column("yea_count", sa.Integer(), nullable=True),
|
||||
sa.Column("nay_count", sa.Integer(), nullable=True),
|
||||
sa.Column("not_voting_count", sa.Integer(), nullable=True),
|
||||
sa.Column("present_count", sa.Integer(), nullable=True),
|
||||
sa.Column("bill_id", sa.Integer(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
|
||||
sa.UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_index("ix_vote_congress_chamber", "vote", ["congress", "chamber"], unique=False, schema=schema)
|
||||
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
|
||||
op.create_table(
|
||||
"vote_record",
|
||||
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||
sa.Column("position", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["legislator_id"],
|
||||
[f"{schema}.legislator.id"],
|
||||
name=op.f("fk_vote_record_legislator_id_legislator"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("vote_record", schema=schema)
|
||||
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
|
||||
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
|
||||
op.drop_table("vote", schema=schema)
|
||||
op.drop_table("bill_text", schema=schema)
|
||||
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
|
||||
op.drop_table("legislator", schema=schema)
|
||||
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
|
||||
op.drop_table("bill", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
+58
@@ -0,0 +1,58 @@
|
||||
"""adding LegislatorSocialMedia.
|
||||
|
||||
Revision ID: 5cd7eee3549d
|
||||
Revises: 83bfc8af92d8
|
||||
Create Date: 2026-03-29 11:53:44.224799
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import DataScienceDevBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5cd7eee3549d"
|
||||
down_revision: str | None = "83bfc8af92d8"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = DataScienceDevBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"legislator_social_media",
|
||||
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||
sa.Column("platform", sa.String(), nullable=False),
|
||||
sa.Column("account_name", sa.String(), nullable=False),
|
||||
sa.Column("url", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["legislator_id"],
|
||||
[f"{schema}.legislator.id"],
|
||||
name=op.f("fk_legislator_social_media_legislator_id_legislator"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_social_media")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("legislator_social_media", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
-93
@@ -1,93 +0,0 @@
|
||||
"""adding audiobook libreary metadata.
|
||||
|
||||
Revision ID: d7864d1ffc17
|
||||
Revises: c8a794340928
|
||||
Create Date: 2026-06-03 20:24:09.200837
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import RichieBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d7864d1ffc17"
|
||||
down_revision: str | None = "c8a794340928"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = RichieBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"audiobook_author",
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook_author")),
|
||||
sa.UniqueConstraint("name", name=op.f("uq_audiobook_author_name")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"audiobook_series",
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("author_id", sa.Integer(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["author_id"],
|
||||
[f"{schema}.audiobook_author.id"],
|
||||
name=op.f("fk_audiobook_series_author_id_audiobook_author"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook_series")),
|
||||
sa.UniqueConstraint("author_id", "name", name=op.f("uq_audiobook_series_author_id")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"audiobook",
|
||||
sa.Column("title", sa.String(), nullable=False),
|
||||
sa.Column("author_id", sa.Integer(), nullable=False),
|
||||
sa.Column("series_id", sa.Integer(), nullable=True),
|
||||
sa.Column("series_index", sa.Integer(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["author_id"],
|
||||
[f"{schema}.audiobook_author.id"],
|
||||
name=op.f("fk_audiobook_author_id_audiobook_author"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["series_id"],
|
||||
[f"{schema}.audiobook_series.id"],
|
||||
name=op.f("fk_audiobook_series_id_audiobook_series"),
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("audiobook", schema=schema)
|
||||
op.drop_table("audiobook_series", schema=schema)
|
||||
op.drop_table("audiobook_author", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,200 +0,0 @@
|
||||
"""add ebook search tables.
|
||||
|
||||
Revision ID: 2db132cace1a
|
||||
Revises: b3c60cc5beb5
|
||||
Create Date: 2026-06-10 22:10:54.379159
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pgvector
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import RichieBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2db132cace1a"
|
||||
down_revision: str | None = "b3c60cc5beb5"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = RichieBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"ebook_embedding_model",
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("dimension", sa.Integer(), nullable=False),
|
||||
sa.Column("is_default", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_embedding_model")),
|
||||
sa.UniqueConstraint("name", name=op.f("uq_ebook_embedding_model_name")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"ebook_source",
|
||||
sa.Column("title", sa.String(), nullable=False),
|
||||
sa.Column("author", sa.String(), nullable=True),
|
||||
sa.Column("language", sa.String(), nullable=True),
|
||||
sa.Column("publisher", sa.String(), nullable=True),
|
||||
sa.Column("identifier", sa.String(), nullable=True),
|
||||
sa.Column("file_path", sa.String(), nullable=False),
|
||||
sa.Column("file_sha256", sa.String(length=64), nullable=False),
|
||||
sa.Column("file_mtime", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("file_size", sa.BigInteger(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_source")),
|
||||
sa.UniqueConstraint("file_path", name=op.f("uq_ebook_source_file_path")),
|
||||
sa.UniqueConstraint("file_sha256", name=op.f("uq_ebook_source_file_sha256")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"ebook_chapter",
|
||||
sa.Column("source_id", sa.Integer(), nullable=False),
|
||||
sa.Column("spine_index", sa.Integer(), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("href", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_id"],
|
||||
[f"{schema}.ebook_source.id"],
|
||||
name=op.f("fk_ebook_chapter_source_id_ebook_source"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chapter")),
|
||||
sa.UniqueConstraint("source_id", "spine_index", name=op.f("uq_ebook_chapter_source_id")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"ebook_chunk",
|
||||
sa.Column("source_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chapter_id", sa.Integer(), nullable=True),
|
||||
sa.Column("chunk_index", sa.Integer(), nullable=False),
|
||||
sa.Column("text", sa.String(), nullable=False),
|
||||
sa.Column("token_start", sa.Integer(), nullable=False),
|
||||
sa.Column("token_count", sa.Integer(), nullable=False),
|
||||
sa.Column("page_label", sa.String(), nullable=True),
|
||||
sa.Column("content_sha256", sa.String(length=64), nullable=False),
|
||||
sa.Column("search_text", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chapter_id"],
|
||||
[f"{schema}.ebook_chapter.id"],
|
||||
name=op.f("fk_ebook_chunk_chapter_id_ebook_chapter"),
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_id"],
|
||||
[f"{schema}.ebook_source.id"],
|
||||
name=op.f("fk_ebook_chunk_source_id_ebook_source"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk")),
|
||||
sa.UniqueConstraint("source_id", "chunk_index", name="uq_ebook_chunk_source_id_chunk_index"),
|
||||
sa.UniqueConstraint("source_id", "content_sha256", name="uq_ebook_chunk_source_id_content_sha256"),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"ebook_chunk_embedding_1024",
|
||||
sa.Column("chunk_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("model_id", sa.Integer(), nullable=False),
|
||||
sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=1024), nullable=False),
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chunk_id"],
|
||||
[f"{schema}.ebook_chunk.id"],
|
||||
name=op.f("fk_ebook_chunk_embedding_1024_chunk_id_ebook_chunk"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_id"],
|
||||
[f"{schema}.ebook_embedding_model.id"],
|
||||
name=op.f("fk_ebook_chunk_embedding_1024_model_id_ebook_embedding_model"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk_embedding_1024")),
|
||||
sa.UniqueConstraint("chunk_id", "model_id", name=op.f("uq_ebook_chunk_embedding_1024_chunk_id")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"ebook_chunk_embedding_2560",
|
||||
sa.Column("chunk_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("model_id", sa.Integer(), nullable=False),
|
||||
sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=2560), nullable=False),
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chunk_id"],
|
||||
[f"{schema}.ebook_chunk.id"],
|
||||
name=op.f("fk_ebook_chunk_embedding_2560_chunk_id_ebook_chunk"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_id"],
|
||||
[f"{schema}.ebook_embedding_model.id"],
|
||||
name=op.f("fk_ebook_chunk_embedding_2560_model_id_ebook_embedding_model"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk_embedding_2560")),
|
||||
sa.UniqueConstraint("chunk_id", "model_id", name=op.f("uq_ebook_chunk_embedding_2560_chunk_id")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"ebook_chunk_embedding_4096",
|
||||
sa.Column("chunk_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("model_id", sa.Integer(), nullable=False),
|
||||
sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=4096), nullable=False),
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chunk_id"],
|
||||
[f"{schema}.ebook_chunk.id"],
|
||||
name=op.f("fk_ebook_chunk_embedding_4096_chunk_id_ebook_chunk"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_id"],
|
||||
[f"{schema}.ebook_embedding_model.id"],
|
||||
name=op.f("fk_ebook_chunk_embedding_4096_model_id_ebook_embedding_model"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk_embedding_4096")),
|
||||
sa.UniqueConstraint("chunk_id", "model_id", name=op.f("uq_ebook_chunk_embedding_4096_chunk_id")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("ebook_chunk_embedding_4096", schema=schema)
|
||||
op.drop_table("ebook_chunk_embedding_2560", schema=schema)
|
||||
op.drop_table("ebook_chunk_embedding_1024", schema=schema)
|
||||
op.drop_table("ebook_chunk", schema=schema)
|
||||
op.drop_table("ebook_chapter", schema=schema)
|
||||
op.drop_table("ebook_source", schema=schema)
|
||||
op.drop_table("ebook_embedding_model", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
-63
@@ -1,63 +0,0 @@
|
||||
"""updated series_index to float and added UniqueConstraint to audiobook and audiobook_author.
|
||||
|
||||
Revision ID: b3c60cc5beb5
|
||||
Revises: d7864d1ffc17
|
||||
Create Date: 2026-06-10 20:02:43.073725
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import RichieBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b3c60cc5beb5"
|
||||
down_revision: str | None = "d7864d1ffc17"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = RichieBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"audiobook",
|
||||
"series_index",
|
||||
existing_type=sa.INTEGER(),
|
||||
type_=sa.Float(),
|
||||
existing_nullable=False,
|
||||
schema=schema,
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
op.f("uq_audiobook_author_id"),
|
||||
"audiobook",
|
||||
["author_id", "series_id", "title"],
|
||||
schema=schema,
|
||||
postgresql_nulls_not_distinct=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(op.f("uq_audiobook_author_id"), "audiobook", schema=schema, type_="unique")
|
||||
op.alter_column(
|
||||
"audiobook",
|
||||
"series_index",
|
||||
existing_type=sa.Float(),
|
||||
type_=sa.INTEGER(),
|
||||
existing_nullable=False,
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
-54
@@ -1,54 +0,0 @@
|
||||
"""add 1024 ebook embedding cosine index.
|
||||
|
||||
Revision ID: c460105682d2
|
||||
Revises: 2db132cace1a
|
||||
Create Date: 2026-06-13 19:53:45.680289
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from alembic import op
|
||||
|
||||
from python.orm import RichieBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c460105682d2"
|
||||
down_revision: str | None = "2db132cace1a"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = RichieBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(
|
||||
"ix_ebook_chunk_embedding_1024_embedding_cosine",
|
||||
"ebook_chunk_embedding_1024",
|
||||
["embedding"],
|
||||
unique=False,
|
||||
schema=schema,
|
||||
postgresql_using="hnsw",
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_ebook_chunk_embedding_1024_embedding_cosine",
|
||||
table_name="ebook_chunk_embedding_1024",
|
||||
schema=schema,
|
||||
postgresql_using="hnsw",
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
+100
@@ -0,0 +1,100 @@
|
||||
"""seprating signal_bot database.
|
||||
|
||||
Revision ID: 6eaf696e07a5
|
||||
Revises:
|
||||
Create Date: 2026-03-17 21:35:37.612672
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from python.orm import SignalBotBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6eaf696e07a5"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = SignalBotBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"dead_letter_message",
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("message", sa.Text(), nullable=False),
|
||||
sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"status", postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema), nullable=False
|
||||
),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_dead_letter_message")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"role",
|
||||
sa.Column("name", sa.String(length=50), nullable=False),
|
||||
sa.Column("id", sa.SmallInteger(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
|
||||
sa.UniqueConstraint("name", name=op.f("uq_role_name")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"signal_device",
|
||||
sa.Column("phone_number", sa.String(length=50), nullable=False),
|
||||
sa.Column("safety_number", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"trust_level",
|
||||
postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_signal_device")),
|
||||
sa.UniqueConstraint("phone_number", name=op.f("uq_signal_device_phone_number")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"device_role",
|
||||
sa.Column("device_id", sa.Integer(), nullable=False),
|
||||
sa.Column("role_id", sa.SmallInteger(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["device_id"], [f"{schema}.signal_device.id"], name=op.f("fk_device_role_device_id_signal_device")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["role_id"], [f"{schema}.role.id"], name=op.f("fk_device_role_role_id_role")),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_device_role")),
|
||||
sa.UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("device_role", schema=schema)
|
||||
op.drop_table("signal_device", schema=schema)
|
||||
op.drop_table("role", schema=schema)
|
||||
op.drop_table("dead_letter_message", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,72 @@
|
||||
"""test.
|
||||
|
||||
Revision ID: 66bdd532bcab
|
||||
Revises: 6eaf696e07a5
|
||||
Create Date: 2026-03-18 19:21:14.561568
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from python.orm import SignalBotBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "66bdd532bcab"
|
||||
down_revision: str | None = "6eaf696e07a5"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = SignalBotBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"dead_letter_message",
|
||||
"status",
|
||||
existing_type=postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
|
||||
type_=sa.Enum("UNPROCESSED", "PROCESSED", name="message_status", native_enum=False),
|
||||
existing_nullable=False,
|
||||
schema=schema,
|
||||
)
|
||||
op.alter_column(
|
||||
"signal_device",
|
||||
"trust_level",
|
||||
existing_type=postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
||||
type_=sa.Enum("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", native_enum=False),
|
||||
existing_nullable=False,
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"signal_device",
|
||||
"trust_level",
|
||||
existing_type=sa.Enum("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", native_enum=False),
|
||||
type_=postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
||||
existing_nullable=False,
|
||||
schema=schema,
|
||||
)
|
||||
op.alter_column(
|
||||
"dead_letter_message",
|
||||
"status",
|
||||
existing_type=sa.Enum("UNPROCESSED", "PROCESSED", name="message_status", native_enum=False),
|
||||
type_=postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
|
||||
existing_nullable=False,
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
+1
-1
@@ -9,9 +9,9 @@ import typer
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from python.api.middleware import ZstdMiddleware
|
||||
from python.api.routers import contact_router, views_router
|
||||
from python.common import configure_logger
|
||||
from python.fastapi_tools import ZstdMiddleware
|
||||
from python.orm.common import get_postgres_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Zstd response compression middleware."""
|
||||
"""Middleware for the FastAPI application."""
|
||||
|
||||
from compression import zstd
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from python.fastapi_tools.db import DbSession
|
||||
from python.api.dependencies import DbSession
|
||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from python.fastapi_tools.db import DbSession
|
||||
from python.api.dependencies import DbSession
|
||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
"""Data science CLI tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,613 @@
|
||||
"""Ingestion pipeline for loading congress data from unitedstates/congress JSON files.
|
||||
|
||||
Loads legislators, bills, votes, vote records, and bill text into the data_science_dev database.
|
||||
Expects the parent directory to contain congress-tracker/ and congress-legislators/ as siblings.
|
||||
|
||||
Usage:
|
||||
ingest-congress /path/to/parent/
|
||||
ingest-congress /path/to/parent/ --congress 118
|
||||
ingest-congress /path/to/parent/ --congress 118 --only bills
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path # noqa: TC003 needed at runtime for typer CLI argument
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import orjson
|
||||
import typer
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_postgres_engine
|
||||
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, LegislatorSocialMedia, Vote, VoteRecord
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 10_000
|
||||
|
||||
app = typer.Typer(help="Ingest unitedstates/congress data into data_science_dev.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
parent_dir: Annotated[
|
||||
Path,
|
||||
typer.Argument(help="Parent directory containing congress-tracker/ and congress-legislators/"),
|
||||
],
|
||||
congress: Annotated[int | None, typer.Option(help="Only ingest a specific congress number")] = None,
|
||||
only: Annotated[
|
||||
str | None,
|
||||
typer.Option(help="Only run a specific step: legislators, social-media, bills, votes, bill-text"),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Ingest congress data from unitedstates/congress JSON files."""
|
||||
configure_logger(level="INFO")
|
||||
|
||||
data_dir = parent_dir / "congress-tracker/congress/data/"
|
||||
legislators_dir = parent_dir / "congress-legislators"
|
||||
|
||||
if not data_dir.is_dir():
|
||||
typer.echo(f"Expected congress-tracker/ directory: {data_dir}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if not legislators_dir.is_dir():
|
||||
typer.echo(f"Expected congress-legislators/ directory: {legislators_dir}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
|
||||
congress_dirs = _resolve_congress_dirs(data_dir, congress)
|
||||
if not congress_dirs:
|
||||
typer.echo("No congress directories found.", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
logger.info("Found %d congress directories to process", len(congress_dirs))
|
||||
|
||||
steps: dict[str, tuple] = {
|
||||
"legislators": (ingest_legislators, (engine, legislators_dir)),
|
||||
"legislators-social-media": (ingest_social_media, (engine, legislators_dir)),
|
||||
"bills": (ingest_bills, (engine, congress_dirs)),
|
||||
"votes": (ingest_votes, (engine, congress_dirs)),
|
||||
"bill-text": (ingest_bill_text, (engine, congress_dirs)),
|
||||
}
|
||||
|
||||
if only:
|
||||
if only not in steps:
|
||||
typer.echo(f"Unknown step: {only}. Choose from: {', '.join(steps)}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
steps = {only: steps[only]}
|
||||
|
||||
for step_name, (step_func, step_args) in steps.items():
|
||||
logger.info("=== Starting step: %s ===", step_name)
|
||||
step_func(*step_args)
|
||||
logger.info("=== Finished step: %s ===", step_name)
|
||||
|
||||
logger.info("ingest-congress done")
|
||||
|
||||
|
||||
def _resolve_congress_dirs(data_dir: Path, congress: int | None) -> list[Path]:
|
||||
"""Find congress number directories under data_dir."""
|
||||
if congress is not None:
|
||||
target = data_dir / str(congress)
|
||||
return [target] if target.is_dir() else []
|
||||
return sorted(path for path in data_dir.iterdir() if path.is_dir() and path.name.isdigit())
|
||||
|
||||
|
||||
def _flush_batch(session: Session, batch: list[object], label: str) -> int:
|
||||
"""Add a batch of ORM objects to the session and commit. Returns count added."""
|
||||
if not batch:
|
||||
return 0
|
||||
session.add_all(batch)
|
||||
session.commit()
|
||||
count = len(batch)
|
||||
logger.info("Committed %d %s", count, label)
|
||||
batch.clear()
|
||||
return count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legislators — loaded from congress-legislators YAML files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ingest_legislators(engine: Engine, legislators_dir: Path) -> None:
|
||||
"""Load legislators from congress-legislators YAML files."""
|
||||
legislators_data = _load_legislators_yaml(legislators_dir)
|
||||
logger.info("Loaded %d legislators from YAML files", len(legislators_data))
|
||||
|
||||
with Session(engine) as session:
|
||||
existing_legislators = {
|
||||
legislator.bioguide_id: legislator for legislator in session.scalars(select(Legislator)).all()
|
||||
}
|
||||
logger.info("Found %d existing legislators in DB", len(existing_legislators))
|
||||
|
||||
total_inserted = 0
|
||||
total_updated = 0
|
||||
for entry in legislators_data:
|
||||
bioguide_id = entry.get("id", {}).get("bioguide")
|
||||
if not bioguide_id:
|
||||
continue
|
||||
|
||||
fields = _parse_legislator(entry)
|
||||
if existing := existing_legislators.get(bioguide_id):
|
||||
changed = False
|
||||
for field, value in fields.items():
|
||||
if value is not None and getattr(existing, field) != value:
|
||||
setattr(existing, field, value)
|
||||
changed = True
|
||||
if changed:
|
||||
total_updated += 1
|
||||
else:
|
||||
session.add(Legislator(bioguide_id=bioguide_id, **fields))
|
||||
total_inserted += 1
|
||||
|
||||
session.commit()
|
||||
logger.info("Inserted %d new legislators, updated %d existing", total_inserted, total_updated)
|
||||
|
||||
|
||||
def _load_legislators_yaml(legislators_dir: Path) -> list[dict]:
|
||||
"""Load and combine legislators-current.yaml and legislators-historical.yaml."""
|
||||
legislators: list[dict] = []
|
||||
for filename in ("legislators-current.yaml", "legislators-historical.yaml"):
|
||||
path = legislators_dir / filename
|
||||
if not path.exists():
|
||||
logger.warning("Legislators file not found: %s", path)
|
||||
continue
|
||||
with path.open() as file:
|
||||
data = yaml.safe_load(file)
|
||||
if isinstance(data, list):
|
||||
legislators.extend(data)
|
||||
return legislators
|
||||
|
||||
|
||||
def _parse_legislator(entry: dict) -> dict:
|
||||
"""Extract Legislator fields from a congress-legislators YAML entry."""
|
||||
ids = entry.get("id", {})
|
||||
name = entry.get("name", {})
|
||||
bio = entry.get("bio", {})
|
||||
terms = entry.get("terms", [])
|
||||
latest_term = terms[-1] if terms else {}
|
||||
|
||||
fec_ids = ids.get("fec")
|
||||
fec_ids_joined = ",".join(fec_ids) if isinstance(fec_ids, list) else fec_ids
|
||||
|
||||
chamber = latest_term.get("type")
|
||||
chamber_normalized = {"rep": "House", "sen": "Senate"}.get(chamber, chamber)
|
||||
|
||||
return {
|
||||
"thomas_id": ids.get("thomas"),
|
||||
"lis_id": ids.get("lis"),
|
||||
"govtrack_id": ids.get("govtrack"),
|
||||
"opensecrets_id": ids.get("opensecrets"),
|
||||
"fec_ids": fec_ids_joined,
|
||||
"first_name": name.get("first"),
|
||||
"last_name": name.get("last"),
|
||||
"official_full_name": name.get("official_full"),
|
||||
"nickname": name.get("nickname"),
|
||||
"birthday": bio.get("birthday"),
|
||||
"gender": bio.get("gender"),
|
||||
"current_party": latest_term.get("party"),
|
||||
"current_state": latest_term.get("state"),
|
||||
"current_district": latest_term.get("district"),
|
||||
"current_chamber": chamber_normalized,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Social Media — loaded from legislators-social-media.yaml
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SOCIAL_MEDIA_PLATFORMS = {
|
||||
"twitter": "https://twitter.com/{account}",
|
||||
"facebook": "https://facebook.com/{account}",
|
||||
"youtube": "https://youtube.com/{account}",
|
||||
"instagram": "https://instagram.com/{account}",
|
||||
"mastodon": None,
|
||||
}
|
||||
|
||||
|
||||
def ingest_social_media(engine: Engine, legislators_dir: Path) -> None:
|
||||
"""Load social media accounts from legislators-social-media.yaml."""
|
||||
social_media_path = legislators_dir / "legislators-social-media.yaml"
|
||||
if not social_media_path.exists():
|
||||
logger.warning("Social media file not found: %s", social_media_path)
|
||||
return
|
||||
|
||||
with social_media_path.open() as file:
|
||||
social_media_data = yaml.safe_load(file)
|
||||
|
||||
if not isinstance(social_media_data, list):
|
||||
logger.warning("Unexpected format in %s", social_media_path)
|
||||
return
|
||||
|
||||
logger.info("Loaded %d entries from legislators-social-media.yaml", len(social_media_data))
|
||||
|
||||
with Session(engine) as session:
|
||||
legislator_map = _build_legislator_map(session)
|
||||
existing_accounts = {
|
||||
(account.legislator_id, account.platform)
|
||||
for account in session.scalars(select(LegislatorSocialMedia)).all()
|
||||
}
|
||||
logger.info("Found %d existing social media accounts in DB", len(existing_accounts))
|
||||
|
||||
total_inserted = 0
|
||||
total_updated = 0
|
||||
for entry in social_media_data:
|
||||
bioguide_id = entry.get("id", {}).get("bioguide")
|
||||
if not bioguide_id:
|
||||
continue
|
||||
|
||||
legislator_id = legislator_map.get(bioguide_id)
|
||||
if legislator_id is None:
|
||||
continue
|
||||
|
||||
social = entry.get("social", {})
|
||||
for platform, url_template in SOCIAL_MEDIA_PLATFORMS.items():
|
||||
account_name = social.get(platform)
|
||||
if not account_name:
|
||||
continue
|
||||
|
||||
url = url_template.format(account=account_name) if url_template else None
|
||||
|
||||
if (legislator_id, platform) in existing_accounts:
|
||||
total_updated += 1
|
||||
else:
|
||||
session.add(
|
||||
LegislatorSocialMedia(
|
||||
legislator_id=legislator_id,
|
||||
platform=platform,
|
||||
account_name=str(account_name),
|
||||
url=url,
|
||||
source="https://github.com/unitedstates/congress-legislators",
|
||||
)
|
||||
)
|
||||
existing_accounts.add((legislator_id, platform))
|
||||
total_inserted += 1
|
||||
|
||||
session.commit()
|
||||
logger.info("Inserted %d new social media accounts, updated %d existing", total_inserted, total_updated)
|
||||
|
||||
|
||||
def _iter_voters(position_group: object) -> Iterator[dict]:
|
||||
"""Yield voter dicts from a vote position group (handles list, single dict, or string)."""
|
||||
if isinstance(position_group, dict):
|
||||
yield position_group
|
||||
elif isinstance(position_group, list):
|
||||
for voter in position_group:
|
||||
if isinstance(voter, dict):
|
||||
yield voter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bills
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ingest_bills(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||
"""Load bill data.json files."""
|
||||
with Session(engine) as session:
|
||||
existing_bills = {(bill.congress, bill.bill_type, bill.number) for bill in session.scalars(select(Bill)).all()}
|
||||
logger.info("Found %d existing bills in DB", len(existing_bills))
|
||||
|
||||
total_inserted = 0
|
||||
batch: list[Bill] = []
|
||||
for congress_dir in congress_dirs:
|
||||
bills_dir = congress_dir / "bills"
|
||||
if not bills_dir.is_dir():
|
||||
continue
|
||||
logger.info("Scanning bills from %s", congress_dir.name)
|
||||
for bill_file in bills_dir.rglob("data.json"):
|
||||
data = _read_json(bill_file)
|
||||
if data is None:
|
||||
continue
|
||||
bill = _parse_bill(data, existing_bills)
|
||||
if bill is not None:
|
||||
batch.append(bill)
|
||||
if len(batch) >= BATCH_SIZE:
|
||||
total_inserted += _flush_batch(session, batch, "bills")
|
||||
|
||||
total_inserted += _flush_batch(session, batch, "bills")
|
||||
logger.info("Inserted %d new bills total", total_inserted)
|
||||
|
||||
|
||||
def _parse_bill(data: dict, existing_bills: set[tuple[int, str, int]]) -> Bill | None:
|
||||
"""Parse a bill data.json dict into a Bill ORM object, skipping existing."""
|
||||
raw_congress = data.get("congress")
|
||||
bill_type = data.get("bill_type")
|
||||
raw_number = data.get("number")
|
||||
if raw_congress is None or bill_type is None or raw_number is None:
|
||||
return None
|
||||
congress = int(raw_congress)
|
||||
number = int(raw_number)
|
||||
if (congress, bill_type, number) in existing_bills:
|
||||
return None
|
||||
|
||||
sponsor_bioguide = None
|
||||
sponsor = data.get("sponsor")
|
||||
if sponsor:
|
||||
sponsor_bioguide = sponsor.get("bioguide_id")
|
||||
|
||||
return Bill(
|
||||
congress=congress,
|
||||
bill_type=bill_type,
|
||||
number=number,
|
||||
title=data.get("short_title") or data.get("official_title"),
|
||||
title_short=data.get("short_title"),
|
||||
official_title=data.get("official_title"),
|
||||
status=data.get("status"),
|
||||
status_at=data.get("status_at"),
|
||||
sponsor_bioguide_id=sponsor_bioguide,
|
||||
subjects_top_term=data.get("subjects_top_term"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Votes (and vote records)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ingest_votes(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||
"""Load vote data.json files with their vote records."""
|
||||
with Session(engine) as session:
|
||||
legislator_map = _build_legislator_map(session)
|
||||
logger.info("Loaded %d legislators into lookup map", len(legislator_map))
|
||||
bill_map = _build_bill_map(session)
|
||||
logger.info("Loaded %d bills into lookup map", len(bill_map))
|
||||
existing_votes = {
|
||||
(vote.congress, vote.chamber, vote.session, vote.number) for vote in session.scalars(select(Vote)).all()
|
||||
}
|
||||
logger.info("Found %d existing votes in DB", len(existing_votes))
|
||||
|
||||
total_inserted = 0
|
||||
batch: list[Vote] = []
|
||||
for congress_dir in congress_dirs:
|
||||
votes_dir = congress_dir / "votes"
|
||||
if not votes_dir.is_dir():
|
||||
continue
|
||||
logger.info("Scanning votes from %s", congress_dir.name)
|
||||
for vote_file in votes_dir.rglob("data.json"):
|
||||
data = _read_json(vote_file)
|
||||
if data is None:
|
||||
continue
|
||||
vote = _parse_vote(data, legislator_map, bill_map, existing_votes)
|
||||
if vote is not None:
|
||||
batch.append(vote)
|
||||
if len(batch) >= BATCH_SIZE:
|
||||
total_inserted += _flush_batch(session, batch, "votes")
|
||||
|
||||
total_inserted += _flush_batch(session, batch, "votes")
|
||||
logger.info("Inserted %d new votes total", total_inserted)
|
||||
|
||||
|
||||
def _build_legislator_map(session: Session) -> dict[str, int]:
|
||||
"""Build a mapping of bioguide_id -> legislator.id."""
|
||||
return {legislator.bioguide_id: legislator.id for legislator in session.scalars(select(Legislator)).all()}
|
||||
|
||||
|
||||
def _build_bill_map(session: Session) -> dict[tuple[int, str, int], int]:
|
||||
"""Build a mapping of (congress, bill_type, number) -> bill.id."""
|
||||
return {(bill.congress, bill.bill_type, bill.number): bill.id for bill in session.scalars(select(Bill)).all()}
|
||||
|
||||
|
||||
def _parse_vote(
|
||||
data: dict,
|
||||
legislator_map: dict[str, int],
|
||||
bill_map: dict[tuple[int, str, int], int],
|
||||
existing_votes: set[tuple[int, str, int, int]],
|
||||
) -> Vote | None:
|
||||
"""Parse a vote data.json dict into a Vote ORM object with records."""
|
||||
raw_congress = data.get("congress")
|
||||
chamber = data.get("chamber")
|
||||
raw_number = data.get("number")
|
||||
vote_date = data.get("date")
|
||||
if raw_congress is None or chamber is None or raw_number is None or vote_date is None:
|
||||
return None
|
||||
|
||||
raw_session = data.get("session")
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
congress = int(raw_congress)
|
||||
number = int(raw_number)
|
||||
session_number = int(raw_session)
|
||||
|
||||
# Normalize chamber from "h"/"s" to "House"/"Senate"
|
||||
chamber_normalized = {"h": "House", "s": "Senate"}.get(chamber, chamber)
|
||||
|
||||
if (congress, chamber_normalized, session_number, number) in existing_votes:
|
||||
return None
|
||||
|
||||
# Resolve linked bill
|
||||
bill_id = None
|
||||
bill_ref = data.get("bill")
|
||||
if bill_ref:
|
||||
bill_key = (
|
||||
int(bill_ref.get("congress", congress)),
|
||||
bill_ref.get("type"),
|
||||
int(bill_ref.get("number", 0)),
|
||||
)
|
||||
bill_id = bill_map.get(bill_key)
|
||||
|
||||
raw_votes = data.get("votes", {})
|
||||
vote_counts = _count_votes(raw_votes)
|
||||
vote_records = _build_vote_records(raw_votes, legislator_map)
|
||||
|
||||
return Vote(
|
||||
congress=congress,
|
||||
chamber=chamber_normalized,
|
||||
session=session_number,
|
||||
number=number,
|
||||
vote_type=data.get("type"),
|
||||
question=data.get("question"),
|
||||
result=data.get("result"),
|
||||
result_text=data.get("result_text"),
|
||||
vote_date=vote_date[:10] if isinstance(vote_date, str) else vote_date,
|
||||
bill_id=bill_id,
|
||||
vote_records=vote_records,
|
||||
**vote_counts,
|
||||
)
|
||||
|
||||
|
||||
def _count_votes(raw_votes: dict) -> dict[str, int]:
|
||||
"""Count voters per position category, correctly handling dict and list formats."""
|
||||
yea_count = 0
|
||||
nay_count = 0
|
||||
not_voting_count = 0
|
||||
present_count = 0
|
||||
|
||||
for position, position_group in raw_votes.items():
|
||||
voter_count = sum(1 for _ in _iter_voters(position_group))
|
||||
if position in ("Yea", "Aye"):
|
||||
yea_count += voter_count
|
||||
elif position in ("Nay", "No"):
|
||||
nay_count += voter_count
|
||||
elif position == "Not Voting":
|
||||
not_voting_count += voter_count
|
||||
elif position == "Present":
|
||||
present_count += voter_count
|
||||
|
||||
return {
|
||||
"yea_count": yea_count,
|
||||
"nay_count": nay_count,
|
||||
"not_voting_count": not_voting_count,
|
||||
"present_count": present_count,
|
||||
}
|
||||
|
||||
|
||||
def _build_vote_records(raw_votes: dict, legislator_map: dict[str, int]) -> list[VoteRecord]:
|
||||
"""Build VoteRecord objects from raw vote data."""
|
||||
records: list[VoteRecord] = []
|
||||
for position, position_group in raw_votes.items():
|
||||
for voter in _iter_voters(position_group):
|
||||
bioguide_id = voter.get("id")
|
||||
if not bioguide_id:
|
||||
continue
|
||||
legislator_id = legislator_map.get(bioguide_id)
|
||||
if legislator_id is None:
|
||||
continue
|
||||
records.append(
|
||||
VoteRecord(
|
||||
legislator_id=legislator_id,
|
||||
position=position,
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bill Text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ingest_bill_text(engine: Engine, congress_dirs: list[Path]) -> None:
|
||||
"""Load bill text from text-versions directories."""
|
||||
with Session(engine) as session:
|
||||
bill_map = _build_bill_map(session)
|
||||
logger.info("Loaded %d bills into lookup map", len(bill_map))
|
||||
existing_bill_texts = {
|
||||
(bill_text.bill_id, bill_text.version_code) for bill_text in session.scalars(select(BillText)).all()
|
||||
}
|
||||
logger.info("Found %d existing bill text versions in DB", len(existing_bill_texts))
|
||||
|
||||
total_inserted = 0
|
||||
batch: list[BillText] = []
|
||||
for congress_dir in congress_dirs:
|
||||
logger.info("Scanning bill texts from %s", congress_dir.name)
|
||||
for bill_text in _iter_bill_texts(congress_dir, bill_map, existing_bill_texts):
|
||||
batch.append(bill_text)
|
||||
if len(batch) >= BATCH_SIZE:
|
||||
total_inserted += _flush_batch(session, batch, "bill texts")
|
||||
|
||||
total_inserted += _flush_batch(session, batch, "bill texts")
|
||||
logger.info("Inserted %d new bill text versions total", total_inserted)
|
||||
|
||||
|
||||
def _iter_bill_texts(
|
||||
congress_dir: Path,
|
||||
bill_map: dict[tuple[int, str, int], int],
|
||||
existing_bill_texts: set[tuple[int, str]],
|
||||
) -> Iterator[BillText]:
|
||||
"""Yield BillText objects for a single congress directory, skipping existing."""
|
||||
bills_dir = congress_dir / "bills"
|
||||
if not bills_dir.is_dir():
|
||||
return
|
||||
|
||||
for bill_dir in bills_dir.rglob("text-versions"):
|
||||
if not bill_dir.is_dir():
|
||||
continue
|
||||
bill_key = _bill_key_from_dir(bill_dir.parent, congress_dir)
|
||||
if bill_key is None:
|
||||
continue
|
||||
bill_id = bill_map.get(bill_key)
|
||||
if bill_id is None:
|
||||
continue
|
||||
|
||||
for version_dir in sorted(bill_dir.iterdir()):
|
||||
if not version_dir.is_dir():
|
||||
continue
|
||||
if (bill_id, version_dir.name) in existing_bill_texts:
|
||||
continue
|
||||
text_content = _read_bill_text(version_dir)
|
||||
version_data = _read_json(version_dir / "data.json")
|
||||
yield BillText(
|
||||
bill_id=bill_id,
|
||||
version_code=version_dir.name,
|
||||
version_name=version_data.get("version_name") if version_data else None,
|
||||
date=version_data.get("issued_on") if version_data else None,
|
||||
text_content=text_content,
|
||||
)
|
||||
|
||||
|
||||
def _bill_key_from_dir(bill_dir: Path, congress_dir: Path) -> tuple[int, str, int] | None:
|
||||
"""Extract (congress, bill_type, number) from directory structure."""
|
||||
congress = int(congress_dir.name)
|
||||
bill_type = bill_dir.parent.name
|
||||
name = bill_dir.name
|
||||
# Directory name is like "hr3590" — strip the type prefix to get the number
|
||||
number_str = name[len(bill_type) :]
|
||||
if not number_str.isdigit():
|
||||
return None
|
||||
return (congress, bill_type, int(number_str))
|
||||
|
||||
|
||||
def _read_bill_text(version_dir: Path) -> str | None:
|
||||
"""Read bill text from a version directory, preferring .txt over .xml."""
|
||||
for extension in ("txt", "htm", "html", "xml"):
|
||||
candidates = list(version_dir.glob(f"document.{extension}"))
|
||||
if not candidates:
|
||||
candidates = list(version_dir.glob(f"*.{extension}"))
|
||||
if candidates:
|
||||
try:
|
||||
return candidates[0].read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
logger.exception("Failed to read %s", candidates[0])
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _read_json(path: Path) -> dict | None:
|
||||
"""Read and parse a JSON file, returning None on failure."""
|
||||
try:
|
||||
return orjson.loads(path.read_bytes())
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to parse %s", path)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
|
||||
|
||||
Usage:
|
||||
ingest-posts /path/to/files/
|
||||
ingest-posts /path/to/single_file.jsonl
|
||||
ingest-posts /data/dir/ --workers 4 --batch-size 5000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path # noqa: TC003 this is needed for typer
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import orjson
|
||||
import psycopg
|
||||
import typer
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_connection_info
|
||||
from python.parallelize import parallelize_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
path: Annotated[Path, typer.Argument(help="Directory containing JSONL files, or a single JSONL file")],
|
||||
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
|
||||
workers: Annotated[int, typer.Option(help="Parallel workers for multi-file ingestion")] = 4,
|
||||
pattern: Annotated[str, typer.Option(help="Glob pattern for JSONL files")] = "*.jsonl",
|
||||
) -> None:
|
||||
"""Ingest JSONL post files into the weekly-partitioned posts table."""
|
||||
configure_logger(level="INFO")
|
||||
|
||||
logger.info("starting ingest-posts")
|
||||
logger.info("path=%s batch_size=%d workers=%d pattern=%s", path, batch_size, workers, pattern)
|
||||
if path.is_file():
|
||||
ingest_file(path, batch_size=batch_size)
|
||||
elif path.is_dir():
|
||||
ingest_directory(path, batch_size=batch_size, max_workers=workers, pattern=pattern)
|
||||
else:
|
||||
typer.echo(f"Path does not exist: {path}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
logger.info("ingest-posts done")
|
||||
|
||||
|
||||
def ingest_directory(
|
||||
directory: Path,
|
||||
*,
|
||||
batch_size: int,
|
||||
max_workers: int,
|
||||
pattern: str = "*.jsonl",
|
||||
) -> None:
|
||||
"""Ingest all JSONL files in a directory using parallel workers."""
|
||||
files = sorted(directory.glob(pattern))
|
||||
if not files:
|
||||
logger.warning("No JSONL files found in %s", directory)
|
||||
return
|
||||
|
||||
logger.info("Found %d JSONL files to ingest", len(files))
|
||||
|
||||
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
|
||||
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
|
||||
|
||||
|
||||
SCHEMA = "main"
|
||||
|
||||
COLUMNS = (
|
||||
"post_id",
|
||||
"user_id",
|
||||
"instance",
|
||||
"date",
|
||||
"text",
|
||||
"langs",
|
||||
"like_count",
|
||||
"reply_count",
|
||||
"repost_count",
|
||||
"reply_to",
|
||||
"replied_author",
|
||||
"thread_root",
|
||||
"thread_root_author",
|
||||
"repost_from",
|
||||
"reposted_author",
|
||||
"quotes",
|
||||
"quoted_author",
|
||||
"labels",
|
||||
"sent_label",
|
||||
"sent_score",
|
||||
)
|
||||
|
||||
INSERT_FROM_STAGING = f"""
|
||||
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
|
||||
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
|
||||
ON CONFLICT (post_id, date) DO NOTHING
|
||||
""" # noqa: S608
|
||||
|
||||
FAILED_INSERT = f"""
|
||||
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
|
||||
VALUES (%(raw_line)s, %(error)s)
|
||||
""" # noqa: S608
|
||||
|
||||
|
||||
def get_psycopg_connection() -> psycopg.Connection:
|
||||
"""Create a raw psycopg3 connection from environment variables."""
|
||||
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
|
||||
return psycopg.connect(
|
||||
dbname=database,
|
||||
host=host,
|
||||
port=int(port),
|
||||
user=username,
|
||||
password=password,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
def ingest_file(path: Path, *, batch_size: int) -> None:
|
||||
"""Ingest a single JSONL file into the posts table."""
|
||||
log_trigger = max(100_000 // batch_size, 1)
|
||||
failed_lines: list[dict] = []
|
||||
try:
|
||||
with get_psycopg_connection() as connection:
|
||||
for index, batch in enumerate(read_jsonl_batches(path, batch_size, failed_lines), 1):
|
||||
ingest_batch(connection, batch)
|
||||
if index % log_trigger == 0:
|
||||
logger.info("Ingested %d batches (%d rows) from %s", index, index * batch_size, path)
|
||||
|
||||
if failed_lines:
|
||||
logger.warning("Recording %d malformed lines from %s", len(failed_lines), path.name)
|
||||
with connection.cursor() as cursor:
|
||||
cursor.executemany(FAILED_INSERT, failed_lines)
|
||||
connection.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to ingest file: %s", path)
|
||||
raise
|
||||
|
||||
|
||||
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
|
||||
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"""
|
||||
CREATE TEMP TABLE IF NOT EXISTS staging
|
||||
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
|
||||
ON COMMIT DELETE ROWS
|
||||
""")
|
||||
cursor.execute("TRUNCATE pg_temp.staging")
|
||||
|
||||
with cursor.copy(f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN") as copy:
|
||||
for row in batch:
|
||||
copy.write_row(tuple(row.get(column) for column in COLUMNS))
|
||||
|
||||
cursor.execute(INSERT_FROM_STAGING)
|
||||
connection.commit()
|
||||
except Exception as error:
|
||||
connection.rollback()
|
||||
|
||||
if len(batch) == 1:
|
||||
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
FAILED_INSERT,
|
||||
{
|
||||
"raw_line": orjson.dumps(batch[0], default=str).decode(),
|
||||
"error": str(error),
|
||||
},
|
||||
)
|
||||
connection.commit()
|
||||
return
|
||||
|
||||
midpoint = len(batch) // 2
|
||||
ingest_batch(connection, batch[:midpoint])
|
||||
ingest_batch(connection, batch[midpoint:])
|
||||
|
||||
|
||||
def read_jsonl_batches(file_path: Path, batch_size: int, failed_lines: list[dict]) -> Iterator[list[dict]]:
|
||||
"""Stream a JSONL file and yield batches of transformed rows."""
|
||||
batch: list[dict] = []
|
||||
with file_path.open("r", encoding="utf-8") as handle:
|
||||
for raw_line in handle:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
batch.extend(parse_line(line, file_path, failed_lines))
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
|
||||
"""Parse a JSONL line, handling concatenated JSON objects."""
|
||||
try:
|
||||
yield transform_row(orjson.loads(line))
|
||||
except orjson.JSONDecodeError:
|
||||
if "}{" not in line:
|
||||
logger.warning("Skipping malformed line in %s: %s", file_path.name, line[:120])
|
||||
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
|
||||
return
|
||||
fragments = line.replace("}{", "}\n{").split("\n")
|
||||
for fragment in fragments:
|
||||
try:
|
||||
yield transform_row(orjson.loads(fragment))
|
||||
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
|
||||
logger.warning("Skipping malformed fragment in %s: %s", file_path.name, fragment[:120])
|
||||
failed_lines.append({"raw_line": fragment, "error": str(error)})
|
||||
except Exception as error:
|
||||
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
|
||||
failed_lines.append({"raw_line": line, "error": str(error)})
|
||||
|
||||
|
||||
def transform_row(raw: dict) -> dict:
|
||||
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
|
||||
raw["date"] = parse_date(raw["date"])
|
||||
if raw.get("langs") is not None:
|
||||
raw["langs"] = orjson.dumps(raw["langs"])
|
||||
if raw.get("text") is not None:
|
||||
raw["text"] = raw["text"].replace("\x00", "")
|
||||
return raw
|
||||
|
||||
|
||||
def parse_date(raw_date: int) -> datetime:
|
||||
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
|
||||
return datetime(
|
||||
raw_date // 100000000,
|
||||
(raw_date // 1000000) % 100,
|
||||
(raw_date // 10000) % 100,
|
||||
(raw_date // 100) % 100,
|
||||
raw_date % 100,
|
||||
tzinfo=UTC,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -83,6 +83,20 @@ DATABASES: dict[str, DatabaseConfig] = {
|
||||
base_class_name="VanInventoryBase",
|
||||
models_module="python.orm.van_inventory.models",
|
||||
),
|
||||
"signal_bot": DatabaseConfig(
|
||||
env_prefix="SIGNALBOT",
|
||||
version_location="python/alembic/signal_bot/versions",
|
||||
base_module="python.orm.signal_bot.base",
|
||||
base_class_name="SignalBotBase",
|
||||
models_module="python.orm.signal_bot.models",
|
||||
),
|
||||
"data_science_dev": DatabaseConfig(
|
||||
env_prefix="DATA_SCIENCE_DEV",
|
||||
version_location="python/alembic/data_science_dev/versions",
|
||||
base_module="python.orm.data_science_dev.base",
|
||||
base_class_name="DataScienceDevBase",
|
||||
models_module="python.orm.data_science_dev.models",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""EPUB search package."""
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Grounded answer generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from python.ebook_search.llm_interface import request_chat_completion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
from python.ebook_search.search import SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def answer_query(query: str, results: list[SearchResult], config: EbookSearchConfig) -> str:
|
||||
"""Answer a question using only retrieved chunks."""
|
||||
if not config.answer_enabled:
|
||||
logger.info("ebook_answer_skipped_disabled")
|
||||
return "Answer generation is disabled. Source chunks are shown below."
|
||||
|
||||
if not results:
|
||||
logger.info("ebook_answer_skipped_no_results")
|
||||
return "No relevant sources were found."
|
||||
|
||||
logger.info(
|
||||
"ebook_answer_request_start base_url=%s model=%s sources=%s query_length=%s",
|
||||
config.vllm_base_url,
|
||||
config.chat_model,
|
||||
len(results),
|
||||
len(query),
|
||||
)
|
||||
context = "\n\n".join(
|
||||
f"[{index}] {result.source_title}{' - ' + result.chapter_title if result.chapter_title else ''}\n{result.text}"
|
||||
for index, result in enumerate(results, start=1)
|
||||
)
|
||||
content = request_chat_completion(
|
||||
config,
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Answer only from the provided context. Cite sources with bracketed numbers like [1]. "
|
||||
"If the context is insufficient, say so."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Question:\n{query}\n\nContext:\n{context}"},
|
||||
],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ebook_answer_request_complete model=%s answer_length=%s",
|
||||
config.chat_model,
|
||||
len(content),
|
||||
)
|
||||
return content or "The model returned an empty answer."
|
||||
@@ -1 +0,0 @@
|
||||
"""Web and external API adapters for EPUB search."""
|
||||
@@ -1,60 +0,0 @@
|
||||
"""Background BM25 refresh tasks for the web app."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from threading import Timer
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.bm25_corpus import load_bm25_corpus, refresh_bm25_corpus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def schedule_bm25_refresh(app: FastAPI) -> None:
|
||||
"""Schedule a delayed BM25 corpus refresh, replacing any pending refresh."""
|
||||
existing_timer = getattr(app.state, "bm25_refresh_timer", None)
|
||||
if existing_timer is not None:
|
||||
existing_timer.cancel()
|
||||
|
||||
timer = Timer(app.state.config.bm25_refresh_delay_seconds, refresh_bm25_for_app, args=(app,))
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
app.state.bm25_refresh_timer = timer
|
||||
logger.info(
|
||||
"ebook_bm25_refresh_scheduled delay_seconds=%s",
|
||||
app.state.config.bm25_refresh_delay_seconds,
|
||||
)
|
||||
|
||||
|
||||
def cancel_bm25_refresh(app: FastAPI) -> None:
|
||||
"""Cancel any pending BM25 corpus refresh."""
|
||||
existing_timer = getattr(app.state, "bm25_refresh_timer", None)
|
||||
if existing_timer is not None:
|
||||
existing_timer.cancel()
|
||||
app.state.bm25_refresh_timer = None
|
||||
logger.info("ebook_bm25_refresh_cancelled")
|
||||
|
||||
|
||||
def refresh_bm25_for_app(app: FastAPI) -> None:
|
||||
"""Refresh the BM25 corpus using the app engine and config."""
|
||||
try:
|
||||
refresh_bm25_for_engine(app.state.engine, app.state.config)
|
||||
except Exception:
|
||||
logger.exception("ebook_bm25_refresh_failed")
|
||||
|
||||
|
||||
def refresh_bm25_for_engine(engine: Engine, config: EbookSearchConfig) -> None:
|
||||
"""Refresh the BM25 corpus using a SQLAlchemy engine."""
|
||||
with Session(engine) as session:
|
||||
refresh_bm25_corpus(session, config)
|
||||
load_bm25_corpus.cache_clear()
|
||||
logger.info("ebook_bm25_corpus_cache_cleared_after_refresh")
|
||||
@@ -1,79 +0,0 @@
|
||||
"""FastAPI HTMX app for EPUB search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
|
||||
from python.ebook_search.api.routes import admin_router, page_router, search_router
|
||||
from python.ebook_search.api.web import STATIC_DIR
|
||||
from python.ebook_search.bm25_corpus import ensure_bm25_corpus
|
||||
from python.ebook_search.config import load_config
|
||||
from python.fastapi_tools import ZstdMiddleware
|
||||
from python.orm.common import get_postgres_engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
"""Manage application startup and shutdown resources."""
|
||||
logger.info("ebook_search_startup")
|
||||
app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True)
|
||||
with Session(app.state.engine) as session:
|
||||
ensure_bm25_corpus(session, app.state.config)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("ebook_search_shutdown")
|
||||
cancel_bm25_refresh(app)
|
||||
app.state.engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the EPUB search web app."""
|
||||
app = FastAPI(title="EPUB Search", lifespan=lifespan)
|
||||
app.add_middleware(ZstdMiddleware)
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.state.config = load_config()
|
||||
logger.info(
|
||||
"ebook_search_config_loaded top_k=%s embedding_model=%s rerank_enabled=%s answer_enabled=%s library_paths=%s",
|
||||
app.state.config.top_k,
|
||||
app.state.config.embedding_model,
|
||||
app.state.config.rerank.enabled,
|
||||
app.state.config.answer_enabled,
|
||||
len(app.state.config.library_paths),
|
||||
)
|
||||
|
||||
app.include_router(admin_router)
|
||||
app.include_router(page_router)
|
||||
app.include_router(search_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def serve(
|
||||
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8070,
|
||||
log_level: Annotated[str, typer.Option("--log-level", "-l", help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Start the EPUB search server."""
|
||||
configure_logger(log_level)
|
||||
uvicorn.run(create_app(), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(serve)
|
||||
@@ -1,11 +0,0 @@
|
||||
"""EPUB search web route modules."""
|
||||
|
||||
from python.ebook_search.api.routes.admin import router as admin_router
|
||||
from python.ebook_search.api.routes.page import router as page_router
|
||||
from python.ebook_search.api.routes.search import router as search_router
|
||||
|
||||
__all__ = [
|
||||
"admin_router",
|
||||
"page_router",
|
||||
"search_router",
|
||||
]
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Admin routes for the EPUB search web UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.api.bm25_tasks import schedule_bm25_refresh
|
||||
from python.ebook_search.api.web import templates
|
||||
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
|
||||
from python.ebook_search.ingest import ingest_configured_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin")
|
||||
EMBED_ALL_BATCH_SIZE = 32
|
||||
|
||||
|
||||
@router.get("", response_class=HTMLResponse)
|
||||
def admin(request: Request) -> HTMLResponse:
|
||||
"""Render the admin page."""
|
||||
with Session(request.app.state.engine) as session:
|
||||
stats = embedding_model_stats(session)
|
||||
logger.info("ebook_admin_page_loaded models=%s", len(stats))
|
||||
return templates.TemplateResponse(request, "admin.html", {"config": request.app.state.config, "stats": stats})
|
||||
|
||||
|
||||
@router.post("/scan", response_class=HTMLResponse)
|
||||
def scan_library(request: Request) -> HTMLResponse:
|
||||
"""Scan configured library paths for EPUB changes."""
|
||||
try:
|
||||
with Session(request.app.state.engine) as session:
|
||||
count = ingest_configured_paths(session, request.app.state.config)
|
||||
session.commit()
|
||||
except Exception as error:
|
||||
logger.exception("ebook_admin_scan_failed")
|
||||
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
|
||||
|
||||
logger.info("ebook_admin_scan_complete changed_files=%s", count)
|
||||
if count > 0:
|
||||
schedule_bm25_refresh(request.app)
|
||||
return templates.TemplateResponse(request, "partials/admin_status.html", {"message": f"Indexed {count} EPUBs"})
|
||||
|
||||
|
||||
@router.post("/embed-missing", response_class=HTMLResponse)
|
||||
def embed_missing(request: Request) -> HTMLResponse:
|
||||
"""Embed chunks missing vectors for the configured model."""
|
||||
try:
|
||||
with Session(request.app.state.engine) as session:
|
||||
count = embed_missing_chunks(session, request.app.state.config)
|
||||
session.commit()
|
||||
except Exception as error:
|
||||
logger.exception("ebook_admin_embed_missing_failed")
|
||||
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
|
||||
|
||||
logger.info("ebook_admin_embed_missing_complete chunks=%s", count)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/admin_status.html",
|
||||
{"message": f"Embedded {count} chunks"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/embed-all", response_class=HTMLResponse)
|
||||
def embed_all(request: Request) -> HTMLResponse:
|
||||
"""Embed all chunks missing vectors in fixed-size batches."""
|
||||
total = 0
|
||||
batches = 0
|
||||
config = replace(request.app.state.config, embedding_batch_size=EMBED_ALL_BATCH_SIZE)
|
||||
try:
|
||||
with Session(request.app.state.engine) as session:
|
||||
while True:
|
||||
count = embed_missing_chunks(session, config)
|
||||
if count == 0:
|
||||
break
|
||||
session.commit()
|
||||
total += count
|
||||
batches += 1
|
||||
logger.info(
|
||||
"ebook_admin_embed_all_batch_complete batch=%s chunks=%s total_chunks=%s",
|
||||
batches,
|
||||
count,
|
||||
total,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(
|
||||
"ebook_admin_embed_all_failed batches=%s chunks=%s",
|
||||
batches,
|
||||
total,
|
||||
)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/error.html",
|
||||
{"message": f"Embed all failed after {total} chunks in {batches} batches: {error}"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
logger.info("ebook_admin_embed_all_complete batches=%s chunks=%s", batches, total)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/admin_status.html",
|
||||
{"message": f"Embedded {total} chunks in {batches} batches of {EMBED_ALL_BATCH_SIZE}"},
|
||||
)
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Page routes for the EPUB search web UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.api.web import templates
|
||||
from python.orm.richie import EbookSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
def index(request: Request) -> HTMLResponse:
|
||||
"""Render the search page."""
|
||||
return templates.TemplateResponse(request, "search.html", {"config": request.app.state.config})
|
||||
|
||||
|
||||
@router.get("/books", response_class=HTMLResponse)
|
||||
def books(request: Request) -> HTMLResponse:
|
||||
"""Render the indexed books page."""
|
||||
with Session(request.app.state.engine) as session:
|
||||
sources = list(session.scalars(select(EbookSource).order_by(EbookSource.title)).all())
|
||||
logger.info("ebook_books_page_loaded count=%s", len(sources))
|
||||
return templates.TemplateResponse(request, "books.html", {"sources": sources})
|
||||
|
||||
|
||||
@router.get("/books/{source_id}", response_class=HTMLResponse)
|
||||
def book_detail(source_id: int, request: Request) -> HTMLResponse:
|
||||
"""Render details for one indexed book."""
|
||||
with Session(request.app.state.engine) as session:
|
||||
source = session.get(EbookSource, source_id)
|
||||
if source is not None:
|
||||
chapter_count = len(source.chapters)
|
||||
chunk_count = len(source.chunks)
|
||||
else:
|
||||
chapter_count = 0
|
||||
chunk_count = 0
|
||||
logger.info(
|
||||
"ebook_book_detail_loaded source_id=%s found=%s chapters=%s chunks=%s",
|
||||
source_id,
|
||||
source is not None,
|
||||
chapter_count,
|
||||
chunk_count,
|
||||
)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"book_detail.html",
|
||||
{"chapter_count": chapter_count, "chunk_count": chunk_count, "source": source},
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Search routes for the EPUB search web UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
from time import perf_counter
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from python.ebook_search.answer import answer_query
|
||||
from python.ebook_search.api.web import templates
|
||||
from python.ebook_search.search import search_ebooks
|
||||
from python.ebook_search.timing import runtime_step_from_start
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/search", response_class=HTMLResponse)
|
||||
def search(
|
||||
request: Request,
|
||||
query: Annotated[str, Form()],
|
||||
rerank: Annotated[str | None, Form()] = None,
|
||||
) -> HTMLResponse:
|
||||
"""Run a search and render HTMX results."""
|
||||
try:
|
||||
response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true")
|
||||
except Exception as error:
|
||||
logger.exception("ebook_search_request_failed")
|
||||
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
|
||||
|
||||
answer_start = perf_counter()
|
||||
if request.app.state.config.answer_enabled:
|
||||
try:
|
||||
answer = answer_query(query, response.results, request.app.state.config)
|
||||
except RuntimeError as error:
|
||||
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
|
||||
answer = "Answer generation failed. Source chunks are still shown below."
|
||||
else:
|
||||
logger.info("ebook_answer_skipped_disabled")
|
||||
answer = "Answer generation is disabled. Source chunks are shown below."
|
||||
answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped"
|
||||
response = replace(
|
||||
response,
|
||||
timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ebook_search_request_complete results=%s rank_label=%s runtime_ms=%.1f",
|
||||
len(response.results),
|
||||
response.rank_label,
|
||||
response.total_runtime_ms,
|
||||
)
|
||||
return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response})
|
||||
@@ -1,140 +0,0 @@
|
||||
body {
|
||||
margin: 0;
|
||||
background: #f7f7f4;
|
||||
color: #202124;
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||||
}
|
||||
|
||||
main {
|
||||
max-width: 960px;
|
||||
margin: 0 auto;
|
||||
padding: 24px;
|
||||
}
|
||||
|
||||
nav {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
nav form {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.actions {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 12px;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
textarea {
|
||||
display: block;
|
||||
width: 100%;
|
||||
margin: 8px 0 12px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 8px 14px;
|
||||
}
|
||||
|
||||
.check {
|
||||
display: inline-flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
margin-right: 12px;
|
||||
}
|
||||
|
||||
.rank-label {
|
||||
margin-top: 24px;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.results {
|
||||
padding-left: 24px;
|
||||
}
|
||||
|
||||
.meta,
|
||||
.scores,
|
||||
.status {
|
||||
color: #626a73;
|
||||
}
|
||||
|
||||
.scores {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
margin: 12px 0;
|
||||
}
|
||||
|
||||
.scores div {
|
||||
display: inline-flex;
|
||||
gap: 4px;
|
||||
align-items: baseline;
|
||||
}
|
||||
|
||||
.scores dt {
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.scores dd {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.runtime {
|
||||
margin-top: 16px;
|
||||
}
|
||||
|
||||
.timing-chart {
|
||||
display: grid;
|
||||
gap: 8px;
|
||||
padding: 0;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.timing-chart li {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(150px, 1fr) minmax(160px, 2fr) auto auto;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.timing-bar {
|
||||
height: 10px;
|
||||
overflow: hidden;
|
||||
background: #e5e5df;
|
||||
}
|
||||
|
||||
.timing-bar span {
|
||||
display: block;
|
||||
height: 100%;
|
||||
background: #3767c8;
|
||||
}
|
||||
|
||||
.timing-value,
|
||||
.timing-remaining {
|
||||
color: #626a73;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
th,
|
||||
td {
|
||||
padding: 8px;
|
||||
border-bottom: 1px solid #d8d8d2;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
th {
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.error {
|
||||
color: #9f1d20;
|
||||
font-weight: 700;
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>EPUB Admin</title>
|
||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<nav>
|
||||
<a href="/">Search</a>
|
||||
<a href="/books">Books</a>
|
||||
<a href="/admin">Admin</a>
|
||||
</nav>
|
||||
<h1>Admin</h1>
|
||||
<section id="admin-status"></section>
|
||||
<section class="actions">
|
||||
<form hx-post="/admin/scan" hx-target="#admin-status" hx-swap="innerHTML">
|
||||
<button type="submit">Scan</button>
|
||||
</form>
|
||||
<form hx-post="/admin/embed-missing" hx-target="#admin-status" hx-swap="innerHTML">
|
||||
<button type="submit">Embed</button>
|
||||
</form>
|
||||
<form hx-post="/admin/embed-all" hx-target="#admin-status" hx-swap="innerHTML">
|
||||
<button type="submit">Embed all</button>
|
||||
</form>
|
||||
</section>
|
||||
<section>
|
||||
<h2>Embeddings</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Model</th>
|
||||
<th>Dimensions</th>
|
||||
<th>Embedded</th>
|
||||
<th>Missing</th>
|
||||
<th>Total chunks</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for item in stats %}
|
||||
<tr>
|
||||
<td>{{ item.model_name }}</td>
|
||||
<td>{{ item.dimension }}</td>
|
||||
<td>{{ item.embedded_chunks }}</td>
|
||||
<td>{{ item.missing_chunks }}</td>
|
||||
<td>{{ item.total_chunks }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,32 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{% if source %}{{ source.title }}{% else %}Book not found{% endif %}</title>
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<nav>
|
||||
<a href="/">Search</a>
|
||||
<a href="/books">Books</a>
|
||||
<a href="/admin">Admin</a>
|
||||
</nav>
|
||||
{% if source %}
|
||||
<h1>{{ source.title }}</h1>
|
||||
<p class="meta">{{ source.author or "Unknown author" }}</p>
|
||||
<dl>
|
||||
<dt>File</dt>
|
||||
<dd>{{ source.file_path }}</dd>
|
||||
<dt>Chapters</dt>
|
||||
<dd>{{ chapter_count }}</dd>
|
||||
<dt>Chunks</dt>
|
||||
<dd>{{ chunk_count }}</dd>
|
||||
</dl>
|
||||
{% else %}
|
||||
<h1>Book not found</h1>
|
||||
{% endif %}
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,31 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>EPUB Books</title>
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<nav>
|
||||
<a href="/">Search</a>
|
||||
<a href="/books">Books</a>
|
||||
<a href="/admin">Admin</a>
|
||||
</nav>
|
||||
<h1>Books</h1>
|
||||
{% if sources %}
|
||||
<ol class="results">
|
||||
{% for source in sources %}
|
||||
<li>
|
||||
<h2><a href="/books/{{ source.id }}">{{ source.title }}</a></h2>
|
||||
<p class="meta">{{ source.author or "Unknown author" }}</p>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ol>
|
||||
{% else %}
|
||||
<p>No EPUBs indexed.</p>
|
||||
{% endif %}
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1 +0,0 @@
|
||||
<p class="status">{{ message }}</p>
|
||||
@@ -1 +0,0 @@
|
||||
<p class="error">{{ message }}</p>
|
||||
@@ -1,74 +0,0 @@
|
||||
<div class="rank-label">{{ response.rank_label }}</div>
|
||||
{% if response.timings %}
|
||||
<section class="runtime">
|
||||
<h2>Runtime</h2>
|
||||
<p class="meta">Total {{ "%.1f"|format(response.total_runtime_ms) }} ms</p>
|
||||
<ol class="timing-chart">
|
||||
{% set total = response.total_runtime_ms %}
|
||||
{% set ns = namespace(remaining=total) %}
|
||||
{% for step in response.timings %}
|
||||
{% set width = (step.duration_ms / total * 100) if total else 0 %}
|
||||
{% if step.counts_toward_total %}
|
||||
{% set ns.remaining = ns.remaining - step.duration_ms %}
|
||||
{% endif %}
|
||||
<li>
|
||||
<span class="timing-label">{{ step.name }}</span>
|
||||
<span class="timing-bar"><span style="width: {{ "%.2f"|format(width) }}%"></span></span>
|
||||
<span class="timing-value">{{ "%.1f"|format(step.duration_ms) }} ms</span>
|
||||
<span class="timing-remaining">{{ "%.1f"|format([ns.remaining, 0]|max) }} ms left</span>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ol>
|
||||
</section>
|
||||
{% endif %}
|
||||
<section class="answer">
|
||||
<h2>Answer</h2>
|
||||
<p>{{ answer }}</p>
|
||||
</section>
|
||||
{% if response.results %}
|
||||
<ol class="results">
|
||||
{% for result in response.results %}
|
||||
<li>
|
||||
<h2>{{ result.source_title }}</h2>
|
||||
<p class="meta">
|
||||
{% if result.source_author %}{{ result.source_author }}{% endif %}
|
||||
{% if result.chapter_title %} · {{ result.chapter_title }}{% endif %}
|
||||
{% if result.page_label %} · page {{ result.page_label }}{% endif %}
|
||||
</p>
|
||||
<p>{{ result.text }}</p>
|
||||
<dl class="scores">
|
||||
<div>
|
||||
<dt>final</dt>
|
||||
<dd>{{ "%.3f"|format(result.score) }}</dd>
|
||||
</div>
|
||||
{% if result.rerank_score is not none %}
|
||||
<div>
|
||||
<dt>rerank</dt>
|
||||
<dd>{{ "%.3f"|format(result.rerank_score) }}</dd>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if result.vector_score is not none %}
|
||||
<div>
|
||||
<dt>vector cosine</dt>
|
||||
<dd>{{ "%.3f"|format(result.vector_score) }}</dd>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if result.bm25_score is not none %}
|
||||
<div>
|
||||
<dt>BM25</dt>
|
||||
<dd>{{ "%.6f"|format(result.bm25_score) }}</dd>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if result.fused_score is not none %}
|
||||
<div>
|
||||
<dt>RRF</dt>
|
||||
<dd>{{ "%.3f"|format(result.fused_score) }}</dd>
|
||||
</div>
|
||||
{% endif %}
|
||||
</dl>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ol>
|
||||
{% else %}
|
||||
<p>No results.</p>
|
||||
{% endif %}
|
||||
@@ -1,30 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>EPUB Search</title>
|
||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<nav>
|
||||
<a href="/">Search</a>
|
||||
<a href="/books">Books</a>
|
||||
<a href="/admin">Admin</a>
|
||||
</nav>
|
||||
<h1>EPUB Search</h1>
|
||||
<form hx-post="/search" hx-target="#results" hx-swap="innerHTML">
|
||||
<label for="query">Search</label>
|
||||
<textarea id="query" name="query" rows="4" required></textarea>
|
||||
<label class="check">
|
||||
<input type="checkbox" name="rerank" value="true" {% if config.rerank.enabled %}checked{% endif %}>
|
||||
Rerank
|
||||
</label>
|
||||
<button type="submit">Search</button>
|
||||
</form>
|
||||
<section id="results"></section>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Shared web UI resources for EPUB search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
TEMPLATE_DIR = PACKAGE_DIR / "templates"
|
||||
STATIC_DIR = PACKAGE_DIR / "static"
|
||||
|
||||
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Persisted BM25 corpus management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import bm25s
|
||||
from sqlalchemy import func, select, union_all
|
||||
|
||||
from python.orm.richie import EbookChapter, EbookChunk, EbookSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MANIFEST_NAME = "manifest.json"
|
||||
REQUIRED_INDEX_FILES = frozenset(
|
||||
{
|
||||
"data.csc.index.npy",
|
||||
"indices.csc.index.npy",
|
||||
"indptr.csc.index.npy",
|
||||
"params.index.json",
|
||||
"vocab.index.json",
|
||||
"corpus.jsonl",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BM25Manifest:
|
||||
"""Metadata describing a persisted BM25 corpus."""
|
||||
|
||||
created_at: datetime
|
||||
db_updated_at: datetime | None
|
||||
chunk_count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BM25Corpus:
|
||||
"""Loaded persisted BM25 corpus and retriever."""
|
||||
|
||||
retriever: object | None
|
||||
records: tuple[dict[str, object], ...]
|
||||
manifest: BM25Manifest
|
||||
|
||||
|
||||
class BM25CorpusUnavailableError(RuntimeError):
|
||||
"""Raised when the persisted BM25 corpus cannot be loaded."""
|
||||
|
||||
|
||||
def bm25_index_path(config: EbookSearchConfig) -> Path:
|
||||
"""Return the configured BM25 index root path relative to the current working directory."""
|
||||
path = Path(config.bm25_index_dir).expanduser()
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return Path.cwd() / path
|
||||
|
||||
|
||||
def get_current_bm25_index(index_path: Path) -> Path:
|
||||
"""Return the live BM25 index directory."""
|
||||
current_path = index_path / "current"
|
||||
if current_path.exists() or current_path.is_symlink():
|
||||
return current_path
|
||||
return index_path
|
||||
|
||||
|
||||
def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None:
|
||||
"""Create or refresh the persisted BM25 corpus when it is missing or stale."""
|
||||
index_path = bm25_index_path(config)
|
||||
manifest = read_bm25_manifest(index_path)
|
||||
db_updated_at = corpus_last_updated_at(session)
|
||||
if not bm25_index_exists(index_path, manifest):
|
||||
logger.info("ebook_bm25_index_missing path=%s", index_path)
|
||||
refresh_bm25_corpus(session, config, db_updated_at=db_updated_at)
|
||||
return
|
||||
if db_updated_at is not None and manifest is not None and manifest.created_at < db_updated_at:
|
||||
logger.info(
|
||||
"ebook_bm25_index_stale path=%s created_at=%s db_updated_at=%s",
|
||||
index_path,
|
||||
manifest.created_at.isoformat(),
|
||||
db_updated_at.isoformat(),
|
||||
)
|
||||
refresh_bm25_corpus(session, config, db_updated_at=db_updated_at)
|
||||
return
|
||||
logger.info(
|
||||
"ebook_bm25_index_current path=%s chunks=%s created_at=%s",
|
||||
index_path,
|
||||
manifest.chunk_count if manifest else 0,
|
||||
manifest.created_at.isoformat() if manifest else None,
|
||||
)
|
||||
|
||||
|
||||
def refresh_bm25_corpus(
|
||||
session: Session,
|
||||
config: EbookSearchConfig,
|
||||
*,
|
||||
db_updated_at: datetime | None = None,
|
||||
) -> BM25Manifest:
|
||||
"""Rebuild and persist the BM25 corpus from the current database chunks."""
|
||||
index_path = bm25_index_path(config)
|
||||
records, texts = fetch_bm25_corpus_records(session)
|
||||
manifest = BM25Manifest(
|
||||
created_at=datetime.now(tz=UTC),
|
||||
db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session),
|
||||
chunk_count=len(records),
|
||||
)
|
||||
write_bm25_corpus(index_path, records, texts, manifest)
|
||||
logger.info(
|
||||
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
|
||||
index_path,
|
||||
manifest.chunk_count,
|
||||
manifest.created_at.isoformat(),
|
||||
)
|
||||
return manifest
|
||||
|
||||
|
||||
@cache
|
||||
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
|
||||
"""Load the BM25 corpus into memory once per process.
|
||||
|
||||
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
|
||||
"""
|
||||
index_path = bm25_index_path(config)
|
||||
active_index_path = get_current_bm25_index(index_path)
|
||||
logger.info("ebook_bm25_corpus_cache_load path=%s active_path=%s", index_path, active_index_path)
|
||||
manifest = read_bm25_manifest(index_path)
|
||||
if manifest is None or not bm25_index_exists(index_path, manifest):
|
||||
msg = f"BM25 corpus is not available: {index_path}"
|
||||
raise BM25CorpusUnavailableError(msg)
|
||||
if manifest.chunk_count == 0:
|
||||
return BM25Corpus(retriever=None, records=(), manifest=manifest)
|
||||
|
||||
retriever = bm25s.BM25.load(active_index_path, load_corpus=True, mmap=True)
|
||||
records = tuple(dict(record) for record in retriever.corpus)
|
||||
return BM25Corpus(retriever=retriever, records=records, manifest=manifest)
|
||||
|
||||
|
||||
def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tuple[dict[str, object], float]]:
|
||||
"""Score a query against a loaded BM25 corpus."""
|
||||
if corpus.retriever is None or not corpus.records:
|
||||
return []
|
||||
k = min(limit, len(corpus.records))
|
||||
documents, scores = corpus.retriever.retrieve(
|
||||
bm25s.tokenize(query, show_progress=False),
|
||||
corpus=list(corpus.records),
|
||||
k=k,
|
||||
show_progress=False,
|
||||
)
|
||||
results: list[tuple[dict[str, object], float]] = []
|
||||
for document, score in zip(documents[0], scores[0], strict=True):
|
||||
score_value = float(score)
|
||||
if score_value <= 0:
|
||||
continue
|
||||
results.append((dict(document), score_value))
|
||||
return results
|
||||
|
||||
|
||||
def fetch_bm25_corpus_records(session: Session) -> tuple[list[dict[str, object]], list[str]]:
|
||||
"""Fetch persistable BM25 corpus records and their matching index texts from the database.
|
||||
|
||||
search_text is only needed to build the index, so it is returned separately instead of
|
||||
being persisted into the corpus records, which would double the corpus size.
|
||||
"""
|
||||
statement = (
|
||||
select(
|
||||
EbookChunk.id.label("chunk_id"),
|
||||
EbookChunk.text.label("text"),
|
||||
EbookSource.title.label("source_title"),
|
||||
EbookSource.author.label("source_author"),
|
||||
EbookChapter.title.label("chapter_title"),
|
||||
EbookChunk.page_label.label("page_label"),
|
||||
EbookChunk.search_text.label("bm25_text"),
|
||||
)
|
||||
.select_from(EbookChunk)
|
||||
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
|
||||
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
|
||||
.order_by(EbookChunk.id)
|
||||
)
|
||||
records: list[dict[str, object]] = []
|
||||
texts: list[str] = []
|
||||
for row in session.execute(statement).mappings():
|
||||
record = dict(row)
|
||||
texts.append(str(record.pop("bm25_text")))
|
||||
records.append(record)
|
||||
return records, texts
|
||||
|
||||
|
||||
def corpus_last_updated_at(session: Session) -> datetime | None:
|
||||
"""Return the latest source/chapter/chunk update timestamp relevant to BM25 text."""
|
||||
update_times = union_all(
|
||||
select(func.max(EbookSource.updated).label("updated")),
|
||||
select(func.max(EbookChapter.updated).label("updated")),
|
||||
select(func.max(EbookChunk.updated).label("updated")),
|
||||
).subquery()
|
||||
return session.scalar(select(func.max(update_times.c.updated)))
|
||||
|
||||
|
||||
def write_bm25_corpus(
|
||||
index_path: Path,
|
||||
records: list[dict[str, object]],
|
||||
texts: list[str],
|
||||
manifest: BM25Manifest,
|
||||
) -> None:
|
||||
"""Write a BM25 corpus generation and publish it through the current symlink."""
|
||||
index_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generations_path = index_path / "generations"
|
||||
generations_path.mkdir(exist_ok=True)
|
||||
|
||||
generation_path = next_bm25_generation_path(generations_path, manifest.created_at)
|
||||
current_path = index_path / "current"
|
||||
next_current_path = index_path / f".current.{generation_path.name}.tmp"
|
||||
try:
|
||||
generation_path.mkdir()
|
||||
|
||||
# Empty corpora publish a manifest-only generation so startup succeeds before any chunks exist.
|
||||
if records:
|
||||
retriever = bm25s.BM25()
|
||||
retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False)
|
||||
retriever.save(generation_path, corpus=records, show_progress=False)
|
||||
write_bm25_manifest(generation_path, manifest)
|
||||
next_current_path.unlink(missing_ok=True)
|
||||
next_current_path.symlink_to(generation_path, target_is_directory=True)
|
||||
next_current_path.replace(current_path)
|
||||
except Exception:
|
||||
next_current_path.unlink(missing_ok=True)
|
||||
shutil.rmtree(generation_path, ignore_errors=True)
|
||||
raise
|
||||
|
||||
|
||||
def read_bm25_manifest(index_path: Path) -> BM25Manifest | None:
|
||||
"""Read the BM25 manifest if it exists and is valid."""
|
||||
manifest_path = get_current_bm25_index(index_path) / MANIFEST_NAME
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
body = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
return BM25Manifest(
|
||||
created_at=datetime.fromisoformat(str(body["created_at"])),
|
||||
db_updated_at=datetime.fromisoformat(str(body["db_updated_at"])) if body.get("db_updated_at") else None,
|
||||
chunk_count=int(body["chunk_count"]),
|
||||
)
|
||||
|
||||
|
||||
def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None:
|
||||
"""Write the BM25 manifest to an index directory."""
|
||||
body = {
|
||||
"created_at": manifest.created_at.isoformat(),
|
||||
"db_updated_at": manifest.db_updated_at.isoformat() if manifest.db_updated_at else None,
|
||||
"chunk_count": manifest.chunk_count,
|
||||
}
|
||||
(index_path / MANIFEST_NAME).write_text(json.dumps(body, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
|
||||
def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool:
|
||||
"""Return whether a usable persisted BM25 index exists."""
|
||||
active_index_path = get_current_bm25_index(index_path)
|
||||
if manifest is None or not active_index_path.is_dir():
|
||||
return False
|
||||
if manifest.chunk_count == 0:
|
||||
return True
|
||||
return all((active_index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
|
||||
|
||||
|
||||
def next_bm25_generation_path(generations_path: Path, created_at: datetime) -> Path:
|
||||
"""Return an unused dated BM25 generation path."""
|
||||
base_name = created_at.astimezone(UTC).strftime("%Y%m%dT%H%M%S.%fZ")
|
||||
generation_path = generations_path / base_name
|
||||
suffix = 1
|
||||
while generation_path.exists():
|
||||
generation_path = generations_path / f"{base_name}.{suffix}"
|
||||
suffix += 1
|
||||
return generation_path
|
||||
@@ -1,117 +0,0 @@
|
||||
"""Configuration for the EPUB search app."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
|
||||
|
||||
def getenv_bool(name: str, *, default: bool) -> bool:
|
||||
"""Read a boolean environment variable with a default fallback."""
|
||||
value = getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def getenv_int(name: str, *, default: int) -> int:
|
||||
"""Read an integer environment variable with a default fallback."""
|
||||
value = getenv(name)
|
||||
if value is None or not value.strip():
|
||||
return default
|
||||
return int(value)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankConfig:
|
||||
"""vLLM reranker settings."""
|
||||
|
||||
enabled: bool = False
|
||||
base_url: str = "http://192.168.90.25:8001"
|
||||
model: str = "qwen3-reranker-06b"
|
||||
candidates: int = 24
|
||||
timeout_seconds: float = 30.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EbookSearchConfig:
|
||||
"""Runtime settings for EPUB search."""
|
||||
|
||||
rerank: RerankConfig
|
||||
top_k: int = 12
|
||||
library_paths: tuple[str, ...] = ()
|
||||
vllm_base_url: str = "https://ollama.com/v1"
|
||||
vllm_api_key: str = "not-needed"
|
||||
chat_model: str = "deepseek-v4-flash"
|
||||
answer_enabled: bool = True
|
||||
embedding_base_url: str = "http://192.168.90.25:8000/v1"
|
||||
embedding_api_key: str = "not-needed"
|
||||
embedding_model: str = "qwen3-embedding-0.6b"
|
||||
embedding_batch_size: int = 32
|
||||
bm25_index_dir: str = ".ebook_search_bm25"
|
||||
bm25_refresh_delay_seconds: int = 60
|
||||
|
||||
|
||||
def load_rerank_config() -> RerankConfig:
|
||||
"""Load reranker config from environment variables."""
|
||||
return RerankConfig(
|
||||
enabled=getenv_bool("EBOOK_SEARCH_RERANK_ENABLED", default=False),
|
||||
base_url=getenv("EBOOK_SEARCH_RERANK_BASE_URL", "http://192.168.90.25:8001"),
|
||||
model=getenv("EBOOK_SEARCH_RERANK_MODEL", "qwen3-reranker-06b"),
|
||||
candidates=getenv_int("EBOOK_SEARCH_RERANK_CANDIDATES", default=24),
|
||||
timeout_seconds=float(getenv_int("EBOOK_SEARCH_RERANK_TIMEOUT_SECONDS", default=30)),
|
||||
)
|
||||
|
||||
|
||||
def load_config() -> EbookSearchConfig:
|
||||
"""Load EPUB search config from environment variables."""
|
||||
return EbookSearchConfig(
|
||||
rerank=load_rerank_config(),
|
||||
top_k=getenv_int("EBOOK_SEARCH_TOP_K", default=12),
|
||||
library_paths=library_paths_from_env(),
|
||||
vllm_base_url=getenv("EBOOK_SEARCH_VLLM_BASE_URL", "https://ollama.com/v1"),
|
||||
vllm_api_key=getenv("EBOOK_SEARCH_VLLM_API_KEY") or getenv("OLLAMA_API_KEY") or "not-needed",
|
||||
chat_model=getenv("EBOOK_SEARCH_CHAT_MODEL", "deepseek-v4-flash"),
|
||||
answer_enabled=getenv_bool("EBOOK_SEARCH_ANSWER_ENABLED", default=True),
|
||||
embedding_base_url=getenv("EBOOK_SEARCH_EMBEDDING_BASE_URL", "http://192.168.90.25:8000/v1"),
|
||||
embedding_api_key=getenv("EBOOK_SEARCH_EMBEDDING_API_KEY", "not-needed"),
|
||||
embedding_model=normalize_embedding_model(),
|
||||
embedding_batch_size=getenv_int("EBOOK_SEARCH_EMBEDDING_BATCH_SIZE", default=32),
|
||||
bm25_index_dir=getenv("EBOOK_SEARCH_BM25_INDEX_DIR", ".ebook_search_bm25"),
|
||||
bm25_refresh_delay_seconds=getenv_int("EBOOK_SEARCH_BM25_REFRESH_DELAY_SECONDS", default=60),
|
||||
)
|
||||
|
||||
|
||||
def normalize_embedding_model(default: str = "qwen3-embedding-0.6b") -> str:
|
||||
"""Normalize supported embedding aliases to provider model names."""
|
||||
aliases = {
|
||||
"Qwen3-Embedding-0.6B": "qwen3-embedding-0.6b",
|
||||
"Qwen3-Embedding-4B": "qwen3-embedding-4b",
|
||||
"Qwen3-Embedding-8B": "qwen3-embedding-8b",
|
||||
"Qwen/Qwen3-Embedding-0.6B": "qwen3-embedding-0.6b",
|
||||
"Qwen/Qwen3-Embedding-4B": "qwen3-embedding-4b",
|
||||
"Qwen/Qwen3-Embedding-8B": "qwen3-embedding-8b",
|
||||
"qwen3-embedding:0.6b": "qwen3-embedding-0.6b",
|
||||
"qwen3-embedding:4b": "qwen3-embedding-4b",
|
||||
"qwen3-embedding:8b": "qwen3-embedding-8b",
|
||||
"qwen3-embedding-0.6b": "qwen3-embedding-0.6b",
|
||||
"qwen3-embedding-4b": "qwen3-embedding-4b",
|
||||
"qwen3-embedding-8b": "qwen3-embedding-8b",
|
||||
}
|
||||
|
||||
model = getenv("EBOOK_SEARCH_EMBEDDING_MODEL", default)
|
||||
standard_model = aliases.get(model)
|
||||
|
||||
if standard_model is None:
|
||||
error = f"Embedding model {model} is not supported. Supported models are {aliases.keys()}"
|
||||
raise ValueError(error)
|
||||
|
||||
return standard_model
|
||||
|
||||
|
||||
def library_paths_from_env() -> tuple[str, ...]:
|
||||
"""Read configured EPUB library paths from the environment."""
|
||||
value = getenv("EBOOK_SEARCH_LIBRARY_PATHS")
|
||||
if value is None:
|
||||
return ()
|
||||
return tuple(path for path in value.split(":") if path)
|
||||
@@ -1,170 +0,0 @@
|
||||
"""Embedding model helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from python.ebook_search.llm_interface import request_embeddings
|
||||
from python.orm.richie import (
|
||||
EbookChunk,
|
||||
EbookChunkEmbedding1024,
|
||||
EbookChunkEmbedding2560,
|
||||
EbookChunkEmbedding4096,
|
||||
EbookEmbeddingModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
|
||||
MODEL_DIMENSIONS = {
|
||||
"qwen3-embedding-0.6b": 1024,
|
||||
"qwen3-embedding-4b": 2560,
|
||||
"qwen3-embedding-8b": 4096,
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_table(
|
||||
dimension: int,
|
||||
) -> type[EbookChunkEmbedding1024 | EbookChunkEmbedding2560 | EbookChunkEmbedding4096]:
|
||||
"""Return the embedding table mapped to an embedding dimension."""
|
||||
embedding_tables = {
|
||||
1024: EbookChunkEmbedding1024,
|
||||
2560: EbookChunkEmbedding2560,
|
||||
4096: EbookChunkEmbedding4096,
|
||||
}
|
||||
table = embedding_tables.get(dimension)
|
||||
if not table:
|
||||
msg = f"Embedding dimension {dimension} is not supported"
|
||||
raise ValueError(msg)
|
||||
return table
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmbeddingModelStats:
|
||||
"""Embedding coverage for one model."""
|
||||
|
||||
model_name: str
|
||||
dimension: int
|
||||
embedded_chunks: int
|
||||
total_chunks: int
|
||||
|
||||
@property
|
||||
def missing_chunks(self) -> int:
|
||||
"""Return chunks missing this embedding model."""
|
||||
return max(self.total_chunks - self.embedded_chunks, 0)
|
||||
|
||||
|
||||
def embed_texts(texts: Sequence[str], config: EbookSearchConfig) -> list[list[float]]:
|
||||
"""Embed text with the configured vLLM embedding model."""
|
||||
logger.info(
|
||||
"ebook_embed_request_start base_url=%s model=%s count=%s",
|
||||
config.embedding_base_url,
|
||||
config.embedding_model,
|
||||
len(texts),
|
||||
)
|
||||
vectors = request_embeddings(texts, config)
|
||||
expected_dimension = MODEL_DIMENSIONS[config.embedding_model]
|
||||
for vector in vectors:
|
||||
if len(vector) != expected_dimension:
|
||||
msg = f"Expected {expected_dimension} dimensions, got {len(vector)}"
|
||||
raise ValueError(msg)
|
||||
logger.info(
|
||||
"ebook_embed_request_complete model=%s count=%s dimension=%s",
|
||||
config.embedding_model,
|
||||
len(vectors),
|
||||
expected_dimension,
|
||||
)
|
||||
return vectors
|
||||
|
||||
|
||||
def embed_query(query: str, config: EbookSearchConfig) -> list[float]:
|
||||
"""Embed a search query with the Qwen retrieval instruction."""
|
||||
instructed_query = f"Instruct: Retrieve relevant passages for the query.\nQuery: {query}"
|
||||
return embed_texts([instructed_query], config)[0]
|
||||
|
||||
|
||||
def ensure_embedding_models(session: Session) -> None:
|
||||
"""Ensure supported embedding model rows exist."""
|
||||
for name, dimension in MODEL_DIMENSIONS.items():
|
||||
existing = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == name))
|
||||
if existing is None:
|
||||
session.add(EbookEmbeddingModel(name=name, dimension=dimension, is_default=name == "qwen3-embedding-0.6b"))
|
||||
logger.info("ebook_embedding_model_created model=%s dimension=%s", name, dimension)
|
||||
session.flush()
|
||||
|
||||
|
||||
def embedding_model_stats(session: Session) -> list[EmbeddingModelStats]:
|
||||
"""Return embedding coverage counts for every supported model."""
|
||||
total_chunks = session.scalar(select(func.count(EbookChunk.id))) or 0
|
||||
models = {
|
||||
model.name: model
|
||||
for model in session.scalars(
|
||||
select(EbookEmbeddingModel)
|
||||
.where(EbookEmbeddingModel.name.in_(MODEL_DIMENSIONS))
|
||||
.order_by(EbookEmbeddingModel.name)
|
||||
)
|
||||
}
|
||||
|
||||
stats: list[EmbeddingModelStats] = []
|
||||
for model_name, dimension in MODEL_DIMENSIONS.items():
|
||||
model = models.get(model_name)
|
||||
embedded_chunks = 0
|
||||
if model is not None:
|
||||
table = get_embedding_table(dimension)
|
||||
embedded_chunks = session.scalar(select(func.count(table.id)).where(table.model_id == model.id)) or 0
|
||||
stats.append(
|
||||
EmbeddingModelStats(
|
||||
model_name=model_name,
|
||||
dimension=dimension,
|
||||
embedded_chunks=embedded_chunks,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
def embed_missing_chunks(session: Session, config: EbookSearchConfig) -> int:
|
||||
"""Embed chunks missing embeddings for the configured model."""
|
||||
ensure_embedding_models(session)
|
||||
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
|
||||
if model is None:
|
||||
supported_models = ", ".join(MODEL_DIMENSIONS)
|
||||
msg = f"Unknown embedding model: {config.embedding_model}. Supported models: {supported_models}"
|
||||
raise ValueError(msg)
|
||||
|
||||
table = get_embedding_table(model.dimension)
|
||||
chunks = list(
|
||||
session.scalars(
|
||||
select(EbookChunk)
|
||||
.outerjoin(table, (table.chunk_id == EbookChunk.id) & (table.model_id == model.id))
|
||||
.where(table.id.is_(None))
|
||||
.order_by(EbookChunk.id)
|
||||
.limit(config.embedding_batch_size)
|
||||
)
|
||||
)
|
||||
if not chunks:
|
||||
logger.info("ebook_embed_missing_none model=%s", config.embedding_model)
|
||||
return 0
|
||||
|
||||
logger.info("ebook_embed_missing_batch_start model=%s count=%s", config.embedding_model, len(chunks))
|
||||
vectors = embed_texts([chunk.text for chunk in chunks], config)
|
||||
rows = [
|
||||
{"chunk_id": chunk.id, "model_id": model.id, "embedding": vector}
|
||||
for chunk, vector in zip(chunks, vectors, strict=True)
|
||||
]
|
||||
statement = insert(table).values(rows).on_conflict_do_nothing(index_elements=["chunk_id", "model_id"])
|
||||
session.execute(statement)
|
||||
session.flush()
|
||||
logger.info("ebook_embed_missing_batch_complete model=%s count=%s", config.embedding_model, len(rows))
|
||||
return len(rows)
|
||||
@@ -1,95 +0,0 @@
|
||||
"""EPUB parsing helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from ebooklib import ITEM_DOCUMENT, epub
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
WHITESPACE_RE = re.compile(r"\s+")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParsedChapter:
|
||||
"""Text extracted from one EPUB spine document."""
|
||||
|
||||
title: str | None
|
||||
href: str | None
|
||||
text: str
|
||||
page_labels: tuple[str, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParsedEpub:
|
||||
"""Parsed EPUB metadata and text."""
|
||||
|
||||
title: str
|
||||
author: str | None
|
||||
language: str | None
|
||||
publisher: str | None
|
||||
identifier: str | None
|
||||
chapters: tuple[ParsedChapter, ...]
|
||||
|
||||
|
||||
def parse_epub(path: Path) -> ParsedEpub:
|
||||
"""Parse EPUB metadata and spine text."""
|
||||
book = epub.read_epub(path)
|
||||
chapters = []
|
||||
for item in book.get_items_of_type(ITEM_DOCUMENT):
|
||||
soup = BeautifulSoup(item.get_content(), "html.parser")
|
||||
title = chapter_title(soup)
|
||||
page_labels = tuple(extract_page_labels(soup))
|
||||
text = clean_text(soup.get_text(" "))
|
||||
if text:
|
||||
chapters.append(ParsedChapter(title=title, href=item.get_name(), text=text, page_labels=page_labels))
|
||||
|
||||
return ParsedEpub(
|
||||
title=metadata_value(book, "title") or path.stem,
|
||||
author=metadata_value(book, "creator"),
|
||||
language=metadata_value(book, "language"),
|
||||
publisher=metadata_value(book, "publisher"),
|
||||
identifier=metadata_value(book, "identifier"),
|
||||
chapters=tuple(chapters),
|
||||
)
|
||||
|
||||
|
||||
def metadata_value(book: epub.EpubBook, name: str) -> str | None:
|
||||
"""Return the first non-empty Dublin Core metadata value for a name."""
|
||||
values = book.get_metadata("DC", name)
|
||||
if not values:
|
||||
return None
|
||||
value = values[0][0]
|
||||
return str(value).strip() or None
|
||||
|
||||
|
||||
def chapter_title(soup: BeautifulSoup) -> str | None:
|
||||
"""Extract the best available title from an EPUB document soup."""
|
||||
heading = soup.find(["h1", "h2", "h3"])
|
||||
if heading is None:
|
||||
title = soup.find("title")
|
||||
if title is None:
|
||||
return None
|
||||
return clean_text(title.get_text(" ")) or None
|
||||
return clean_text(heading.get_text(" ")) or None
|
||||
|
||||
|
||||
def extract_page_labels(soup: BeautifulSoup) -> list[str]:
|
||||
"""Extract EPUB page-break labels from a document soup."""
|
||||
labels: list[str] = []
|
||||
for tag in soup.find_all(attrs={"epub:type": "pagebreak"}):
|
||||
label = tag.get("title") or tag.get("aria-label") or tag.get_text(" ")
|
||||
clean = clean_text(str(label))
|
||||
if clean:
|
||||
labels.append(clean)
|
||||
return labels
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""Normalize whitespace in extracted EPUB text."""
|
||||
return WHITESPACE_RE.sub(" ", text).strip()
|
||||
@@ -1,190 +0,0 @@
|
||||
"""EPUB ingestion into Richie DB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tiktoken
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from python.ebook_search.epub_parse import parse_epub
|
||||
from python.orm.richie import EbookChapter, EbookChunk, EbookSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_CHUNK_TOKENS = 700
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
from python.ebook_search.epub_parse import ParsedChapter
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextChunk:
|
||||
"""A token-bounded chunk of text."""
|
||||
|
||||
text: str
|
||||
token_start: int
|
||||
token_count: int
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
*,
|
||||
chunk_tokens: int = DEFAULT_CHUNK_TOKENS,
|
||||
overlap_tokens: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> list[TextChunk]:
|
||||
"""Split text into overlapping token chunks."""
|
||||
if chunk_tokens <= 0:
|
||||
msg = "chunk_tokens must be positive"
|
||||
raise ValueError(msg)
|
||||
if overlap_tokens < 0 or overlap_tokens >= chunk_tokens:
|
||||
msg = "overlap_tokens must be non-negative and smaller than chunk_tokens"
|
||||
raise ValueError(msg)
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoding.encode(text)
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
chunks: list[TextChunk] = []
|
||||
step = chunk_tokens - overlap_tokens
|
||||
for start in range(0, len(tokens), step):
|
||||
chunk = tokens[start : start + chunk_tokens]
|
||||
if not chunk:
|
||||
continue
|
||||
chunks.append(
|
||||
TextChunk(
|
||||
text=encoding.decode(chunk).strip(),
|
||||
token_start=start,
|
||||
token_count=len(chunk),
|
||||
)
|
||||
)
|
||||
if start + chunk_tokens >= len(tokens):
|
||||
break
|
||||
return [chunk for chunk in chunks if chunk.text]
|
||||
|
||||
|
||||
def ingest_configured_paths(session: Session, config: EbookSearchConfig) -> int:
|
||||
"""Ingest every EPUB found under configured library paths."""
|
||||
count = 0
|
||||
for library_path in config.library_paths:
|
||||
path = Path(library_path).expanduser()
|
||||
logger.info("ebook_ingest_path_start path=%s", path)
|
||||
if path.is_file() and path.suffix.lower() == ".epub":
|
||||
count += int(ingest_file(session, path))
|
||||
elif path.is_dir():
|
||||
for epub_path in sorted(path.rglob("*.epub")):
|
||||
count += int(ingest_file(session, epub_path))
|
||||
else:
|
||||
logger.warning("ebook_ingest_path_missing path=%s", path)
|
||||
logger.info("ebook_ingest_paths_complete changed_files=%s configured_paths=%s", count, len(config.library_paths))
|
||||
return count
|
||||
|
||||
|
||||
def ingest_file(session: Session, path: Path) -> bool:
|
||||
"""Ingest one EPUB file. Return True when the database changed."""
|
||||
resolved_path = path.expanduser().resolve()
|
||||
logger.info("ebook_ingest_file_start path=%s", resolved_path)
|
||||
file_hash = sha256_file(resolved_path)
|
||||
existing = find_existing_source(session, resolved_path, file_hash)
|
||||
if existing is not None and existing.file_sha256 == file_hash:
|
||||
stat = resolved_path.stat()
|
||||
existing.file_path = str(resolved_path)
|
||||
existing.file_mtime = datetime.fromtimestamp(stat.st_mtime, tz=UTC)
|
||||
existing.file_size = stat.st_size
|
||||
session.flush()
|
||||
logger.info("ebook_ingest_file_unchanged source_id=%s path=%s", existing.id, resolved_path)
|
||||
return False
|
||||
if existing is not None:
|
||||
logger.info("ebook_ingest_file_replacing source_id=%s path=%s", existing.id, resolved_path)
|
||||
session.delete(existing)
|
||||
session.flush()
|
||||
|
||||
stat = resolved_path.stat()
|
||||
parsed = parse_epub(resolved_path)
|
||||
source = EbookSource(
|
||||
title=parsed.title,
|
||||
author=parsed.author,
|
||||
language=parsed.language,
|
||||
publisher=parsed.publisher,
|
||||
identifier=parsed.identifier,
|
||||
file_path=str(resolved_path),
|
||||
file_sha256=file_hash,
|
||||
file_mtime=datetime.fromtimestamp(stat.st_mtime, tz=UTC),
|
||||
file_size=stat.st_size,
|
||||
)
|
||||
session.add(source)
|
||||
session.flush()
|
||||
|
||||
chunk_index = 0
|
||||
for spine_index, parsed_chapter in enumerate(parsed.chapters):
|
||||
chapter = EbookChapter(
|
||||
source_id=source.id,
|
||||
spine_index=spine_index,
|
||||
title=parsed_chapter.title,
|
||||
href=parsed_chapter.href,
|
||||
)
|
||||
session.add(chapter)
|
||||
session.flush()
|
||||
chunk_index = add_chapter_chunks(session, source, chapter, parsed_chapter, chunk_index)
|
||||
|
||||
session.flush()
|
||||
logger.info(
|
||||
"ebook_ingest_file_complete source_id=%s path=%s chapters=%s chunks=%s",
|
||||
source.id,
|
||||
resolved_path,
|
||||
len(parsed.chapters),
|
||||
chunk_index,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def find_existing_source(session: Session, path: Path, file_hash: str) -> EbookSource | None:
|
||||
"""Find an existing source by canonical path or file hash."""
|
||||
return session.scalar(
|
||||
select(EbookSource).where(or_(EbookSource.file_path == str(path), EbookSource.file_sha256 == file_hash))
|
||||
)
|
||||
|
||||
|
||||
def add_chapter_chunks(
|
||||
session: Session,
|
||||
source: EbookSource,
|
||||
chapter: EbookChapter,
|
||||
parsed_chapter: ParsedChapter,
|
||||
chunk_index: int,
|
||||
) -> int:
|
||||
"""Add chunk rows for one parsed chapter and return the next chunk index."""
|
||||
page_label = parsed_chapter.page_labels[0] if parsed_chapter.page_labels else None
|
||||
for text_chunk in chunk_text(parsed_chapter.text):
|
||||
session.add(
|
||||
EbookChunk(
|
||||
source_id=source.id,
|
||||
chapter_id=chapter.id,
|
||||
chunk_index=chunk_index,
|
||||
text=text_chunk.text,
|
||||
token_start=text_chunk.token_start,
|
||||
token_count=text_chunk.token_count,
|
||||
page_label=page_label,
|
||||
content_sha256=hashlib.sha256(text_chunk.text.encode()).hexdigest(),
|
||||
search_text=f"{source.title} {source.author or ''} {chapter.title or ''} {text_chunk.text}",
|
||||
)
|
||||
)
|
||||
chunk_index += 1
|
||||
return chunk_index
|
||||
|
||||
|
||||
def sha256_file(path: Path) -> str:
|
||||
"""Calculate the SHA-256 digest for a file."""
|
||||
digest = hashlib.sha256()
|
||||
with path.open("rb") as file:
|
||||
for block in iter(lambda: file.read(1024 * 1024), b""):
|
||||
digest.update(block)
|
||||
return digest.hexdigest()
|
||||
@@ -1,143 +0,0 @@
|
||||
"""LLM provider HTTP adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from python.ebook_search.config import EbookSearchConfig, RerankConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def auth_headers(api_key: str) -> dict[str, str]:
|
||||
"""Build authorization headers when an API key is configured."""
|
||||
if api_key == "not-needed":
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
|
||||
def request_embeddings(texts: Sequence[str], config: EbookSearchConfig) -> list[list[float]]:
|
||||
"""Request embeddings from the configured OpenAI-compatible endpoint."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
f"{config.embedding_base_url.rstrip('/')}/embeddings",
|
||||
headers=auth_headers(config.embedding_api_key),
|
||||
json={"model": config.embedding_model, "input": list(texts)},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return embedding_vectors_from_response(response.json())
|
||||
except (httpx.HTTPError, ValueError, KeyError, TypeError) as error:
|
||||
logger.exception(
|
||||
"ebook_embed_request_failed base_url=%s model=%s count=%s",
|
||||
config.embedding_base_url,
|
||||
config.embedding_model,
|
||||
len(texts),
|
||||
)
|
||||
msg = f"Embedding request failed. base_url={config.embedding_base_url} model={config.embedding_model}"
|
||||
raise RuntimeError(msg) from error
|
||||
|
||||
|
||||
def embedding_vectors_from_response(body: object) -> list[list[float]]:
|
||||
"""Extract embedding vectors from an OpenAI-compatible embedding response."""
|
||||
if not isinstance(body, dict):
|
||||
msg = "Embedding response is not an object"
|
||||
raise TypeError(msg)
|
||||
|
||||
data = body["data"]
|
||||
if not isinstance(data, list):
|
||||
msg = "Embedding response data is not a list"
|
||||
raise TypeError(msg)
|
||||
|
||||
vectors: list[list[float]] = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
msg = "Embedding item is not an object"
|
||||
raise TypeError(msg)
|
||||
embedding = item["embedding"]
|
||||
if not isinstance(embedding, list):
|
||||
msg = "Embedding value is not a list"
|
||||
raise TypeError(msg)
|
||||
vectors.append([float(value) for value in embedding])
|
||||
return vectors
|
||||
|
||||
|
||||
def request_rerank(
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
config: RerankConfig,
|
||||
) -> object | None:
|
||||
"""Request rerank scores from the configured vLLM endpoint."""
|
||||
payload = {
|
||||
"model": config.model,
|
||||
"query": query,
|
||||
"documents": list(documents),
|
||||
}
|
||||
response = httpx.post(
|
||||
f"{config.base_url.rstrip('/')}/rerank",
|
||||
json=payload,
|
||||
timeout=config.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
logger.debug("ebook_rerank_response_invalid_json", extra={"response": response.text})
|
||||
return None
|
||||
|
||||
|
||||
def request_chat_completion(
|
||||
config: EbookSearchConfig,
|
||||
messages: Sequence[dict[str, str]],
|
||||
) -> str:
|
||||
"""Request a chat completion from the configured OpenAI-compatible endpoint."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
f"{config.vllm_base_url.rstrip('/')}/chat/completions",
|
||||
headers=auth_headers(config.vllm_api_key),
|
||||
json={
|
||||
"model": config.chat_model,
|
||||
"messages": list(messages),
|
||||
"temperature": 0,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return chat_content_from_response(response.json())
|
||||
except (httpx.HTTPError, ValueError, KeyError, TypeError) as error:
|
||||
msg = f"Chat request failed. base_url={config.vllm_base_url} model={config.chat_model}"
|
||||
raise RuntimeError(msg) from error
|
||||
|
||||
|
||||
def chat_content_from_response(body: object) -> str:
|
||||
"""Extract text content from an OpenAI-compatible chat response."""
|
||||
if not isinstance(body, dict):
|
||||
msg = "Chat response is not an object"
|
||||
raise TypeError(msg)
|
||||
|
||||
choices = body["choices"]
|
||||
if not isinstance(choices, list) or not choices:
|
||||
msg = "Chat response has no choices"
|
||||
raise ValueError(msg)
|
||||
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
msg = "Chat choice is not an object"
|
||||
raise TypeError(msg)
|
||||
|
||||
message = first["message"]
|
||||
if not isinstance(message, dict):
|
||||
msg = "Chat message is not an object"
|
||||
raise TypeError(msg)
|
||||
|
||||
content = message.get("content") or ""
|
||||
if not isinstance(content, str):
|
||||
msg = "Chat content is not text"
|
||||
raise TypeError(msg)
|
||||
return content
|
||||
@@ -1,129 +0,0 @@
|
||||
"""vLLM-backed optional reranking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from python.ebook_search.llm_interface import request_rerank
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.ebook_search.config import RerankConfig
|
||||
from python.ebook_search.search import SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RERANK_SCORE_WEIGHT = 0.7
|
||||
HYBRID_SCORE_WEIGHT = 0.3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankResult:
|
||||
"""A relevance score for one candidate chunk."""
|
||||
|
||||
chunk_id: int
|
||||
score: float
|
||||
|
||||
|
||||
def rerank_chunks(query: str, candidates: list[SearchResult], config: RerankConfig) -> list[SearchResult]:
|
||||
"""Rerank candidates with a vLLM rerank endpoint."""
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"ebook_rerank_request_start base_url=%s model=%s candidates=%s",
|
||||
config.base_url,
|
||||
config.model,
|
||||
len(candidates),
|
||||
)
|
||||
scores = score_candidates(query, candidates, config)
|
||||
results = sorted(
|
||||
(
|
||||
replace(
|
||||
result,
|
||||
score=final_rerank_score(result, scores[result.chunk_id].score, candidates),
|
||||
rerank_score=scores[result.chunk_id].score,
|
||||
)
|
||||
for result in candidates
|
||||
),
|
||||
key=lambda result: result.score,
|
||||
reverse=True,
|
||||
)
|
||||
logger.info(
|
||||
"ebook_rerank_request_complete base_url=%s model=%s candidates=%s",
|
||||
config.base_url,
|
||||
config.model,
|
||||
len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def score_candidates(
|
||||
query: str,
|
||||
candidates: list[SearchResult],
|
||||
config: RerankConfig,
|
||||
) -> dict[int, RerankResult]:
|
||||
"""Score candidate chunks with the configured rerank API."""
|
||||
body = request_rerank(query, [candidate.text for candidate in candidates], config)
|
||||
if body is None:
|
||||
return zero_rerank_scores(candidates)
|
||||
|
||||
scores = parse_vllm_scores(body, candidates)
|
||||
for result in scores.values():
|
||||
logger.debug("ebook_rerank_candidate_scored chunk_id=%s score=%s", result.chunk_id, result.score)
|
||||
return scores
|
||||
|
||||
|
||||
def parse_vllm_scores(body: object, candidates: list[SearchResult]) -> dict[int, RerankResult]:
|
||||
"""Parse vLLM rerank scores into chunk-id keyed results."""
|
||||
if not isinstance(body, dict):
|
||||
logger.debug("ebook_rerank_response_not_object", extra={"response": body})
|
||||
return zero_rerank_scores(candidates)
|
||||
|
||||
results = body.get("results") or body.get("data")
|
||||
if not isinstance(results, list):
|
||||
logger.debug("ebook_rerank_response_missing_results", extra={"response": body})
|
||||
return zero_rerank_scores(candidates)
|
||||
|
||||
scores = zero_rerank_scores(candidates)
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
index = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score"))
|
||||
if not isinstance(index, int) or index < 0 or index >= len(candidates):
|
||||
continue
|
||||
if not isinstance(score, int | float):
|
||||
continue
|
||||
chunk_id = candidates[index].chunk_id
|
||||
scores[chunk_id] = RerankResult(chunk_id=chunk_id, score=clamp_score(float(score)))
|
||||
return scores
|
||||
|
||||
|
||||
def zero_rerank_scores(candidates: list[SearchResult]) -> dict[int, RerankResult]:
|
||||
"""Return zero relevance scores for all candidate chunks."""
|
||||
return {candidate.chunk_id: RerankResult(chunk_id=candidate.chunk_id, score=0.0) for candidate in candidates}
|
||||
|
||||
|
||||
def clamp_score(score: float) -> float:
|
||||
"""Clamp a rerank score into the supported 0.0 to 1.0 range."""
|
||||
return min(max(score, 0.0), 1.0)
|
||||
|
||||
|
||||
def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
|
||||
"""Combine rerank relevance with normalized hybrid retrieval evidence."""
|
||||
return (RERANK_SCORE_WEIGHT * rerank_score) + (HYBRID_SCORE_WEIGHT * normalized_hybrid_score(result, candidates))
|
||||
|
||||
|
||||
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
|
||||
"""Normalize a candidate hybrid score against the rerank candidate set."""
|
||||
hybrid_scores = [
|
||||
candidate.fused_score if candidate.fused_score is not None else candidate.score for candidate in candidates
|
||||
]
|
||||
low = min(hybrid_scores)
|
||||
high = max(hybrid_scores)
|
||||
if high == low:
|
||||
return 1.0
|
||||
|
||||
score = result.fused_score if result.fused_score is not None else result.score
|
||||
return (score - low) / (high - low)
|
||||
@@ -1,383 +0,0 @@
|
||||
"""Hybrid search orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import literal, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.ebook_search.bm25_corpus import (
|
||||
BM25CorpusUnavailableError,
|
||||
load_bm25_corpus,
|
||||
score_bm25_corpus,
|
||||
)
|
||||
from python.ebook_search.embeddings import MODEL_DIMENSIONS, embed_query, get_embedding_table
|
||||
from python.ebook_search.rerank import rerank_chunks
|
||||
from python.ebook_search.timing import RuntimeStep, timed_result
|
||||
from python.orm.richie import (
|
||||
EbookChapter,
|
||||
EbookChunk,
|
||||
EbookEmbeddingModel,
|
||||
EbookSource,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
BM25_CANDIDATE_LIMIT = 120
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SearchResult:
|
||||
"""One source chunk returned by search."""
|
||||
|
||||
chunk_id: int
|
||||
text: str
|
||||
source_title: str
|
||||
score: float = 0.0
|
||||
vector_score: float | None = None
|
||||
bm25_score: float | None = None
|
||||
fused_score: float | None = None
|
||||
rerank_score: float | None = None
|
||||
source_author: str | None = None
|
||||
chapter_title: str | None = None
|
||||
page_label: str | None = None
|
||||
rank_source: str = "Hybrid"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SearchResponse:
|
||||
"""Search output for the UI."""
|
||||
|
||||
query: str
|
||||
results: list[SearchResult]
|
||||
rank_label: str
|
||||
timings: tuple[RuntimeStep, ...] = ()
|
||||
|
||||
@property
|
||||
def total_runtime_ms(self) -> float:
|
||||
"""Return total measured runtime for the response."""
|
||||
return sum(step.duration_ms for step in self.timings if step.counts_toward_total)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetrievalResponse:
|
||||
"""Parallel retrieval output for vector and BM25 candidates."""
|
||||
|
||||
vector_results: list[SearchResult]
|
||||
lexical_results: list[SearchResult]
|
||||
timings: tuple[RuntimeStep, ...]
|
||||
|
||||
|
||||
def search_ebooks(
|
||||
engine: Engine,
|
||||
query: str,
|
||||
config: EbookSearchConfig,
|
||||
*,
|
||||
rerank: bool = False,
|
||||
) -> SearchResponse:
|
||||
"""Run hybrid vector/BM25 search and optional reranking."""
|
||||
if not query.strip():
|
||||
logger.info("ebook_search_empty_query")
|
||||
return SearchResponse(query=query, results=[], rank_label="Hybrid")
|
||||
|
||||
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
|
||||
timings: list[RuntimeStep] = []
|
||||
bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query)
|
||||
timings.append(timing)
|
||||
retrieval, timing = timed_result(
|
||||
"Hybrid retrieval",
|
||||
parallel_retrieval,
|
||||
engine,
|
||||
query,
|
||||
bm25_query,
|
||||
config,
|
||||
)
|
||||
timings.extend(retrieval.timings)
|
||||
timings.append(timing)
|
||||
fused, timing = timed_result(
|
||||
"Reciprocal rank fusion",
|
||||
reciprocal_rank_fusion,
|
||||
retrieval.vector_results,
|
||||
retrieval.lexical_results,
|
||||
)
|
||||
timings.append(timing)
|
||||
if config.rerank.enabled and rerank:
|
||||
response, timing = timed_result("Rerank", apply_rerank, query, fused, config)
|
||||
else:
|
||||
response, timing = timed_result("Rerank skipped", skip_rerank, query, fused, config)
|
||||
timings.append(timing)
|
||||
response = replace(response, timings=tuple(timings))
|
||||
logger.info(
|
||||
"ebook_search_complete vector_candidates=%s lexical_candidates=%s "
|
||||
"fused_candidates=%s returned=%s rank_label=%s runtime_ms=%.1f",
|
||||
len(retrieval.vector_results),
|
||||
len(retrieval.lexical_results),
|
||||
len(fused),
|
||||
len(response.results),
|
||||
response.rank_label,
|
||||
response.total_runtime_ms,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def parallel_retrieval(
|
||||
engine: Engine,
|
||||
vector_query: str,
|
||||
bm25_query: str,
|
||||
config: EbookSearchConfig,
|
||||
) -> RetrievalResponse:
|
||||
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
|
||||
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor:
|
||||
vector_future = executor.submit(
|
||||
timed_result,
|
||||
"Embedding + vector search",
|
||||
vector_candidates,
|
||||
engine,
|
||||
vector_query,
|
||||
config,
|
||||
)
|
||||
bm25_future = executor.submit(
|
||||
timed_result,
|
||||
"BM25 search",
|
||||
bm25_candidates,
|
||||
bm25_query,
|
||||
config,
|
||||
)
|
||||
vector_results, vector_timing = vector_future.result()
|
||||
lexical_results, lexical_timing = bm25_future.result()
|
||||
|
||||
logger.info(
|
||||
"ebook_parallel_retrieval_complete vector_candidates=%s lexical_candidates=%s",
|
||||
len(vector_results),
|
||||
len(lexical_results),
|
||||
)
|
||||
return RetrievalResponse(
|
||||
vector_results=vector_results,
|
||||
lexical_results=lexical_results,
|
||||
timings=(
|
||||
replace(vector_timing, counts_toward_total=False),
|
||||
replace(lexical_timing, counts_toward_total=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def skip_rerank(
|
||||
query: str,
|
||||
candidates: list[SearchResult],
|
||||
config: EbookSearchConfig,
|
||||
) -> SearchResponse:
|
||||
"""Return fused hybrid results without reranking."""
|
||||
logger.info("ebook_rerank_skipped candidates=%s", len(candidates))
|
||||
return SearchResponse(query=query, results=candidates[: config.top_k], rank_label="Hybrid")
|
||||
|
||||
|
||||
def apply_rerank(
|
||||
query: str,
|
||||
candidates: list[SearchResult],
|
||||
config: EbookSearchConfig,
|
||||
) -> SearchResponse:
|
||||
"""Rerank already-fused hybrid candidates."""
|
||||
reranked = rerank_chunks(query, candidates[: config.rerank.candidates], config.rerank)
|
||||
logger.info(
|
||||
"ebook_rerank_complete input_candidates=%s returned=%s",
|
||||
min(len(candidates), config.rerank.candidates),
|
||||
len(reranked),
|
||||
)
|
||||
return SearchResponse(
|
||||
query=query,
|
||||
results=[replace(result, rank_source="Hybrid + rerank") for result in reranked[: config.top_k]],
|
||||
rank_label="Hybrid + rerank",
|
||||
)
|
||||
|
||||
|
||||
def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]:
|
||||
"""Return pgvector cosine candidates for a natural-language query."""
|
||||
with Session(engine) as session:
|
||||
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
|
||||
if model is None:
|
||||
msg = f"Embedding model is not registered: {config.embedding_model}"
|
||||
raise ValueError(msg)
|
||||
|
||||
expected_dimension = MODEL_DIMENSIONS[config.embedding_model]
|
||||
if model.dimension != expected_dimension:
|
||||
msg = f"Model row dimension {model.dimension} does not match configured dimension {expected_dimension}"
|
||||
raise ValueError(msg)
|
||||
|
||||
embedding = embed_query(query, config)
|
||||
limit = max(config.rerank.candidates, config.top_k) * 4
|
||||
embedding_table = get_embedding_table(model.dimension)
|
||||
|
||||
embedding_param = literal(embedding, type_=Vector(model.dimension))
|
||||
distance = embedding_table.embedding.op("<=>")(embedding_param)
|
||||
score = (literal(1.0) - distance).label("score")
|
||||
statement = (
|
||||
select(
|
||||
EbookChunk.id.label("chunk_id"),
|
||||
EbookChunk.text.label("text"),
|
||||
EbookSource.title.label("source_title"),
|
||||
EbookSource.author.label("source_author"),
|
||||
EbookChapter.title.label("chapter_title"),
|
||||
EbookChunk.page_label.label("page_label"),
|
||||
score,
|
||||
)
|
||||
.select_from(embedding_table)
|
||||
.join(EbookChunk, EbookChunk.id == embedding_table.chunk_id)
|
||||
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
|
||||
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
|
||||
.where(embedding_table.model_id == model.id)
|
||||
.order_by(distance)
|
||||
.limit(limit)
|
||||
)
|
||||
rows = session.execute(statement).mappings()
|
||||
results = [search_result_from_row(row) for row in rows]
|
||||
logger.info(
|
||||
"ebook_vector_search_complete model=%s dimension=%s candidates=%s",
|
||||
config.embedding_model,
|
||||
model.dimension,
|
||||
len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]:
|
||||
"""Return BM25-ranked lexical candidates using the persisted corpus."""
|
||||
try:
|
||||
corpus = load_bm25_corpus(config)
|
||||
except BM25CorpusUnavailableError as error:
|
||||
logger.warning("ebook_bm25_index_unavailable_skipping error=%s", error)
|
||||
return []
|
||||
|
||||
if not corpus.records:
|
||||
logger.info("ebook_bm25_search_complete corpus=0 candidates=0")
|
||||
return []
|
||||
|
||||
scored_records = score_bm25_corpus(query, corpus, limit=BM25_CANDIDATE_LIMIT)
|
||||
results = [
|
||||
replace(search_result_from_row(record), score=score, vector_score=None, bm25_score=score)
|
||||
for record, score in scored_records
|
||||
]
|
||||
|
||||
max_score = results[0].bm25_score if results else 0.0
|
||||
logger.info(
|
||||
"ebook_bm25_search_complete corpus=%s candidates=%s max_score=%.6f",
|
||||
len(corpus.records),
|
||||
len(results),
|
||||
max_score,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
vector_results: list[SearchResult],
|
||||
lexical_results: list[SearchResult],
|
||||
*,
|
||||
rank_constant: int = 60,
|
||||
) -> list[SearchResult]:
|
||||
"""Fuse vector and lexical rankings with Reciprocal Rank Fusion."""
|
||||
by_chunk: dict[int, SearchResult] = {}
|
||||
scores: dict[int, float] = {}
|
||||
vector_scores: dict[int, float] = {}
|
||||
bm25_scores: dict[int, float] = {}
|
||||
|
||||
for rank, result in enumerate(vector_results, start=1):
|
||||
by_chunk.setdefault(result.chunk_id, result)
|
||||
vector_scores[result.chunk_id] = result.vector_score if result.vector_score is not None else result.score
|
||||
scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank))
|
||||
|
||||
for rank, result in enumerate(lexical_results, start=1):
|
||||
by_chunk.setdefault(result.chunk_id, result)
|
||||
bm25_scores[result.chunk_id] = result.bm25_score if result.bm25_score is not None else result.score
|
||||
scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank))
|
||||
|
||||
return sorted(
|
||||
(
|
||||
replace(
|
||||
result,
|
||||
score=scores[result.chunk_id],
|
||||
vector_score=vector_scores.get(result.chunk_id),
|
||||
bm25_score=bm25_scores.get(result.chunk_id),
|
||||
fused_score=scores[result.chunk_id],
|
||||
rank_source="Hybrid",
|
||||
)
|
||||
for result in by_chunk.values()
|
||||
),
|
||||
key=lambda result: result.score,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
def search_result_from_row(row: Mapping[str, object]) -> SearchResult:
|
||||
"""Convert a database row mapping into a search result."""
|
||||
return SearchResult(
|
||||
chunk_id=int(row["chunk_id"]),
|
||||
text=str(row["text"]),
|
||||
source_title=str(row["source_title"]),
|
||||
source_author=optional_str(row["source_author"]),
|
||||
chapter_title=optional_str(row["chapter_title"]),
|
||||
page_label=optional_str(row["page_label"]),
|
||||
score=float(row["score"]) if "score" in row else 0.0,
|
||||
vector_score=float(row["score"]) if "score" in row else None,
|
||||
)
|
||||
|
||||
|
||||
def optional_str(value: object) -> str | None:
|
||||
"""Convert nullable database values to optional strings."""
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
|
||||
def tokens(text_value: str) -> list[str]:
|
||||
"""Extract tokens from a text value.
|
||||
|
||||
This is a simple approximation of the tokenization used by PostgreSQL's full-text search,
|
||||
which is sufficient for BM25 candidate retrieval. It lowercases tokens and includes alphanumeric characters and
|
||||
underscores.
|
||||
"""
|
||||
return [match.group(0).lower() for match in TOKEN_RE.finditer(text_value)]
|
||||
|
||||
|
||||
QUERY_STOP_WORDS = {
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"are",
|
||||
"as",
|
||||
"at",
|
||||
"does",
|
||||
"for",
|
||||
"in",
|
||||
"is",
|
||||
"of",
|
||||
"the",
|
||||
"to",
|
||||
"what",
|
||||
"when",
|
||||
"where",
|
||||
"which",
|
||||
"who",
|
||||
"why",
|
||||
}
|
||||
|
||||
|
||||
def retrieval_query_from_text(query: str) -> str:
|
||||
"""Remove generic question words while preserving entity and series terms."""
|
||||
keywords = [token for token in tokens(query) if token not in QUERY_STOP_WORDS]
|
||||
if not keywords:
|
||||
return query
|
||||
return " ".join(keywords)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Runtime timing helpers for EPUB search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from time import perf_counter
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeStep:
|
||||
"""Elapsed runtime for one named search step."""
|
||||
|
||||
name: str
|
||||
duration_ms: float
|
||||
counts_toward_total: bool = True
|
||||
|
||||
|
||||
def runtime_step_from_start(name: str, start_seconds: float) -> RuntimeStep:
|
||||
"""Create a runtime step from a prior perf_counter timestamp."""
|
||||
return RuntimeStep(name=name, duration_ms=(perf_counter() - start_seconds) * 1000)
|
||||
|
||||
|
||||
def timed_result[T, **P](
|
||||
name: str,
|
||||
operation: Callable[P, T],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> tuple[T, RuntimeStep]:
|
||||
"""Run an operation and return its result plus elapsed runtime."""
|
||||
start_seconds = perf_counter()
|
||||
result = operation(*args, **kwargs)
|
||||
return result, runtime_step_from_start(name, start_seconds)
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Reusable FastAPI tools."""
|
||||
|
||||
from python.fastapi_tools.db import DbSession, get_db
|
||||
from python.fastapi_tools.zstd_middleware import ZstdMiddleware
|
||||
|
||||
__all__ = ["DbSession", "ZstdMiddleware", "get_db"]
|
||||
@@ -1,9 +1,13 @@
|
||||
"""ORM package exports."""
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevBase
|
||||
from python.orm.richie.base import RichieBase
|
||||
from python.orm.signal_bot.base import SignalBotBase
|
||||
from python.orm.van_inventory.base import VanInventoryBase
|
||||
|
||||
__all__ = [
|
||||
"DataScienceDevBase",
|
||||
"RichieBase",
|
||||
"SignalBotBase",
|
||||
"VanInventoryBase",
|
||||
]
|
||||
|
||||
+2
-24
@@ -31,24 +31,8 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
||||
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
||||
|
||||
|
||||
def get_postgres_engine(
|
||||
*,
|
||||
name: str = "POSTGRES",
|
||||
pool_pre_ping: bool = True,
|
||||
vector_engine: bool = False,
|
||||
) -> Engine:
|
||||
"""Create a SQLAlchemy engine from environment variables.
|
||||
|
||||
Args:
|
||||
name (str, optional): The name of the environment variable prefix. Defaults to "POSTGRES".
|
||||
pool_pre_ping (bool, optional): Whether to ping the database before each connection. Defaults to True.
|
||||
This fixes the issue of trying to use a conection that has timed out on the database side.
|
||||
vector_engine (bool, optional): Whether to use the vector search schema. Defaults to False.
|
||||
This updates the search path the incldued the vecore types and operators.
|
||||
|
||||
Returns:
|
||||
Engine: The SQLAlchemy engine.
|
||||
"""
|
||||
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
|
||||
"""Create a SQLAlchemy engine from environment variables."""
|
||||
database, host, port, username, password = get_connection_info(name)
|
||||
|
||||
url = URL.create(
|
||||
@@ -60,14 +44,8 @@ def get_postgres_engine(
|
||||
database=database,
|
||||
)
|
||||
|
||||
connect_args = {}
|
||||
# There more better way to do this is with separate PG account and a dedicated vector schema for the vector types
|
||||
if vector_engine:
|
||||
connect_args["options"] = "-csearch_path=main,public"
|
||||
|
||||
return create_engine(
|
||||
url=url,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
pool_recycle=1800,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Data science dev database ORM exports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase, DataScienceDevTableBaseBig
|
||||
|
||||
__all__ = [
|
||||
"DataScienceDevBase",
|
||||
"DataScienceDevTableBase",
|
||||
"DataScienceDevTableBaseBig",
|
||||
]
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Data science dev database ORM base."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, MetaData, func
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from python.orm.common import NAMING_CONVENTION
|
||||
|
||||
|
||||
class DataScienceDevBase(DeclarativeBase):
|
||||
"""Base class for data_science_dev database ORM models."""
|
||||
|
||||
schema_name = "main"
|
||||
|
||||
metadata = MetaData(
|
||||
schema=schema_name,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
)
|
||||
|
||||
|
||||
class _TableMixin:
|
||||
"""Shared timestamp columns for all table bases."""
|
||||
|
||||
created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
|
||||
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
|
||||
"""Table with Integer primary key."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
|
||||
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
|
||||
"""Table with BigInteger primary key."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
@@ -0,0 +1,14 @@
|
||||
"""init."""
|
||||
|
||||
from python.orm.data_science_dev.congress.bill import Bill, BillText
|
||||
from python.orm.data_science_dev.congress.legislator import Legislator, LegislatorSocialMedia
|
||||
from python.orm.data_science_dev.congress.vote import Vote, VoteRecord
|
||||
|
||||
__all__ = [
|
||||
"Bill",
|
||||
"BillText",
|
||||
"Legislator",
|
||||
"LegislatorSocialMedia",
|
||||
"Vote",
|
||||
"VoteRecord",
|
||||
]
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Bill model - legislation introduced in Congress."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.orm.data_science_dev.congress.vote import Vote
|
||||
|
||||
|
||||
class Bill(DataScienceDevTableBase):
|
||||
"""Legislation with congress number, type, titles, status, and sponsor."""
|
||||
|
||||
__tablename__ = "bill"
|
||||
|
||||
congress: Mapped[int]
|
||||
bill_type: Mapped[str]
|
||||
number: Mapped[int]
|
||||
|
||||
title: Mapped[str | None]
|
||||
title_short: Mapped[str | None]
|
||||
official_title: Mapped[str | None]
|
||||
|
||||
status: Mapped[str | None]
|
||||
status_at: Mapped[date | None]
|
||||
|
||||
sponsor_bioguide_id: Mapped[str | None]
|
||||
|
||||
subjects_top_term: Mapped[str | None]
|
||||
|
||||
votes: Mapped[list[Vote]] = relationship(
|
||||
"Vote",
|
||||
back_populates="bill",
|
||||
)
|
||||
bill_texts: Mapped[list[BillText]] = relationship(
|
||||
"BillText",
|
||||
back_populates="bill",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
|
||||
Index("ix_bill_congress", "congress"),
|
||||
)
|
||||
|
||||
|
||||
class BillText(DataScienceDevTableBase):
|
||||
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
|
||||
|
||||
__tablename__ = "bill_text"
|
||||
|
||||
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
|
||||
version_code: Mapped[str]
|
||||
version_name: Mapped[str | None]
|
||||
text_content: Mapped[str | None]
|
||||
date: Mapped[date | None]
|
||||
|
||||
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
|
||||
|
||||
__table_args__ = (UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Legislator model - members of Congress."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.orm.data_science_dev.congress.vote import VoteRecord
|
||||
|
||||
|
||||
class Legislator(DataScienceDevTableBase):
|
||||
"""Members of Congress with identification and current term info."""
|
||||
|
||||
__tablename__ = "legislator"
|
||||
|
||||
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
|
||||
|
||||
thomas_id: Mapped[str | None]
|
||||
lis_id: Mapped[str | None]
|
||||
govtrack_id: Mapped[int | None]
|
||||
opensecrets_id: Mapped[str | None]
|
||||
fec_ids: Mapped[str | None]
|
||||
|
||||
first_name: Mapped[str]
|
||||
last_name: Mapped[str]
|
||||
official_full_name: Mapped[str | None]
|
||||
nickname: Mapped[str | None]
|
||||
|
||||
birthday: Mapped[date | None]
|
||||
gender: Mapped[str | None]
|
||||
|
||||
current_party: Mapped[str | None]
|
||||
current_state: Mapped[str | None]
|
||||
current_district: Mapped[int | None]
|
||||
current_chamber: Mapped[str | None]
|
||||
|
||||
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
|
||||
"LegislatorSocialMedia",
|
||||
back_populates="legislator",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||
"VoteRecord",
|
||||
back_populates="legislator",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class LegislatorSocialMedia(DataScienceDevTableBase):
|
||||
"""Social media account linked to a legislator."""
|
||||
|
||||
__tablename__ = "legislator_social_media"
|
||||
|
||||
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
|
||||
platform: Mapped[str]
|
||||
account_name: Mapped[str]
|
||||
url: Mapped[str | None]
|
||||
source: Mapped[str]
|
||||
|
||||
legislator: Mapped[Legislator] = relationship(back_populates="social_media_accounts")
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Vote model - roll call votes in Congress."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.orm.data_science_dev.congress.bill import Bill
|
||||
from python.orm.data_science_dev.congress.legislator import Legislator
|
||||
from python.orm.data_science_dev.congress.vote import Vote
|
||||
|
||||
|
||||
class VoteRecord(DataScienceDevBase):
|
||||
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
|
||||
|
||||
__tablename__ = "vote_record"
|
||||
|
||||
vote_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
legislator_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
position: Mapped[str]
|
||||
|
||||
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
|
||||
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")
|
||||
|
||||
|
||||
class Vote(DataScienceDevTableBase):
|
||||
"""Roll call votes with counts and optional bill linkage."""
|
||||
|
||||
__tablename__ = "vote"
|
||||
|
||||
congress: Mapped[int]
|
||||
chamber: Mapped[str]
|
||||
session: Mapped[int]
|
||||
number: Mapped[int]
|
||||
|
||||
vote_type: Mapped[str | None]
|
||||
question: Mapped[str | None]
|
||||
result: Mapped[str | None]
|
||||
result_text: Mapped[str | None]
|
||||
|
||||
vote_date: Mapped[date]
|
||||
|
||||
yea_count: Mapped[int | None]
|
||||
nay_count: Mapped[int | None]
|
||||
not_voting_count: Mapped[int | None]
|
||||
present_count: Mapped[int | None]
|
||||
|
||||
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
|
||||
|
||||
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
|
||||
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||
"VoteRecord",
|
||||
back_populates="vote",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"congress",
|
||||
"chamber",
|
||||
"session",
|
||||
"number",
|
||||
name="uq_vote_congress_chamber_session_number",
|
||||
),
|
||||
Index("ix_vote_date", "vote_date"),
|
||||
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Data science dev database ORM models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
|
||||
from python.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
|
||||
from python.orm.data_science_dev.posts.tables import Posts
|
||||
|
||||
__all__ = [
|
||||
"Bill",
|
||||
"BillText",
|
||||
"Legislator",
|
||||
"Posts",
|
||||
"Vote",
|
||||
"VoteRecord",
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Posts module — weekly-partitioned posts table and partition ORM models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
|
||||
from python.orm.data_science_dev.posts.tables import Posts
|
||||
|
||||
__all__ = [
|
||||
"FailedIngestion",
|
||||
"Posts",
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Shared column definitions for the posts partitioned table family."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, SmallInteger, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class PostsColumns:
|
||||
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
|
||||
|
||||
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger)
|
||||
instance: Mapped[str]
|
||||
date: Mapped[datetime] = mapped_column(primary_key=True)
|
||||
text: Mapped[str] = mapped_column(Text)
|
||||
langs: Mapped[str | None]
|
||||
like_count: Mapped[int]
|
||||
reply_count: Mapped[int]
|
||||
repost_count: Mapped[int]
|
||||
reply_to: Mapped[int | None] = mapped_column(BigInteger)
|
||||
replied_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||
thread_root: Mapped[int | None] = mapped_column(BigInteger)
|
||||
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||
repost_from: Mapped[int | None] = mapped_column(BigInteger)
|
||||
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||
quotes: Mapped[int | None] = mapped_column(BigInteger)
|
||||
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
|
||||
labels: Mapped[str | None]
|
||||
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
|
||||
sent_score: Mapped[float | None]
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Table for storing JSONL lines that failed during post ingestion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevTableBase
|
||||
|
||||
|
||||
class FailedIngestion(DataScienceDevTableBase):
|
||||
"""Stores raw JSONL lines and their error messages when ingestion fails."""
|
||||
|
||||
__tablename__ = "failed_ingestion"
|
||||
|
||||
raw_line: Mapped[str] = mapped_column(Text)
|
||||
error: Mapped[str] = mapped_column(Text)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Dynamically generated ORM classes for each weekly partition of the posts table.
|
||||
|
||||
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
|
||||
These are real ORM models tracked by Alembic autogenerate.
|
||||
|
||||
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
|
||||
52 or 53 weeks, and week boundaries are always Monday to Monday.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevBase
|
||||
from python.orm.data_science_dev.posts.columns import PostsColumns
|
||||
|
||||
PARTITION_START_YEAR = 2023
|
||||
PARTITION_END_YEAR = 2026
|
||||
|
||||
_current_module = sys.modules[__name__]
|
||||
|
||||
|
||||
def iso_weeks_in_year(year: int) -> int:
|
||||
"""Return the number of ISO weeks in a given year (52 or 53)."""
|
||||
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
|
||||
return dec_28.isocalendar().week
|
||||
|
||||
|
||||
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
|
||||
"""Return (start, end) datetimes for an ISO week.
|
||||
|
||||
Start = Monday 00:00:00 UTC of the given ISO week.
|
||||
End = Monday 00:00:00 UTC of the following ISO week.
|
||||
"""
|
||||
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
|
||||
if week < iso_weeks_in_year(year):
|
||||
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
|
||||
else:
|
||||
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
|
||||
return start, end
|
||||
|
||||
|
||||
def _build_partition_classes() -> dict[str, type]:
|
||||
"""Generate one ORM class per ISO week partition."""
|
||||
classes: dict[str, type] = {}
|
||||
|
||||
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
|
||||
for week in range(1, iso_weeks_in_year(year) + 1):
|
||||
class_name = f"PostsWeek{year}W{week:02d}"
|
||||
table_name = f"posts_{year}_{week:02d}"
|
||||
|
||||
partition_class = type(
|
||||
class_name,
|
||||
(PostsColumns, DataScienceDevBase),
|
||||
{
|
||||
"__tablename__": table_name,
|
||||
"__table_args__": ({"implicit_returning": False},),
|
||||
},
|
||||
)
|
||||
|
||||
classes[class_name] = partition_class
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
# Generate all partition classes and register them on this module
|
||||
_partition_classes = _build_partition_classes()
|
||||
for _name, _cls in _partition_classes.items():
|
||||
setattr(_current_module, _name, _cls)
|
||||
__all__ = list(_partition_classes.keys())
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.data_science_dev.base import DataScienceDevBase
|
||||
from python.orm.data_science_dev.posts.columns import PostsColumns
|
||||
|
||||
|
||||
class Posts(PostsColumns, DataScienceDevBase):
|
||||
"""Parent partitioned table for posts, partitioned by week on `date`."""
|
||||
|
||||
__tablename__ = "posts"
|
||||
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.richie.audiobook import Audiobook, AudiobookAuthor, AudiobookSeries
|
||||
from python.orm.richie.base import RichieBase, TableBase, TableBaseBig, TableBaseSmall
|
||||
from python.orm.richie.contact import (
|
||||
Contact,
|
||||
@@ -11,30 +10,11 @@ from python.orm.richie.contact import (
|
||||
Need,
|
||||
RelationshipType,
|
||||
)
|
||||
from python.orm.richie.ebook import (
|
||||
EbookChapter,
|
||||
EbookChunk,
|
||||
EbookChunkEmbedding1024,
|
||||
EbookChunkEmbedding2560,
|
||||
EbookChunkEmbedding4096,
|
||||
EbookEmbeddingModel,
|
||||
EbookSource,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Audiobook",
|
||||
"AudiobookAuthor",
|
||||
"AudiobookSeries",
|
||||
"Contact",
|
||||
"ContactNeed",
|
||||
"ContactRelationship",
|
||||
"EbookChapter",
|
||||
"EbookChunk",
|
||||
"EbookChunkEmbedding1024",
|
||||
"EbookChunkEmbedding2560",
|
||||
"EbookChunkEmbedding4096",
|
||||
"EbookEmbeddingModel",
|
||||
"EbookSource",
|
||||
"Need",
|
||||
"RelationshipType",
|
||||
"RichieBase",
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Audiobook catalog models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.richie.base import TableBase
|
||||
|
||||
|
||||
class AudiobookAuthor(TableBase):
|
||||
"""Canonical audiobook author."""
|
||||
|
||||
__tablename__ = "audiobook_author"
|
||||
__table_args__ = (UniqueConstraint("name"),)
|
||||
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
|
||||
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="author")
|
||||
series: Mapped[list[AudiobookSeries]] = relationship("AudiobookSeries", back_populates="author")
|
||||
|
||||
|
||||
class AudiobookSeries(TableBase):
|
||||
"""Canonical audiobook series."""
|
||||
|
||||
__tablename__ = "audiobook_series"
|
||||
__table_args__ = (UniqueConstraint("author_id", "name"),)
|
||||
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
|
||||
|
||||
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="series")
|
||||
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="series")
|
||||
|
||||
|
||||
class Audiobook(TableBase):
|
||||
"""Canonical audiobook title."""
|
||||
|
||||
__tablename__ = "audiobook"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"author_id",
|
||||
"series_id",
|
||||
"title",
|
||||
postgresql_nulls_not_distinct=True,
|
||||
),
|
||||
)
|
||||
|
||||
title: Mapped[str] = mapped_column(String)
|
||||
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
|
||||
series_id: Mapped[int | None] = mapped_column(ForeignKey("main.audiobook_series.id", ondelete="SET NULL"))
|
||||
series_index: Mapped[float] = mapped_column(default=0.0)
|
||||
|
||||
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="books")
|
||||
series: Mapped[AudiobookSeries | None] = relationship("AudiobookSeries", back_populates="books")
|
||||
@@ -1,138 +0,0 @@
|
||||
"""EPUB search models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.richie.base import TableBase, TableBaseBig
|
||||
|
||||
|
||||
class EbookSource(TableBase):
|
||||
"""One indexed EPUB file."""
|
||||
|
||||
__tablename__ = "ebook_source"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("file_path"),
|
||||
UniqueConstraint("file_sha256"),
|
||||
)
|
||||
|
||||
title: Mapped[str]
|
||||
author: Mapped[str | None]
|
||||
language: Mapped[str | None]
|
||||
publisher: Mapped[str | None]
|
||||
identifier: Mapped[str | None]
|
||||
file_path: Mapped[str]
|
||||
file_sha256: Mapped[str] = mapped_column(String(64))
|
||||
file_mtime: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
file_size: Mapped[int] = mapped_column(BigInteger)
|
||||
|
||||
chapters: Mapped[list[EbookChapter]] = relationship(
|
||||
"EbookChapter",
|
||||
back_populates="source",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
chunks: Mapped[list[EbookChunk]] = relationship(
|
||||
"EbookChunk",
|
||||
back_populates="source",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
|
||||
class EbookChapter(TableBase):
|
||||
"""A chapter or spine document inside an EPUB."""
|
||||
|
||||
__tablename__ = "ebook_chapter"
|
||||
__table_args__ = (UniqueConstraint("source_id", "spine_index"),)
|
||||
|
||||
source_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_source.id", ondelete="CASCADE"))
|
||||
spine_index: Mapped[int]
|
||||
title: Mapped[str | None]
|
||||
href: Mapped[str | None]
|
||||
|
||||
source: Mapped[EbookSource] = relationship("EbookSource", back_populates="chapters")
|
||||
chunks: Mapped[list[EbookChunk]] = relationship(
|
||||
"EbookChunk",
|
||||
back_populates="chapter",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
|
||||
class EbookChunk(TableBaseBig):
|
||||
"""A searchable text chunk."""
|
||||
|
||||
__tablename__ = "ebook_chunk"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("source_id", "chunk_index", name="uq_ebook_chunk_source_id_chunk_index"),
|
||||
UniqueConstraint("source_id", "content_sha256", name="uq_ebook_chunk_source_id_content_sha256"),
|
||||
)
|
||||
|
||||
source_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_source.id", ondelete="CASCADE"))
|
||||
chapter_id: Mapped[int | None] = mapped_column(ForeignKey("main.ebook_chapter.id", ondelete="SET NULL"))
|
||||
chunk_index: Mapped[int]
|
||||
text: Mapped[str]
|
||||
token_start: Mapped[int]
|
||||
token_count: Mapped[int]
|
||||
page_label: Mapped[str | None]
|
||||
content_sha256: Mapped[str] = mapped_column(String(64))
|
||||
search_text: Mapped[str]
|
||||
|
||||
source: Mapped[EbookSource] = relationship("EbookSource", back_populates="chunks")
|
||||
chapter: Mapped[EbookChapter | None] = relationship("EbookChapter", back_populates="chunks")
|
||||
|
||||
|
||||
class EbookEmbeddingModel(TableBase):
|
||||
"""A supported embedding model."""
|
||||
|
||||
__tablename__ = "ebook_embedding_model"
|
||||
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
dimension: Mapped[int]
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
|
||||
class EbookChunkEmbedding1024(TableBaseBig):
|
||||
"""1024-dimensional chunk embedding."""
|
||||
|
||||
__tablename__ = "ebook_chunk_embedding_1024"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("chunk_id", "model_id"),
|
||||
Index(
|
||||
"ix_ebook_chunk_embedding_1024_embedding_cosine",
|
||||
"embedding",
|
||||
postgresql_using="hnsw",
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
),
|
||||
)
|
||||
|
||||
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
|
||||
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
|
||||
embedding: Mapped[list[float]] = mapped_column(Vector(1024))
|
||||
|
||||
|
||||
class EbookChunkEmbedding2560(TableBaseBig):
|
||||
"""2560-dimensional chunk embedding."""
|
||||
|
||||
__tablename__ = "ebook_chunk_embedding_2560"
|
||||
__table_args__ = (UniqueConstraint("chunk_id", "model_id"),)
|
||||
|
||||
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
|
||||
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
|
||||
embedding: Mapped[list[float]] = mapped_column(Vector(2560))
|
||||
|
||||
|
||||
class EbookChunkEmbedding4096(TableBaseBig):
|
||||
"""4096-dimensional chunk embedding."""
|
||||
|
||||
__tablename__ = "ebook_chunk_embedding_4096"
|
||||
__table_args__ = (UniqueConstraint("chunk_id", "model_id"),)
|
||||
|
||||
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
|
||||
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
|
||||
embedding: Mapped[list[float]] = mapped_column(Vector(4096))
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Signal bot database ORM exports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.signal_bot.base import SignalBotBase, SignalBotTableBase, SignalBotTableBaseSmall
|
||||
from python.orm.signal_bot.models import DeadLetterMessage, DeviceRole, RoleRecord, SignalDevice
|
||||
|
||||
__all__ = [
|
||||
"DeadLetterMessage",
|
||||
"DeviceRole",
|
||||
"RoleRecord",
|
||||
"SignalBotBase",
|
||||
"SignalBotTableBase",
|
||||
"SignalBotTableBaseSmall",
|
||||
"SignalDevice",
|
||||
]
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Signal bot database ORM base."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, MetaData, SmallInteger, func
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from python.orm.common import NAMING_CONVENTION
|
||||
|
||||
|
||||
class SignalBotBase(DeclarativeBase):
|
||||
"""Base class for signal_bot database ORM models."""
|
||||
|
||||
schema_name = "main"
|
||||
|
||||
metadata = MetaData(
|
||||
schema=schema_name,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
)
|
||||
|
||||
|
||||
class _TableMixin:
|
||||
"""Shared timestamp columns for all table bases."""
|
||||
|
||||
created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
|
||||
class SignalBotTableBaseSmall(_TableMixin, AbstractConcreteBase, SignalBotBase):
|
||||
"""Table with SmallInteger primary key."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(SmallInteger, primary_key=True)
|
||||
|
||||
|
||||
class SignalBotTableBase(_TableMixin, AbstractConcreteBase, SignalBotBase):
|
||||
"""Table with Integer primary key."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Signal bot device, role, and dead letter ORM models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
|
||||
from python.signal_bot.models import MessageStatus, TrustLevel
|
||||
|
||||
|
||||
class RoleRecord(SignalBotTableBaseSmall):
|
||||
"""Lookup table for RBAC roles, keyed by smallint."""
|
||||
|
||||
__tablename__ = "role"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
class DeviceRole(SignalBotTableBase):
|
||||
"""Association between a device and a role."""
|
||||
|
||||
__tablename__ = "device_role"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"),
|
||||
{"schema": "main"},
|
||||
)
|
||||
|
||||
device_id: Mapped[int] = mapped_column(ForeignKey("main.signal_device.id"))
|
||||
role_id: Mapped[int] = mapped_column(SmallInteger, ForeignKey("main.role.id"))
|
||||
|
||||
|
||||
class SignalDevice(SignalBotTableBase):
|
||||
"""A Signal device tracked by phone number and safety number."""
|
||||
|
||||
__tablename__ = "signal_device"
|
||||
|
||||
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
safety_number: Mapped[str | None]
|
||||
trust_level: Mapped[TrustLevel] = mapped_column(
|
||||
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
|
||||
default=TrustLevel.UNVERIFIED,
|
||||
)
|
||||
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
roles: Mapped[list[RoleRecord]] = relationship(secondary=DeviceRole.__table__)
|
||||
|
||||
|
||||
class DeadLetterMessage(SignalBotTableBase):
|
||||
"""A Signal message that failed processing and was sent to the dead letter queue."""
|
||||
|
||||
__tablename__ = "dead_letter_message"
|
||||
|
||||
source: Mapped[str]
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
status: Mapped[MessageStatus] = mapped_column(
|
||||
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
|
||||
default=MessageStatus.UNPROCESSED,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Signal command and control bot."""
|
||||
@@ -0,0 +1 @@
|
||||
"""Signal bot commands."""
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Van inventory command — parse receipts and item lists via LLM, push to API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from python.signal_bot.models import InventoryItem, InventoryUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.signal_bot.llm_client import LLMClient
|
||||
from python.signal_bot.models import SignalMessage
|
||||
from python.signal_bot.signal_client import SignalClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are an inventory assistant. Extract items from the input and return ONLY
|
||||
a JSON array. Each element must have these fields:
|
||||
- "name": item name (string)
|
||||
- "quantity": numeric count or amount (default 1)
|
||||
- "unit": unit of measure (e.g. "each", "lb", "oz", "gallon", "bag", "box")
|
||||
- "category": category like "food", "tools", "supplies", etc.
|
||||
- "notes": any extra detail (empty string if none)
|
||||
|
||||
Example output:
|
||||
[{"name": "water bottles", "quantity": 6, "unit": "gallon", "category": "supplies", "notes": "1 gallon each"}]
|
||||
|
||||
Return ONLY the JSON array, no other text.\
|
||||
"""
|
||||
|
||||
IMAGE_PROMPT = "Extract all items from this receipt or inventory photo."
|
||||
TEXT_PROMPT = "Extract all items from this inventory list."
|
||||
|
||||
|
||||
def parse_llm_response(raw: str) -> list[InventoryItem]:
|
||||
"""Parse the LLM JSON response into InventoryItem list."""
|
||||
text = raw.strip()
|
||||
# Strip markdown code fences if present
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if not line.startswith("```")]
|
||||
text = "\n".join(lines)
|
||||
|
||||
items_data: list[dict[str, Any]] = json.loads(text)
|
||||
return [InventoryItem.model_validate(item) for item in items_data]
|
||||
|
||||
|
||||
def _upsert_item(api_url: str, item: InventoryItem) -> None:
|
||||
"""Create or update an item via the van_inventory API.
|
||||
|
||||
Fetches existing items, and if one with the same name exists,
|
||||
patches its quantity (summing). Otherwise creates a new item.
|
||||
"""
|
||||
base = api_url.rstrip("/")
|
||||
response = httpx.get(f"{base}/api/items", timeout=10)
|
||||
response.raise_for_status()
|
||||
existing: list[dict[str, Any]] = response.json()
|
||||
|
||||
match = next((e for e in existing if e["name"].lower() == item.name.lower()), None)
|
||||
|
||||
if match:
|
||||
new_qty = match["quantity"] + item.quantity
|
||||
patch = {"quantity": new_qty}
|
||||
if item.category:
|
||||
patch["category"] = item.category
|
||||
response = httpx.patch(f"{base}/api/items/{match['id']}", json=patch, timeout=10)
|
||||
response.raise_for_status()
|
||||
return
|
||||
payload = {
|
||||
"name": item.name,
|
||||
"quantity": item.quantity,
|
||||
"unit": item.unit,
|
||||
"category": item.category or None,
|
||||
}
|
||||
response = httpx.post(f"{base}/api/items", json=payload, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def handle_inventory_update(
|
||||
message: SignalMessage,
|
||||
signal: SignalClient,
|
||||
llm: LLMClient,
|
||||
api_url: str,
|
||||
) -> InventoryUpdate:
|
||||
"""Process an inventory update from a Signal message.
|
||||
|
||||
Accepts either an image (receipt photo) or text list.
|
||||
Uses the LLM to extract structured items, then pushes to the van_inventory API.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Processing inventory update from {message.source}")
|
||||
if message.attachments:
|
||||
image_data = signal.get_attachment(message.attachments[0])
|
||||
raw_response = llm.chat(
|
||||
IMAGE_PROMPT,
|
||||
image_data=image_data,
|
||||
system=SYSTEM_PROMPT,
|
||||
)
|
||||
source_type = "receipt_photo"
|
||||
elif message.message.strip():
|
||||
raw_response = llm.chat(
|
||||
f"{TEXT_PROMPT}\n\n{message.message}",
|
||||
system=SYSTEM_PROMPT,
|
||||
)
|
||||
source_type = "text_list"
|
||||
else:
|
||||
signal.reply(message, "Send a photo of a receipt or a text list of items to update inventory.")
|
||||
return InventoryUpdate()
|
||||
|
||||
logger.info(f"{raw_response=}")
|
||||
|
||||
new_items = parse_llm_response(raw_response)
|
||||
|
||||
logger.info(f"{new_items=}")
|
||||
|
||||
for item in new_items:
|
||||
_upsert_item(api_url, item)
|
||||
|
||||
summary = _format_summary(new_items)
|
||||
signal.reply(message, f"Inventory updated with {len(new_items)} item(s):\n{summary}")
|
||||
|
||||
return InventoryUpdate(items=new_items, raw_response=raw_response, source_type=source_type)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to process inventory update")
|
||||
signal.reply(message, "Failed to process inventory update. Check logs for details.")
|
||||
return InventoryUpdate()
|
||||
|
||||
|
||||
def _format_summary(items: list[InventoryItem]) -> str:
|
||||
"""Format items into a readable summary."""
|
||||
lines = [f" - {item.name} x{item.quantity} {item.unit} [{item.category}]" for item in items]
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Location command for the Signal bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.signal_bot.models import SignalMessage
|
||||
from python.signal_bot.signal_client import SignalClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_entity_state(ha_url: str, ha_token: str, entity_id: str) -> dict[str, Any]:
|
||||
"""Fetch an entity's state from Home Assistant."""
|
||||
entity_url = f"{ha_url}/api/states/{entity_id}"
|
||||
logger.debug(f"Fetching {entity_url=}")
|
||||
response = httpx.get(
|
||||
entity_url,
|
||||
headers={"Authorization": f"Bearer {ha_token}"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _format_location(latitude: str, longitude: str) -> str:
|
||||
"""Render a friendly location response."""
|
||||
return f"Van location: {latitude}, {longitude}\nhttps://maps.google.com/?q={latitude},{longitude}"
|
||||
|
||||
|
||||
def handle_location_request(
|
||||
message: SignalMessage,
|
||||
signal: SignalClient,
|
||||
ha_url: str | None,
|
||||
ha_token: str | None,
|
||||
) -> None:
|
||||
"""Reply with van location from Home Assistant."""
|
||||
if ha_url is None or ha_token is None:
|
||||
signal.reply(message, "Location command is not configured (missing HA_URL or HA_TOKEN).")
|
||||
return
|
||||
|
||||
lat_payload = None
|
||||
lon_payload = None
|
||||
try:
|
||||
lat_payload = _get_entity_state(ha_url, ha_token, "sensor.van_last_known_latitude")
|
||||
lon_payload = _get_entity_state(ha_url, ha_token, "sensor.van_last_known_longitude")
|
||||
except httpx.HTTPError:
|
||||
logger.exception("Couldn't fetch van location from Home Assistant right now.")
|
||||
logger.debug(f"{ha_url=} {lat_payload=} {lon_payload=}")
|
||||
signal.reply(message, "Couldn't fetch van location from Home Assistant right now.")
|
||||
return
|
||||
|
||||
latitude = lat_payload.get("state", "")
|
||||
longitude = lon_payload.get("state", "")
|
||||
|
||||
if not latitude or not longitude or latitude == "unavailable" or longitude == "unavailable":
|
||||
signal.reply(message, "Van location is unavailable in Home Assistant right now.")
|
||||
return
|
||||
|
||||
signal.reply(message, _format_location(latitude, longitude))
|
||||
@@ -0,0 +1,284 @@
|
||||
"""Device registry — tracks verified/unverified devices by safety number."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.common import utcnow
|
||||
from python.orm.signal_bot.models import RoleRecord, SignalDevice
|
||||
from python.signal_bot.models import Role, TrustLevel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from python.signal_bot.signal_client import SignalClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BLOCKED_TTL = timedelta(minutes=60)
|
||||
_DEFAULT_TTL = timedelta(minutes=5)
|
||||
|
||||
|
||||
class _CacheEntry(NamedTuple):
|
||||
expires: datetime
|
||||
trust_level: TrustLevel
|
||||
has_safety_number: bool
|
||||
safety_number: str | None
|
||||
roles: list[Role]
|
||||
|
||||
|
||||
class DeviceRegistry:
|
||||
"""Manage device trust based on Signal safety numbers.
|
||||
|
||||
Devices start as UNVERIFIED. An admin verifies them over SSH by calling
|
||||
``verify(phone_number)`` which marks the device VERIFIED and also tells
|
||||
signal-cli to trust the identity.
|
||||
|
||||
Only VERIFIED devices may execute commands.
|
||||
"""
|
||||
|
||||
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
||||
self.signal_client = signal_client
|
||||
self.engine = engine
|
||||
self._contact_cache: dict[str, _CacheEntry] = {}
|
||||
|
||||
def is_verified(self, phone_number: str) -> bool:
|
||||
"""Check if a phone number is verified."""
|
||||
if entry := self._cached(phone_number):
|
||||
return entry.trust_level == TrustLevel.VERIFIED
|
||||
device = self._load_device(phone_number)
|
||||
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
||||
|
||||
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
||||
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
||||
now = utcnow()
|
||||
|
||||
entry = self._cached(phone_number)
|
||||
if entry and entry.safety_number == safety_number:
|
||||
return
|
||||
|
||||
with Session(self.engine) as session:
|
||||
device = session.scalars(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).one_or_none()
|
||||
|
||||
if device:
|
||||
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
||||
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
|
||||
device.safety_number = safety_number
|
||||
device.trust_level = TrustLevel.UNVERIFIED
|
||||
device.last_seen = now
|
||||
else:
|
||||
device = SignalDevice(
|
||||
phone_number=phone_number,
|
||||
safety_number=safety_number,
|
||||
trust_level=TrustLevel.UNVERIFIED,
|
||||
last_seen=now,
|
||||
)
|
||||
session.add(device)
|
||||
logger.info(f"New device registered: {phone_number}")
|
||||
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
|
||||
def has_safety_number(self, phone_number: str) -> bool:
|
||||
"""Check if a device has a safety number on file."""
|
||||
if entry := self._cached(phone_number):
|
||||
return entry.has_safety_number
|
||||
device = self._load_device(phone_number)
|
||||
return device is not None and device.safety_number is not None
|
||||
|
||||
def verify(self, phone_number: str) -> bool:
|
||||
"""Mark a device as verified. Called by admin over SSH.
|
||||
|
||||
Returns True if the device was found and verified.
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
device = session.scalars(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
||||
return False
|
||||
|
||||
device.trust_level = TrustLevel.VERIFIED
|
||||
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
logger.info(f"Device verified: {phone_number}")
|
||||
return True
|
||||
|
||||
def block(self, phone_number: str) -> bool:
|
||||
"""Block a device."""
|
||||
return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
|
||||
|
||||
def unverify(self, phone_number: str) -> bool:
|
||||
"""Reset a device to unverified."""
|
||||
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
|
||||
|
||||
# -- role management ------------------------------------------------------
|
||||
|
||||
def get_roles(self, phone_number: str) -> list[Role]:
|
||||
"""Return the roles for a device, defaulting to empty."""
|
||||
if entry := self._cached(phone_number):
|
||||
return entry.roles
|
||||
device = self._load_device(phone_number)
|
||||
return _extract_roles(device) if device else []
|
||||
|
||||
def has_role(self, phone_number: str, role: Role) -> bool:
|
||||
"""Check if a device has a specific role or is admin."""
|
||||
roles = self.get_roles(phone_number)
|
||||
return Role.ADMIN in roles or role in roles
|
||||
|
||||
def grant_role(self, phone_number: str, role: Role) -> bool:
|
||||
"""Add a role to a device. Called by admin over SSH."""
|
||||
with Session(self.engine) as session:
|
||||
device = session.scalars(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
||||
return False
|
||||
|
||||
if any(record.name == role for record in device.roles):
|
||||
return True
|
||||
|
||||
role_record = session.scalars(select(RoleRecord).where(RoleRecord.name == role)).one_or_none()
|
||||
|
||||
if not role_record:
|
||||
logger.warning(f"Unknown role: {role}")
|
||||
return False
|
||||
|
||||
device.roles.append(role_record)
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
logger.info(f"Device {phone_number} granted role {role}")
|
||||
return True
|
||||
|
||||
def revoke_role(self, phone_number: str, role: Role) -> bool:
|
||||
"""Remove a role from a device. Called by admin over SSH."""
|
||||
with Session(self.engine) as session:
|
||||
device = session.scalars(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
||||
return False
|
||||
|
||||
device.roles = [record for record in device.roles if record.name != role]
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
logger.info(f"Device {phone_number} revoked role {role}")
|
||||
return True
|
||||
|
||||
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
|
||||
"""Replace all roles for a device. Called by admin over SSH."""
|
||||
with Session(self.engine) as session:
|
||||
device = session.scalars(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
||||
return False
|
||||
|
||||
role_names = [str(role) for role in roles]
|
||||
records = session.scalars(select(RoleRecord).where(RoleRecord.name.in_(role_names))).all()
|
||||
device.roles = records
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
logger.info(f"Device {phone_number} roles set to {role_names}")
|
||||
return True
|
||||
|
||||
# -- queries --------------------------------------------------------------
|
||||
|
||||
def list_devices(self) -> list[SignalDevice]:
|
||||
"""Return all known devices."""
|
||||
with Session(self.engine) as session:
|
||||
return list(session.scalars(select(SignalDevice)).all())
|
||||
|
||||
def sync_identities(self) -> None:
|
||||
"""Pull identity list from signal-cli and record any new ones."""
|
||||
identities = self.signal_client.get_identities()
|
||||
for identity in identities:
|
||||
number = identity.get("number", "")
|
||||
safety = identity.get("safety_number", identity.get("fingerprint", ""))
|
||||
if number:
|
||||
self.record_contact(number, safety)
|
||||
|
||||
# -- internals ------------------------------------------------------------
|
||||
|
||||
def _cached(self, phone_number: str) -> _CacheEntry | None:
|
||||
"""Return the cache entry if it exists and hasn't expired."""
|
||||
entry = self._contact_cache.get(phone_number)
|
||||
if entry and utcnow() < entry.expires:
|
||||
return entry
|
||||
return None
|
||||
|
||||
def _load_device(self, phone_number: str) -> SignalDevice | None:
|
||||
"""Fetch a device by phone number (with joined roles)."""
|
||||
with Session(self.engine) as session:
|
||||
return session.scalars(select(SignalDevice).where(SignalDevice.phone_number == phone_number)).one_or_none()
|
||||
|
||||
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
|
||||
"""Refresh the cache entry for a device."""
|
||||
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
||||
self._contact_cache[phone_number] = _CacheEntry(
|
||||
expires=utcnow() + ttl,
|
||||
trust_level=device.trust_level,
|
||||
has_safety_number=device.safety_number is not None,
|
||||
safety_number=device.safety_number,
|
||||
roles=_extract_roles(device),
|
||||
)
|
||||
|
||||
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
||||
"""Update the trust level for a device."""
|
||||
with Session(self.engine) as session:
|
||||
device = session.scalars(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).one_or_none()
|
||||
|
||||
if not device:
|
||||
return False
|
||||
|
||||
device.trust_level = level
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
if log_msg:
|
||||
logger.info(f"{log_msg}: {phone_number}")
|
||||
return True
|
||||
|
||||
|
||||
def _extract_roles(device: SignalDevice) -> list[Role]:
|
||||
"""Convert a device's RoleRecord objects to a list of Role enums."""
|
||||
return [Role(record.name) for record in device.roles]
|
||||
|
||||
|
||||
def sync_roles(engine: Engine) -> None:
|
||||
"""Sync the Role enum to the role table, adding new and removing stale entries."""
|
||||
expected = {role.value for role in Role}
|
||||
|
||||
with Session(engine) as session:
|
||||
existing = set(session.scalars(select(RoleRecord.name)).all())
|
||||
|
||||
to_add = expected - existing
|
||||
to_remove = existing - expected
|
||||
|
||||
for name in to_add:
|
||||
session.add(RoleRecord(name=name))
|
||||
logger.info(f"Role added: {name}")
|
||||
|
||||
if to_remove:
|
||||
session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove)))
|
||||
for name in to_remove:
|
||||
logger.info(f"Role removed: {name}")
|
||||
|
||||
session.commit()
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Flexible LLM client for ollama backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Self
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Talk to an ollama instance.
|
||||
|
||||
Args:
|
||||
model: Ollama model name.
|
||||
host: Ollama host.
|
||||
port: Ollama port.
|
||||
temperature: Sampling temperature.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
host: str,
|
||||
port: int = 11434,
|
||||
temperature: float = 0.1,
|
||||
timeout: int = 300,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
|
||||
|
||||
def chat(self, prompt: str, image_data: bytes | None = None, system: str | None = None) -> str:
|
||||
"""Send a text prompt and return the response."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
|
||||
user_msg = {"role": "user", "content": prompt}
|
||||
if image_data:
|
||||
user_msg["images"] = [base64.b64encode(image_data).decode()]
|
||||
|
||||
messages.append(user_msg)
|
||||
return self._generate(messages)
|
||||
|
||||
def _generate(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Call the ollama chat API."""
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"temperature": self.temperature},
|
||||
}
|
||||
logger.info(f"LLM request to {self.model}")
|
||||
response = self._client.post("/api/chat", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["message"]["content"]
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""List available models on the ollama instance."""
|
||||
response = self._client.get("/api/tags")
|
||||
response.raise_for_status()
|
||||
return [m["name"] for m in response.json().get("models", [])]
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
"""Close the HTTP client on exit."""
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
self._client.close()
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Signal command and control bot — main entry point."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import typer
|
||||
from alembic.command import upgrade
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from python.common import configure_logger, utcnow
|
||||
from python.database_cli import DATABASES
|
||||
from python.orm.common import get_postgres_engine
|
||||
from python.orm.signal_bot.models import DeadLetterMessage
|
||||
from python.signal_bot.commands.inventory import handle_inventory_update
|
||||
from python.signal_bot.commands.location import handle_location_request
|
||||
from python.signal_bot.device_registry import DeviceRegistry, sync_roles
|
||||
from python.signal_bot.llm_client import LLMClient
|
||||
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
|
||||
from python.signal_bot.signal_client import SignalClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Command:
|
||||
"""A registered bot command."""
|
||||
|
||||
action: Callable[[SignalMessage, str], None]
|
||||
help_text: str
|
||||
role: Role | None # None = no role required (always allowed)
|
||||
|
||||
|
||||
class Bot:
|
||||
"""Holds shared resources and dispatches incoming messages to command handlers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: SignalClient,
|
||||
llm: LLMClient,
|
||||
registry: DeviceRegistry,
|
||||
config: BotConfig,
|
||||
) -> None:
|
||||
self.signal = signal
|
||||
self.llm = llm
|
||||
self.registry = registry
|
||||
self.config = config
|
||||
self.commands: dict[str, Command] = {
|
||||
"help": Command(action=self._help, help_text="show this help message", role=None),
|
||||
"status": Command(action=self._status, help_text="show bot status", role=Role.STATUS),
|
||||
"inventory": Command(
|
||||
action=self._inventory,
|
||||
help_text="update van inventory from a text list or receipt photo",
|
||||
role=Role.INVENTORY,
|
||||
),
|
||||
"location": Command(
|
||||
action=self._location,
|
||||
help_text="get current van location",
|
||||
role=Role.LOCATION,
|
||||
),
|
||||
}
|
||||
|
||||
# -- actions --------------------------------------------------------------
|
||||
|
||||
def _help(self, message: SignalMessage, _cmd: str) -> None:
|
||||
"""Return help text filtered to the sender's roles."""
|
||||
self.signal.reply(message, self._build_help(self.registry.get_roles(message.source)))
|
||||
|
||||
def _status(self, message: SignalMessage, _cmd: str) -> None:
|
||||
"""Return the status of the bot."""
|
||||
models = self.llm.list_models()
|
||||
model_list = ", ".join(models[:10])
|
||||
device_count = len(self.registry.list_devices())
|
||||
self.signal.reply(
|
||||
message,
|
||||
f"Bot online.\nLLM: {self.llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}",
|
||||
)
|
||||
|
||||
def _inventory(self, message: SignalMessage, _cmd: str) -> None:
|
||||
"""Process an inventory update."""
|
||||
handle_inventory_update(message, self.signal, self.llm, self.config.inventory_api_url)
|
||||
|
||||
def _location(self, message: SignalMessage, _cmd: str) -> None:
|
||||
"""Reply with current van location."""
|
||||
handle_location_request(message, self.signal, self.config.ha_url, self.config.ha_token)
|
||||
|
||||
# -- dispatch -------------------------------------------------------------
|
||||
|
||||
def _build_help(self, roles: list[Role]) -> str:
|
||||
"""Build help text showing only the commands the user can access."""
|
||||
is_admin = Role.ADMIN in roles
|
||||
lines = ["Available commands:"]
|
||||
for name, cmd in self.commands.items():
|
||||
if cmd.role is None or is_admin or cmd.role in roles:
|
||||
lines.append(f" {name:20s} — {cmd.help_text}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def dispatch(self, message: SignalMessage) -> None:
|
||||
"""Route an incoming message to the right command handler."""
|
||||
source = message.source
|
||||
|
||||
if not self.registry.is_verified(source):
|
||||
logger.info(f"Device {source} not verified, ignoring message")
|
||||
return
|
||||
|
||||
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
|
||||
logger.warning(f"Admin device {source} missing safety number, ignoring message")
|
||||
return
|
||||
|
||||
text = message.message.strip()
|
||||
parts = text.split()
|
||||
|
||||
if not parts and not message.attachments:
|
||||
return
|
||||
|
||||
cmd = parts[0].lower() if parts else ""
|
||||
|
||||
logger.info(f"f{source=} running {cmd=} with {message=}")
|
||||
|
||||
command = self.commands.get(cmd)
|
||||
if command is None:
|
||||
if message.attachments:
|
||||
command = self.commands["inventory"]
|
||||
cmd = "inventory"
|
||||
else:
|
||||
return
|
||||
|
||||
if command.role is not None and not self.registry.has_role(source, command.role):
|
||||
logger.warning(f"Device {source} denied access to {cmd!r}")
|
||||
self.signal.reply(message, f"Permission denied: you do not have the '{command.role}' role.")
|
||||
return
|
||||
|
||||
command.action(message, cmd)
|
||||
|
||||
def process_message(self, message: SignalMessage) -> None:
|
||||
"""Process a single message, sending it to the dead letter queue after repeated failures."""
|
||||
max_attempts = self.config.max_message_attempts
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
safety_number = self.signal.get_safety_number(message.source)
|
||||
self.registry.record_contact(message.source, safety_number)
|
||||
self.dispatch(message)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})")
|
||||
else:
|
||||
return
|
||||
|
||||
logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue")
|
||||
with Session(self.config.engine) as session:
|
||||
session.add(
|
||||
DeadLetterMessage(
|
||||
source=message.source,
|
||||
message=message.message,
|
||||
received_at=utcnow(),
|
||||
status=MessageStatus.UNPROCESSED,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def run(self) -> None:
|
||||
"""Listen for messages via WebSocket, reconnecting on failure."""
|
||||
logger.info("Bot started — listening via WebSocket")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(self.config.max_retries),
|
||||
wait=wait_exponential(multiplier=self.config.reconnect_delay, max=self.config.max_reconnect_delay),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
def _listen() -> None:
|
||||
for message in self.signal.listen():
|
||||
logger.info(f"Message from {message.source}: {message.message[:80]}")
|
||||
self.process_message(message)
|
||||
|
||||
try:
|
||||
_listen()
|
||||
except Exception:
|
||||
logger.critical("Max retries exceeded, shutting down")
|
||||
raise
|
||||
|
||||
|
||||
def main(
|
||||
log_level: Annotated[str, typer.Option()] = "DEBUG",
|
||||
llm_timeout: Annotated[int, typer.Option()] = 600,
|
||||
) -> None:
|
||||
"""Run the Signal command and control bot."""
|
||||
configure_logger(log_level)
|
||||
signal_api_url = getenv("SIGNAL_API_URL")
|
||||
phone_number = getenv("SIGNAL_PHONE_NUMBER")
|
||||
inventory_api_url = getenv("INVENTORY_API_URL")
|
||||
|
||||
if signal_api_url is None:
|
||||
error = "SIGNAL_API_URL environment variable not set"
|
||||
raise ValueError(error)
|
||||
if phone_number is None:
|
||||
error = "SIGNAL_PHONE_NUMBER environment variable not set"
|
||||
raise ValueError(error)
|
||||
if inventory_api_url is None:
|
||||
error = "INVENTORY_API_URL environment variable not set"
|
||||
raise ValueError(error)
|
||||
|
||||
signal_bot_config = DATABASES["signal_bot"].alembic_config()
|
||||
upgrade(signal_bot_config, "head")
|
||||
engine = get_postgres_engine(name="SIGNALBOT")
|
||||
sync_roles(engine)
|
||||
config = BotConfig(
|
||||
signal_api_url=signal_api_url,
|
||||
phone_number=phone_number,
|
||||
inventory_api_url=inventory_api_url,
|
||||
ha_url=getenv("HA_URL"),
|
||||
ha_token=getenv("HA_TOKEN"),
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
llm_host = getenv("LLM_HOST")
|
||||
llm_model = getenv("LLM_MODEL", "qwen3-vl:32b")
|
||||
llm_port = int(getenv("LLM_PORT", "11434"))
|
||||
if llm_host is None:
|
||||
error = "LLM_HOST environment variable not set"
|
||||
raise ValueError(error)
|
||||
|
||||
with (
|
||||
SignalClient(config.signal_api_url, config.phone_number) as signal,
|
||||
LLMClient(model=llm_model, host=llm_host, port=llm_port, timeout=llm_timeout) as llm,
|
||||
):
|
||||
registry = DeviceRegistry(signal, engine)
|
||||
bot = Bot(signal, llm, registry, config)
|
||||
bot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Models for the Signal command and control bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime # noqa: TC003 - pydantic needs this at runtime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy.engine import Engine # noqa: TC002 - pydantic needs this at runtime
|
||||
|
||||
|
||||
class TrustLevel(StrEnum):
|
||||
"""Device trust level."""
|
||||
|
||||
VERIFIED = "verified"
|
||||
UNVERIFIED = "unverified"
|
||||
BLOCKED = "blocked"
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
"""RBAC roles — one per command, plus admin which grants all."""
|
||||
|
||||
ADMIN = "admin"
|
||||
STATUS = "status"
|
||||
INVENTORY = "inventory"
|
||||
LOCATION = "location"
|
||||
|
||||
|
||||
class MessageStatus(StrEnum):
|
||||
"""Dead letter queue message status."""
|
||||
|
||||
UNPROCESSED = "unprocessed"
|
||||
PROCESSED = "processed"
|
||||
|
||||
|
||||
class Device(BaseModel):
|
||||
"""A registered device tracked by safety number."""
|
||||
|
||||
phone_number: str
|
||||
safety_number: str
|
||||
trust_level: TrustLevel = TrustLevel.UNVERIFIED
|
||||
first_seen: datetime
|
||||
last_seen: datetime
|
||||
|
||||
|
||||
class SignalMessage(BaseModel):
|
||||
"""An incoming Signal message."""
|
||||
|
||||
source: str
|
||||
timestamp: int
|
||||
message: str = ""
|
||||
attachments: list[str] = []
|
||||
group_id: str | None = None
|
||||
is_receipt: bool = False
|
||||
|
||||
|
||||
class SignalEnvelope(BaseModel):
|
||||
"""Raw envelope from signal-cli-rest-api."""
|
||||
|
||||
envelope: dict[str, Any]
|
||||
account: str | None = None
|
||||
|
||||
|
||||
class InventoryItem(BaseModel):
|
||||
"""An item in the van inventory."""
|
||||
|
||||
name: str
|
||||
quantity: float = 1
|
||||
unit: str = "each"
|
||||
category: str = ""
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class InventoryUpdate(BaseModel):
|
||||
"""Result of processing an inventory update."""
|
||||
|
||||
items: list[InventoryItem] = []
|
||||
raw_response: str = ""
|
||||
source_type: str = "" # "receipt_photo" or "text_list"
|
||||
|
||||
|
||||
class BotConfig(BaseModel):
|
||||
"""Top-level bot configuration."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
signal_api_url: str
|
||||
phone_number: str
|
||||
inventory_api_url: str
|
||||
ha_url: str | None = None
|
||||
ha_token: str | None = None
|
||||
engine: Engine
|
||||
reconnect_delay: int = 5
|
||||
max_reconnect_delay: int = 300
|
||||
max_retries: int = 10
|
||||
max_message_attempts: int = 3
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Client for the signal-cli-rest-api."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Self
|
||||
|
||||
import httpx
|
||||
import websockets.sync.client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from python.signal_bot.models import SignalMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_envelope(envelope: dict[str, Any]) -> SignalMessage | None:
|
||||
"""Parse a signal-cli envelope into a SignalMessage, or None if not a data message."""
|
||||
data_message = envelope.get("dataMessage")
|
||||
if not data_message:
|
||||
return None
|
||||
|
||||
attachment_ids = [att["id"] for att in data_message.get("attachments", []) if "id" in att]
|
||||
|
||||
group_info = data_message.get("groupInfo")
|
||||
group_id = group_info.get("groupId") if group_info else None
|
||||
|
||||
return SignalMessage(
|
||||
source=envelope.get("source", ""),
|
||||
timestamp=envelope.get("timestamp", 0),
|
||||
message=data_message.get("message", "") or "",
|
||||
attachments=attachment_ids,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
class SignalClient:
|
||||
"""Communicate with signal-cli-rest-api.
|
||||
|
||||
Args:
|
||||
base_url: URL of the signal-cli-rest-api (e.g. http://localhost:8989).
|
||||
phone_number: The registered phone number to send/receive as.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, phone_number: str) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.phone_number = phone_number
|
||||
self._client = httpx.Client(base_url=self.base_url, timeout=30)
|
||||
|
||||
def _ws_url(self) -> str:
|
||||
"""Build the WebSocket URL from the base HTTP URL."""
|
||||
url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
return f"{url}/v1/receive/{self.phone_number}"
|
||||
|
||||
def listen(self) -> Generator[SignalMessage]:
|
||||
"""Connect via WebSocket and yield messages as they arrive."""
|
||||
ws_url = self._ws_url()
|
||||
logger.info(f"Connecting to WebSocket: {ws_url}")
|
||||
|
||||
with websockets.sync.client.connect(ws_url) as ws:
|
||||
for raw in ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
envelope = data.get("envelope", {})
|
||||
message = _parse_envelope(envelope)
|
||||
if message:
|
||||
yield message
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON WebSocket frame: {raw[:200]}")
|
||||
|
||||
def send(self, recipient: str, message: str) -> None:
|
||||
"""Send a text message."""
|
||||
payload = {
|
||||
"message": message,
|
||||
"number": self.phone_number,
|
||||
"recipients": [recipient],
|
||||
}
|
||||
response = self._client.post("/v2/send", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
def send_to_group(self, group_id: str, message: str) -> None:
|
||||
"""Send a message to a group."""
|
||||
payload = {
|
||||
"message": message,
|
||||
"number": self.phone_number,
|
||||
"recipients": [group_id],
|
||||
}
|
||||
response = self._client.post("/v2/send", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
def get_attachment(self, attachment_id: str) -> bytes:
|
||||
"""Download an attachment by ID."""
|
||||
response = self._client.get(f"/v1/attachments/{attachment_id}")
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def get_identities(self) -> list[dict[str, Any]]:
|
||||
"""List known identities and their trust levels."""
|
||||
response = self._client.get(f"/v1/identities/{self.phone_number}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_safety_number(self, phone_number: str) -> str | None:
|
||||
"""Look up the safety number for a contact from signal-cli's local store."""
|
||||
for identity in self.get_identities():
|
||||
if identity.get("number") == phone_number:
|
||||
return identity.get("safety_number", identity.get("fingerprint", ""))
|
||||
return None
|
||||
|
||||
def trust_identity(self, number_to_trust: str, *, trust_all_known_keys: bool = False) -> None:
|
||||
"""Trust an identity (verify safety number)."""
|
||||
payload: dict[str, Any] = {}
|
||||
if trust_all_known_keys:
|
||||
payload["trust_all_known_keys"] = True
|
||||
response = self._client.put(
|
||||
f"/v1/identities/{self.phone_number}/trust/{number_to_trust}",
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def reply(self, message: SignalMessage, text: str) -> None:
|
||||
"""Reply to a message, routing to group or individual."""
|
||||
if message.group_id:
|
||||
self.send_to_group(message.group_id, text)
|
||||
else:
|
||||
self.send(message.source, text)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
"""Close the HTTP client on exit."""
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
self._client.close()
|
||||
@@ -1 +0,0 @@
|
||||
"""Audiobook tools."""
|
||||
@@ -1,471 +0,0 @@
|
||||
"""Convert Audible AAX downloads into Audiobookshelf-friendly M4B files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict, dataclass
|
||||
from os import getenv
|
||||
from pathlib import Path # noqa: TC003 This is required for the typer CLI
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
from uuid import uuid7
|
||||
|
||||
import typer
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_postgres_engine
|
||||
from python.tools.audiobook.metadata_agent import (
|
||||
AgentConfig,
|
||||
StandardBookMetadata,
|
||||
standard_book_metadata,
|
||||
write_agent_log,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SENSITIVE_COMMAND_ARGUMENTS = {"-activation_bytes"}
|
||||
BOOK_RANGE_PATTERN = re.compile(r"(?:^|-)books?-(?P<start>[1-9]\d*)-(?P<end>[1-9]\d*)(?:-|$)")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversionConfig:
|
||||
"""Runtime settings for one conversion command."""
|
||||
|
||||
resolved_output: Path
|
||||
ollama_api_key: str
|
||||
agent_config: AgentConfig
|
||||
engine: Engine
|
||||
activation_bytes: str | None
|
||||
dry_run: bool
|
||||
overwrite: bool
|
||||
work_directory_name: str = ".audible_convert"
|
||||
dry_run_directory_name: str = "dry-run"
|
||||
temp_directory_name: str = "tmp"
|
||||
log_directory_name: str = "logs"
|
||||
review_directory_name: str = "review"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConcurrentConversionResult:
|
||||
"""Result from running ffmpeg and metadata resolution together."""
|
||||
|
||||
metadata: StandardBookMetadata | None
|
||||
conversion_error: Exception | None
|
||||
metadata_error: Exception | None
|
||||
|
||||
|
||||
class CommandExecutionError(RuntimeError):
|
||||
"""Command failed without exposing sensitive arguments."""
|
||||
|
||||
def __init__(self, arguments: list[str], returncode: int) -> None:
|
||||
"""Create a redacted command failure."""
|
||||
self.arguments = tuple(arguments)
|
||||
self.returncode = returncode
|
||||
command = " ".join(redact_command_arguments(arguments))
|
||||
super().__init__(f"Command failed with exit code {returncode}: {command}")
|
||||
|
||||
|
||||
def main(
|
||||
input_directory: Annotated[Path, typer.Argument(help="Directory audible-cli downloads AAX files into.")],
|
||||
output_directory: Annotated[Path, typer.Argument(help="Audiobook output directory.")],
|
||||
*,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option("--dry-run", help="Print planned output files and write marker files without converting."),
|
||||
] = False,
|
||||
overwrite: Annotated[bool, typer.Option("--overwrite", help="Overwrite existing M4B files.")] = False,
|
||||
) -> None:
|
||||
"""Convert AAX files from a download directory into M4B files."""
|
||||
configure_logger()
|
||||
resolved_input = input_directory.resolve(strict=True)
|
||||
resolved_output = output_directory.resolve()
|
||||
if not dry_run:
|
||||
resolved_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ollama_api_key = getenv("OLLAMA_API_KEY")
|
||||
if not ollama_api_key:
|
||||
msg = "OLLAMA_API_KEY is required for audiobook metadata resolution"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
config = ConversionConfig(
|
||||
resolved_output=resolved_output,
|
||||
ollama_api_key=ollama_api_key,
|
||||
agent_config=AgentConfig(),
|
||||
engine=get_postgres_engine(name="RICHIE"),
|
||||
activation_bytes=getenv("AUDIBLE_ACTIVATION_BYTES"),
|
||||
dry_run=dry_run,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
aax_files = sorted(resolved_input.glob("*.aax"))
|
||||
if not aax_files:
|
||||
logger.info("No AAX files found in %s", resolved_input)
|
||||
return
|
||||
for aax_file in aax_files:
|
||||
logger.info("Converting %s", aax_file)
|
||||
convert_aax_file_with_agent(aax_file, config)
|
||||
|
||||
|
||||
def run_command(arguments: list[str], *, capture: bool = False) -> subprocess.CompletedProcess[str]:
|
||||
"""Run a command and return the completed process.
|
||||
|
||||
Args:
|
||||
arguments: Command and arguments to run.
|
||||
capture: Whether to capture stdout and stderr.
|
||||
|
||||
Returns:
|
||||
The completed process.
|
||||
"""
|
||||
logger.debug("%s", " ".join(redact_command_arguments(arguments)))
|
||||
try:
|
||||
return subprocess.run(arguments, check=True, capture_output=capture, text=True)
|
||||
except subprocess.CalledProcessError as error:
|
||||
raise CommandExecutionError(arguments, error.returncode) from error
|
||||
|
||||
|
||||
def redact_command_arguments(arguments: list[str]) -> list[str]:
|
||||
"""Return command arguments with sensitive values redacted."""
|
||||
redacted = []
|
||||
redact_next = False
|
||||
for argument in arguments:
|
||||
if redact_next:
|
||||
redacted.append("<redacted>")
|
||||
redact_next = False
|
||||
continue
|
||||
|
||||
redacted.append(argument)
|
||||
redact_next = argument in SENSITIVE_COMMAND_ARGUMENTS
|
||||
return redacted
|
||||
|
||||
|
||||
def read_metadata(aax_file: Path) -> dict[str, str]:
|
||||
"""Read ffprobe format tags from an AAX file.
|
||||
|
||||
Args:
|
||||
aax_file: AAX file to inspect.
|
||||
|
||||
Returns:
|
||||
Lower-cased metadata tag names mapped to their values.
|
||||
"""
|
||||
completed = run_command(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"quiet",
|
||||
"-print_format",
|
||||
"json",
|
||||
"-show_format",
|
||||
str(aax_file),
|
||||
],
|
||||
capture=True,
|
||||
)
|
||||
ffprobe_data: dict[str, Any] = json.loads(completed.stdout)
|
||||
tags = ffprobe_data.get("format", {}).get("tags", {})
|
||||
return {str(key).lower(): str(value) for key, value in tags.items()}
|
||||
|
||||
|
||||
def output_stem(metadata: StandardBookMetadata) -> str:
|
||||
"""Build the output stem for a book.
|
||||
|
||||
Args:
|
||||
metadata: Book metadata.
|
||||
|
||||
Returns:
|
||||
Output stem in author-series_01-title form.
|
||||
"""
|
||||
index_slug = series_index_slug(metadata.series_index, metadata.title)
|
||||
return f"{metadata.author}-{metadata.series}_{index_slug}-{metadata.title}"
|
||||
|
||||
|
||||
def series_index_slug(series_index: float, title: str = "") -> str:
|
||||
"""Return a filename-safe series index."""
|
||||
if title_range := title_series_range_slug(series_index, title):
|
||||
return title_range
|
||||
index = float(series_index)
|
||||
if index.is_integer():
|
||||
return f"{int(index):02}"
|
||||
return f"{int(index):02}.5"
|
||||
|
||||
|
||||
def title_series_range_slug(series_index: float, title: str) -> str | None:
|
||||
"""Return a series range slug found in an omnibus title."""
|
||||
index = float(series_index)
|
||||
if not index.is_integer():
|
||||
return None
|
||||
first_index = int(index)
|
||||
for match in BOOK_RANGE_PATTERN.finditer(title):
|
||||
start = int(match.group("start"))
|
||||
end = int(match.group("end"))
|
||||
if start == first_index and end > start:
|
||||
return f"{start:02}-{end:02}"
|
||||
return None
|
||||
|
||||
|
||||
def metadata_output_path(output_directory: Path, metadata: StandardBookMetadata) -> Path:
|
||||
"""Build the final M4B path from resolved metadata."""
|
||||
stem = output_stem(metadata)
|
||||
return output_directory / stem / f"{stem}.m4b"
|
||||
|
||||
|
||||
def convert_aax_file(
|
||||
aax_file: Path,
|
||||
destination: Path,
|
||||
activation_bytes: str | None,
|
||||
*,
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
"""Convert an AAX file into an M4B file.
|
||||
|
||||
Args:
|
||||
aax_file: Source AAX file.
|
||||
destination: Destination M4B file.
|
||||
activation_bytes: Optional Audible activation bytes for ffmpeg.
|
||||
overwrite: Whether to overwrite an existing M4B.
|
||||
"""
|
||||
if destination.exists() and not overwrite:
|
||||
logger.info("Skipping existing file %s", destination)
|
||||
return
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
arguments = ["ffmpeg", "-hide_banner", "-y" if overwrite else "-n"]
|
||||
if activation_bytes:
|
||||
arguments.extend(["-activation_bytes", activation_bytes])
|
||||
arguments.extend(["-i", str(aax_file), "-map_metadata", "0", "-c", "copy", str(destination)])
|
||||
run_command(arguments)
|
||||
|
||||
|
||||
def write_review_file(
|
||||
*,
|
||||
destination: Path | None,
|
||||
ffprobe_metadata: dict[str, str],
|
||||
log_file: Path,
|
||||
metadata: StandardBookMetadata | None,
|
||||
reason: str,
|
||||
review_file: Path,
|
||||
source: Path,
|
||||
temp_file: Path | None,
|
||||
) -> None:
|
||||
"""Write a manual review file for an unresolved conversion."""
|
||||
review_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"destination": str(destination) if destination else None,
|
||||
"ffprobe_metadata": ffprobe_metadata,
|
||||
"metadata": asdict(metadata) if metadata else None,
|
||||
"reason": reason,
|
||||
"source": str(source),
|
||||
"temp_file": str(temp_file) if temp_file else None,
|
||||
}
|
||||
review_file.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
||||
write_agent_log(log_file, "review_written", path=str(review_file), reason=reason)
|
||||
|
||||
|
||||
def cleanup_temp_output(temp_file: Path) -> None:
|
||||
"""Remove a run's temporary output directory."""
|
||||
shutil.rmtree(temp_file.parent, ignore_errors=True)
|
||||
|
||||
|
||||
def dry_run_aax_file_with_agent(
|
||||
aax_file: Path,
|
||||
ffprobe_metadata: dict[str, str],
|
||||
engine: Engine,
|
||||
config: ConversionConfig,
|
||||
log_file: Path,
|
||||
review_file: Path,
|
||||
) -> None:
|
||||
"""Resolve and print the planned output path without converting."""
|
||||
metadata = standard_book_metadata(
|
||||
aax_file.name,
|
||||
ffprobe_metadata,
|
||||
engine,
|
||||
log_file,
|
||||
config.ollama_api_key,
|
||||
config.agent_config,
|
||||
)
|
||||
destination = None if metadata.needs_review else metadata_output_path(config.resolved_output, metadata)
|
||||
if metadata.needs_review:
|
||||
write_review_file(
|
||||
destination=destination,
|
||||
ffprobe_metadata=ffprobe_metadata,
|
||||
log_file=log_file,
|
||||
metadata=metadata,
|
||||
reason="metadata_needs_review",
|
||||
review_file=review_file,
|
||||
source=aax_file,
|
||||
temp_file=None,
|
||||
)
|
||||
typer.echo(f"{aax_file} -> REVIEW {review_file}")
|
||||
else:
|
||||
stem = output_stem(metadata)
|
||||
dry_run_file = (
|
||||
config.resolved_output / config.work_directory_name / config.dry_run_directory_name / stem / f"{stem}.m4b"
|
||||
)
|
||||
dry_run_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
dry_run_file.write_text(f"{destination}\n", encoding="utf-8")
|
||||
write_agent_log(
|
||||
log_file,
|
||||
"dry_run_file_written",
|
||||
destination=str(destination),
|
||||
path=str(dry_run_file),
|
||||
)
|
||||
typer.echo(f"{aax_file} -> {destination}")
|
||||
|
||||
|
||||
def convert_temp_file_and_resolve_metadata(
|
||||
aax_file: Path,
|
||||
temp_file: Path,
|
||||
ffprobe_metadata: dict[str, str],
|
||||
config: ConversionConfig,
|
||||
log_file: Path,
|
||||
) -> ConcurrentConversionResult:
|
||||
"""Run ffmpeg and metadata resolution in parallel."""
|
||||
conversion_error: Exception | None = None
|
||||
metadata_error: Exception | None = None
|
||||
metadata: StandardBookMetadata | None = None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
conversion_future = executor.submit(
|
||||
convert_aax_file,
|
||||
aax_file,
|
||||
temp_file,
|
||||
config.activation_bytes,
|
||||
overwrite=True,
|
||||
)
|
||||
metadata_future = executor.submit(
|
||||
standard_book_metadata,
|
||||
aax_file.name,
|
||||
ffprobe_metadata,
|
||||
config.engine,
|
||||
log_file,
|
||||
config.ollama_api_key,
|
||||
config.agent_config,
|
||||
)
|
||||
|
||||
conversion_error = conversion_future.exception()
|
||||
if conversion_error is None:
|
||||
conversion_future.result()
|
||||
|
||||
metadata_error = metadata_future.exception()
|
||||
if metadata_error is None:
|
||||
metadata = metadata_future.result()
|
||||
|
||||
return ConcurrentConversionResult(
|
||||
metadata=metadata,
|
||||
conversion_error=conversion_error,
|
||||
metadata_error=metadata_error,
|
||||
)
|
||||
|
||||
|
||||
def convert_aax_file_with_agent(aax_file: Path, config: ConversionConfig) -> None:
|
||||
"""Convert one AAX file using the metadata agent for the final path."""
|
||||
run_id = uuid7().hex
|
||||
log_file = config.resolved_output / config.work_directory_name / config.log_directory_name / f"{run_id}.jsonl"
|
||||
review_file = config.resolved_output / config.work_directory_name / config.review_directory_name / f"{run_id}.json"
|
||||
write_agent_log(log_file, "conversion_start", source=str(aax_file), dry_run=config.dry_run)
|
||||
try:
|
||||
ffprobe_metadata = read_metadata(aax_file)
|
||||
except Exception as error:
|
||||
logger.exception("ffprobe failed")
|
||||
write_review_file(
|
||||
destination=None,
|
||||
ffprobe_metadata={},
|
||||
log_file=log_file,
|
||||
metadata=None,
|
||||
reason=f"ffprobe_failed: {error}",
|
||||
review_file=review_file,
|
||||
source=aax_file,
|
||||
temp_file=None,
|
||||
)
|
||||
return
|
||||
|
||||
if config.dry_run:
|
||||
dry_run_aax_file_with_agent(
|
||||
aax_file,
|
||||
ffprobe_metadata,
|
||||
config.engine,
|
||||
config,
|
||||
log_file,
|
||||
review_file,
|
||||
)
|
||||
return
|
||||
|
||||
temp_file = (
|
||||
config.resolved_output / config.work_directory_name / config.temp_directory_name / run_id / "converted.m4b"
|
||||
)
|
||||
temp_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = convert_temp_file_and_resolve_metadata(aax_file, temp_file, ffprobe_metadata, config, log_file)
|
||||
|
||||
if result.conversion_error:
|
||||
reason = f"ffmpeg_failed: {result.conversion_error}"
|
||||
write_review_file(
|
||||
destination=None,
|
||||
ffprobe_metadata=ffprobe_metadata,
|
||||
log_file=log_file,
|
||||
metadata=result.metadata,
|
||||
reason=reason,
|
||||
review_file=review_file,
|
||||
source=aax_file,
|
||||
temp_file=temp_file if temp_file.exists() else None,
|
||||
)
|
||||
return
|
||||
|
||||
if result.metadata_error:
|
||||
write_review_file(
|
||||
destination=None,
|
||||
ffprobe_metadata=ffprobe_metadata,
|
||||
log_file=log_file,
|
||||
metadata=None,
|
||||
reason=f"metadata_failed: {result.metadata_error}",
|
||||
review_file=review_file,
|
||||
source=aax_file,
|
||||
temp_file=temp_file,
|
||||
)
|
||||
return
|
||||
|
||||
if result.metadata is None or result.metadata.needs_review:
|
||||
write_review_file(
|
||||
destination=None,
|
||||
ffprobe_metadata=ffprobe_metadata,
|
||||
log_file=log_file,
|
||||
metadata=result.metadata,
|
||||
reason="metadata_needs_review",
|
||||
review_file=review_file,
|
||||
source=aax_file,
|
||||
temp_file=temp_file,
|
||||
)
|
||||
return
|
||||
|
||||
destination = metadata_output_path(config.resolved_output, result.metadata)
|
||||
if destination.exists() and not config.overwrite:
|
||||
write_agent_log(log_file, "destination_exists", destination=str(destination))
|
||||
cleanup_temp_output(temp_file)
|
||||
return
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
temp_file.replace(destination)
|
||||
except Exception as error: # noqa: BLE001
|
||||
write_review_file(
|
||||
destination=destination,
|
||||
ffprobe_metadata=ffprobe_metadata,
|
||||
log_file=log_file,
|
||||
metadata=result.metadata,
|
||||
reason=f"rename_failed: {error}",
|
||||
review_file=review_file,
|
||||
source=aax_file,
|
||||
temp_file=temp_file if temp_file.exists() else None,
|
||||
)
|
||||
else:
|
||||
cleanup_temp_output(temp_file)
|
||||
write_agent_log(log_file, "conversion_complete", destination=str(destination))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -1,176 +0,0 @@
|
||||
"""Import audiobook catalog authors and series from CSV files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import logging
|
||||
from pathlib import Path # noqa: TC003 This is required for the typer CLI
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_postgres_engine
|
||||
from python.orm.richie import AudiobookAuthor, AudiobookSeries
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTHOR_NAME_COLUMN = "author_name"
|
||||
ID_COLUMN = "id"
|
||||
NAME_COLUMN = "name"
|
||||
|
||||
|
||||
class CatalogImportError(ValueError):
|
||||
"""CSV catalog import failed validation."""
|
||||
|
||||
|
||||
def main(
|
||||
authors_csv: Annotated[Path, typer.Argument(help="CSV with name and optional id.")],
|
||||
series_csv: Annotated[Path, typer.Argument(help="CSV with name, author_name, and optional id.")],
|
||||
) -> None:
|
||||
"""Upsert audiobook authors and series from CSV files."""
|
||||
configure_logger()
|
||||
try:
|
||||
engine = get_postgres_engine(name="RICHIE")
|
||||
with Session(engine) as session:
|
||||
author_count = upsert_authors_from_csv(session, authors_csv)
|
||||
series_count = upsert_series_from_csv(session, series_csv)
|
||||
session.commit()
|
||||
except CatalogImportError as error:
|
||||
typer.echo(str(error), err=True)
|
||||
raise typer.Exit(code=1) from error
|
||||
|
||||
logger.info("Upserted %s authors and %s series", author_count, series_count)
|
||||
|
||||
|
||||
def upsert_authors_from_csv(session: Session, authors_csv: Path) -> int:
|
||||
"""Upsert authors from a CSV file."""
|
||||
count = 0
|
||||
for row_number, row in csv_rows(authors_csv):
|
||||
name = required_csv_value(row, authors_csv, row_number, NAME_COLUMN)
|
||||
upsert_author(session, name, csv_id(row, authors_csv, row_number))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def upsert_series_from_csv(session: Session, series_csv: Path) -> int:
|
||||
"""Upsert series from a CSV file."""
|
||||
count = 0
|
||||
for row_number, row in csv_rows(series_csv):
|
||||
series_name = required_csv_value(row, series_csv, row_number, NAME_COLUMN)
|
||||
author_name = required_csv_value(row, series_csv, row_number, AUTHOR_NAME_COLUMN)
|
||||
author = find_author_by_name(session, author_name)
|
||||
if author is None:
|
||||
msg = f"{series_csv}:{row_number}: author not found: {author_name}"
|
||||
raise CatalogImportError(msg)
|
||||
upsert_series(session, series_name, author, csv_id(row, series_csv, row_number))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def upsert_author(session: Session, name: str, author_id: int | None) -> AudiobookAuthor:
|
||||
"""Upsert one author by id or exact name."""
|
||||
if author_id is not None:
|
||||
author = session.get(AudiobookAuthor, author_id)
|
||||
if author is None:
|
||||
author = AudiobookAuthor(id=author_id, name=name)
|
||||
session.add(author)
|
||||
else:
|
||||
author.name = name
|
||||
session.flush()
|
||||
return author
|
||||
|
||||
author = find_author_by_name(session, name)
|
||||
if author is None:
|
||||
author = AudiobookAuthor(name=name)
|
||||
session.add(author)
|
||||
session.flush()
|
||||
return author
|
||||
|
||||
|
||||
def upsert_series(
|
||||
session: Session,
|
||||
name: str,
|
||||
author: AudiobookAuthor,
|
||||
series_id: int | None,
|
||||
) -> AudiobookSeries:
|
||||
"""Upsert one series by id or exact author/name match."""
|
||||
if series_id is not None:
|
||||
series = session.get(AudiobookSeries, series_id)
|
||||
if series is None:
|
||||
series = AudiobookSeries(id=series_id, name=name, author=author)
|
||||
session.add(series)
|
||||
else:
|
||||
series.name = name
|
||||
series.author = author
|
||||
session.flush()
|
||||
return series
|
||||
|
||||
series = find_series_by_name_and_author(session, name, author.id)
|
||||
if series is None:
|
||||
series = AudiobookSeries(name=name, author=author)
|
||||
session.add(series)
|
||||
session.flush()
|
||||
return series
|
||||
|
||||
|
||||
def find_author_by_name(session: Session, name: str) -> AudiobookAuthor | None:
|
||||
"""Find one author by exact name."""
|
||||
return session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
|
||||
|
||||
|
||||
def find_series_by_name_and_author(
|
||||
session: Session,
|
||||
name: str,
|
||||
author_id: int,
|
||||
) -> AudiobookSeries | None:
|
||||
"""Find one series by exact name and author."""
|
||||
return session.scalar(
|
||||
select(AudiobookSeries).where(
|
||||
AudiobookSeries.name == name,
|
||||
AudiobookSeries.author_id == author_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def csv_rows(csv_path: Path) -> list[tuple[int, dict[str, str | None]]]:
|
||||
"""Read a CSV file as numbered rows."""
|
||||
with csv_path.open(newline="", encoding="utf-8") as file:
|
||||
reader = csv.DictReader(file)
|
||||
if reader.fieldnames is None:
|
||||
msg = f"{csv_path}: missing CSV header"
|
||||
raise CatalogImportError(msg)
|
||||
return [(row_number, row) for row_number, row in enumerate(reader, start=2)]
|
||||
|
||||
|
||||
def required_csv_value(
|
||||
row: dict[str, str | None],
|
||||
csv_path: Path,
|
||||
row_number: int,
|
||||
column: str,
|
||||
) -> str:
|
||||
"""Read a required CSV value."""
|
||||
value = row.get(column)
|
||||
if value and value.strip():
|
||||
return value.strip()
|
||||
msg = f"{csv_path}:{row_number}: missing required column value: {column}"
|
||||
raise CatalogImportError(msg)
|
||||
|
||||
|
||||
def csv_id(row: dict[str, str | None], csv_path: Path, row_number: int) -> int | None:
|
||||
"""Read an optional id field from a CSV row."""
|
||||
value = row.get(ID_COLUMN)
|
||||
if value is None or not value.strip():
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as error:
|
||||
msg = f"{csv_path}:{row_number}: id must be an integer: {value}"
|
||||
raise CatalogImportError(msg) from error
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -1,599 +0,0 @@
|
||||
"""LLM tool calling support for audiobook metadata resolution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from python.orm.richie import Audiobook, AudiobookAuthor, AudiobookSeries
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.tools.audiobook.metadata_agent import AgentConfig
|
||||
|
||||
CATALOG_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+)*$")
|
||||
TITLE_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
|
||||
|
||||
LogWriter = Callable[..., None]
|
||||
|
||||
|
||||
class MetadataResolutionError(ValueError):
|
||||
"""Metadata resolution failed validation."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnsuredBook:
|
||||
"""Book row plus whether it was created."""
|
||||
|
||||
book: Audiobook
|
||||
action: str
|
||||
|
||||
|
||||
class CatalogToolRegistry:
|
||||
"""Controlled catalog tools exposed to the metadata model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
log_path: Path,
|
||||
config: AgentConfig,
|
||||
write_log: LogWriter,
|
||||
) -> None:
|
||||
"""Create a registry bound to one database session and audit log."""
|
||||
self.session = session
|
||||
self.log_path = log_path
|
||||
self.config = config
|
||||
self.write_log = write_log
|
||||
self.seen_author_ids: set[int] = set()
|
||||
self.seen_series_ids: set[int] = set()
|
||||
self.seen_book_ids: set[int] = set()
|
||||
self.created_author_ids: set[int] = set()
|
||||
self.created_series_ids: set[int] = set()
|
||||
self.created_book_ids: set[int] = set()
|
||||
|
||||
def tool_schemas(self) -> list[dict[str, object]]:
|
||||
"""Return Ollama tool schemas."""
|
||||
schemas = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_authors",
|
||||
"description": "Search canonical audiobook authors by slug or noisy source text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_series",
|
||||
"description": "Search canonical audiobook series by slug or noisy source text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"author_id": {"type": ["integer", "null"]},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_books",
|
||||
"description": "Search canonical audiobook titles with optional author and series filters.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"author_id": {"type": ["integer", "null"]},
|
||||
"series_id": {"type": ["integer", "null"]},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ensure_author",
|
||||
"description": "Normalize an author name to a catalog slug, then return or create that author.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ensure_series",
|
||||
"description": "Normalize a series name to a catalog slug, then return or create it for an author.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"author_id": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "author_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ensure_book",
|
||||
"description": "Normalize a title to a book slug, then return or create it for an author/series.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"author_id": {"type": "integer"},
|
||||
"series_id": {"type": ["integer", "null"]},
|
||||
"series_index": {"type": "number", "multipleOf": 0.5},
|
||||
},
|
||||
"required": ["title", "author_id", "series_id", "series_index"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
enabled_tool_names = set(self.config.tool_names)
|
||||
return [schema for schema in schemas if schema["function"]["name"] in enabled_tool_names]
|
||||
|
||||
def run(self, name: str, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Run one catalog tool and audit the call."""
|
||||
handlers = {
|
||||
"search_authors": self.run_search_authors,
|
||||
"search_series": self.run_search_series,
|
||||
"search_books": self.run_search_books,
|
||||
"ensure_author": self.run_ensure_author,
|
||||
"ensure_series": self.run_ensure_series,
|
||||
"ensure_book": self.run_ensure_book,
|
||||
}
|
||||
handler = handlers.get(name)
|
||||
if handler is None:
|
||||
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="unknown_tool")
|
||||
msg = f"Unknown audiobook metadata tool: {name}"
|
||||
raise MetadataResolutionError(msg)
|
||||
if name not in self.config.tool_names:
|
||||
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="tool_not_enabled")
|
||||
msg = f"Audiobook metadata tool is not enabled: {name}"
|
||||
raise MetadataResolutionError(msg)
|
||||
|
||||
started = time.perf_counter()
|
||||
self.write_log(self.log_path, "tool_call", tool=name, arguments=arguments)
|
||||
result = handler(arguments)
|
||||
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
||||
self.write_log(
|
||||
self.log_path,
|
||||
"tool_result",
|
||||
tool=name,
|
||||
duration_ms=duration_ms,
|
||||
result_count=len(result),
|
||||
preview=result[:3],
|
||||
)
|
||||
return result
|
||||
|
||||
def get_author(self, author_id: int) -> AudiobookAuthor | None:
|
||||
"""Return an author by id."""
|
||||
return self.session.get(AudiobookAuthor, author_id)
|
||||
|
||||
def get_book(self, book_id: int) -> Audiobook | None:
|
||||
"""Return a book by id."""
|
||||
return self.session.get(Audiobook, book_id)
|
||||
|
||||
def get_series(self, series_id: int) -> AudiobookSeries | None:
|
||||
"""Return a series by id."""
|
||||
return self.session.get(AudiobookSeries, series_id)
|
||||
|
||||
def prune_unused_created_rows(self, *, author_id: int, book_id: int | None, series_id: int | None) -> None:
|
||||
"""Remove catalog rows created during this run but not used by final metadata."""
|
||||
used_book_ids = {book_id} if book_id is not None else set()
|
||||
for created_book_id in self.created_book_ids - used_book_ids:
|
||||
if book := self.get_book(created_book_id):
|
||||
self.session.delete(book)
|
||||
|
||||
self.session.flush()
|
||||
used_series_ids = {series_id} if series_id is not None else set()
|
||||
for created_series_id in self.created_series_ids - used_series_ids:
|
||||
series = self.get_series(created_series_id)
|
||||
if series and not series.books:
|
||||
self.session.delete(series)
|
||||
|
||||
self.session.flush()
|
||||
for created_author_id in self.created_author_ids - {author_id}:
|
||||
author = self.get_author(created_author_id)
|
||||
if author and not author.books and not author.series:
|
||||
self.session.delete(author)
|
||||
|
||||
def run_search_authors(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Search authors from tool arguments and remember returned ids."""
|
||||
query = required_string(arguments, "query")
|
||||
statement = select(AudiobookAuthor).order_by(AudiobookAuthor.name).limit(self.config.max_tool_results)
|
||||
if terms := query_terms(query):
|
||||
statement = statement.where(or_(*(AudiobookAuthor.name.ilike(f"%{term}%") for term in terms)))
|
||||
|
||||
authors = self.session.scalars(statement).all()
|
||||
self.seen_author_ids.update(author.id for author in authors)
|
||||
return [{"id": author.id, "name": author.name} for author in authors]
|
||||
|
||||
def run_search_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Search series from tool arguments and remember returned ids."""
|
||||
query = required_string(arguments, "query")
|
||||
author_id = optional_int(arguments.get("author_id"), "author_id")
|
||||
statement = select(AudiobookSeries).order_by(AudiobookSeries.name).limit(self.config.max_tool_results)
|
||||
if terms := query_terms(query):
|
||||
statement = statement.where(or_(*(AudiobookSeries.name.ilike(f"%{term}%") for term in terms)))
|
||||
if author_id is not None:
|
||||
statement = statement.where(AudiobookSeries.author_id == author_id)
|
||||
|
||||
series_rows = self.session.scalars(statement).all()
|
||||
self.seen_series_ids.update(series.id for series in series_rows)
|
||||
self.seen_author_ids.update(series.author_id for series in series_rows)
|
||||
return [
|
||||
{
|
||||
"id": series.id,
|
||||
"name": series.name,
|
||||
"author_id": series.author_id,
|
||||
"author": series.author.name,
|
||||
}
|
||||
for series in series_rows
|
||||
]
|
||||
|
||||
def run_search_books(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Search books from tool arguments and remember returned ids."""
|
||||
query = required_string(arguments, "query")
|
||||
author_id = optional_int(arguments.get("author_id"), "author_id")
|
||||
series_id = optional_int(arguments.get("series_id"), "series_id")
|
||||
statement = select(Audiobook).order_by(Audiobook.title).limit(self.config.max_tool_results)
|
||||
if terms := query_terms(query):
|
||||
statement = statement.where(or_(*(Audiobook.title.ilike(f"%{term}%") for term in terms)))
|
||||
if author_id is not None:
|
||||
statement = statement.where(Audiobook.author_id == author_id)
|
||||
if series_id is not None:
|
||||
statement = statement.where(Audiobook.series_id == series_id)
|
||||
|
||||
books = self.session.scalars(statement).all()
|
||||
self.seen_book_ids.update(book.id for book in books)
|
||||
self.seen_author_ids.update(book.author_id for book in books)
|
||||
self.seen_series_ids.update(book.series_id for book in books if book.series_id is not None)
|
||||
return [
|
||||
{
|
||||
"id": book.id,
|
||||
"title": book.title,
|
||||
"author_id": book.author_id,
|
||||
"author": book.author.name,
|
||||
"series_id": book.series_id,
|
||||
"series": book.series.name if book.series else self.config.standalone_series,
|
||||
"series_index": book.series_index,
|
||||
}
|
||||
for book in books
|
||||
]
|
||||
|
||||
def run_ensure_author(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Ensure an author from tool arguments and return a tool result."""
|
||||
name = normalize_catalog_slug(required_string(arguments, "name"))
|
||||
validate_catalog_slug(name, "author")
|
||||
author = self.session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
|
||||
action = "existing"
|
||||
if author is None:
|
||||
author = AudiobookAuthor(name=name)
|
||||
self.session.add(author)
|
||||
self.session.flush()
|
||||
self.created_author_ids.add(author.id)
|
||||
action = "created"
|
||||
|
||||
self.seen_author_ids.add(author.id)
|
||||
return [{"id": author.id, "name": author.name, "action": action}]
|
||||
|
||||
def run_ensure_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Ensure a series from tool arguments and return a tool result."""
|
||||
name = normalize_catalog_slug(required_string(arguments, "name"))
|
||||
author_id = required_int(arguments, "author_id")
|
||||
validate_catalog_slug(name, "series")
|
||||
author = self.required_author(author_id)
|
||||
series = self.find_series_by_catalog_slug(name, author.id)
|
||||
action = "existing"
|
||||
if series is None:
|
||||
series = AudiobookSeries(name=name, author=author)
|
||||
self.session.add(series)
|
||||
self.session.flush()
|
||||
self.created_series_ids.add(series.id)
|
||||
action = "created"
|
||||
|
||||
self.seen_author_ids.add(author.id)
|
||||
self.seen_series_ids.add(series.id)
|
||||
return [self.series_result(series, action)]
|
||||
|
||||
def run_ensure_book(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||
"""Ensure a book from tool arguments and return a tool result."""
|
||||
title = required_string(arguments, "title")
|
||||
author_id = required_int(arguments, "author_id")
|
||||
series_id = optional_int(arguments.get("series_id"), "series_id")
|
||||
series_index = required_series_index(arguments, "series_index")
|
||||
ensured = self.ensure_book(title, author_id, series_id, series_index)
|
||||
return [self.book_result(ensured.book, ensured.action)]
|
||||
|
||||
def ensure_book(
|
||||
self,
|
||||
title: str,
|
||||
author_id: int,
|
||||
series_id: int | None,
|
||||
series_index: float,
|
||||
) -> EnsuredBook:
|
||||
"""Return an existing book row, or create it after validating ownership."""
|
||||
title = normalize_title_slug(title)
|
||||
validate_title_slug(title)
|
||||
author = self.required_author(author_id)
|
||||
series = None
|
||||
if series_id is None:
|
||||
if series_index != 0:
|
||||
msg = "standalone books must use series_index 0"
|
||||
raise MetadataResolutionError(msg)
|
||||
else:
|
||||
series = self.required_series(series_id)
|
||||
if series.author_id != author.id:
|
||||
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
||||
raise MetadataResolutionError(msg)
|
||||
if series_index <= 0:
|
||||
msg = "series books must use a positive series_index"
|
||||
raise MetadataResolutionError(msg)
|
||||
|
||||
statement = select(Audiobook).where(
|
||||
Audiobook.title == title,
|
||||
Audiobook.author_id == author.id,
|
||||
)
|
||||
if series is None:
|
||||
statement = statement.where(Audiobook.series_id.is_(None))
|
||||
else:
|
||||
statement = statement.where(Audiobook.series_id == series.id)
|
||||
book = self.session.scalar(statement)
|
||||
if book is None:
|
||||
book = Audiobook(title=title, author=author, series=series, series_index=series_index)
|
||||
self.session.add(book)
|
||||
self.session.flush()
|
||||
self.created_book_ids.add(book.id)
|
||||
action = "created"
|
||||
else:
|
||||
action = "existing"
|
||||
|
||||
self.seen_book_ids.add(book.id)
|
||||
self.seen_author_ids.add(author.id)
|
||||
if book.series_id is not None:
|
||||
self.seen_series_ids.add(book.series_id)
|
||||
return EnsuredBook(book=book, action=action)
|
||||
|
||||
def required_author(self, author_id: int) -> AudiobookAuthor:
|
||||
"""Return an author or fail metadata resolution."""
|
||||
author = self.get_author(author_id)
|
||||
if author is None:
|
||||
msg = f"author_id {author_id} does not exist"
|
||||
raise MetadataResolutionError(msg)
|
||||
return author
|
||||
|
||||
def required_series(self, series_id: int) -> AudiobookSeries:
|
||||
"""Return a series or fail metadata resolution."""
|
||||
series = self.get_series(series_id)
|
||||
if series is None:
|
||||
msg = f"series_id {series_id} does not exist"
|
||||
raise MetadataResolutionError(msg)
|
||||
return series
|
||||
|
||||
def find_series_by_catalog_slug(self, name: str, author_id: int) -> AudiobookSeries | None:
|
||||
"""Return a series by exact slug or underscore-insensitive slug."""
|
||||
exact = self.session.scalar(
|
||||
select(AudiobookSeries).where(
|
||||
AudiobookSeries.name == name,
|
||||
AudiobookSeries.author_id == author_id,
|
||||
),
|
||||
)
|
||||
if exact is not None:
|
||||
return exact
|
||||
|
||||
compact_name = compact_catalog_slug(name)
|
||||
series_rows = self.session.scalars(
|
||||
select(AudiobookSeries).where(AudiobookSeries.author_id == author_id).order_by(AudiobookSeries.name),
|
||||
).all()
|
||||
for series in series_rows:
|
||||
if compact_catalog_slug(series.name) == compact_name:
|
||||
return series
|
||||
return None
|
||||
|
||||
def series_result(self, series: AudiobookSeries, action: str) -> dict[str, object]:
|
||||
"""Build a normalized series tool result."""
|
||||
return {
|
||||
"id": series.id,
|
||||
"name": series.name,
|
||||
"author_id": series.author_id,
|
||||
"author": series.author.name,
|
||||
"action": action,
|
||||
}
|
||||
|
||||
def book_result(self, book: Audiobook, action: str) -> dict[str, object]:
|
||||
"""Build a normalized book tool result."""
|
||||
return {
|
||||
"id": book.id,
|
||||
"title": book.title,
|
||||
"author_id": book.author_id,
|
||||
"author": book.author.name,
|
||||
"series_id": book.series_id,
|
||||
"series": book.series.name if book.series else self.config.standalone_series,
|
||||
"series_index": book.series_index,
|
||||
"action": action,
|
||||
}
|
||||
|
||||
|
||||
def run_tool_calls(
|
||||
messages: list[dict[str, object]],
|
||||
message: dict[str, object],
|
||||
tool_calls: list[tuple[str, dict[str, object]]],
|
||||
registry: CatalogToolRegistry,
|
||||
log_path: Path,
|
||||
write_log: LogWriter,
|
||||
) -> str | None:
|
||||
"""Run tool calls, append tool messages, and return fatal error text when stopped."""
|
||||
messages.append(message)
|
||||
for tool_name, arguments in tool_calls:
|
||||
try:
|
||||
tool_result = registry.run(tool_name, arguments)
|
||||
except MetadataResolutionError as error:
|
||||
if is_fatal_tool_error(error):
|
||||
return str(error)
|
||||
write_log(log_path, "tool_error", tool=tool_name, arguments=arguments, error=str(error))
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_name": tool_name,
|
||||
"content": json.dumps({"error": str(error)}, sort_keys=True),
|
||||
},
|
||||
)
|
||||
continue
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_name": tool_name,
|
||||
"content": json.dumps(tool_result, sort_keys=True),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def parse_tool_calls(message: dict[str, object]) -> list[tuple[str, dict[str, object]]]:
|
||||
"""Parse Ollama tool calls from a response message."""
|
||||
raw_tool_calls = message.get("tool_calls") or []
|
||||
if not isinstance(raw_tool_calls, list):
|
||||
msg = "tool_calls must be a list"
|
||||
raise MetadataResolutionError(msg)
|
||||
|
||||
tool_calls = []
|
||||
for raw_call in raw_tool_calls:
|
||||
if not isinstance(raw_call, dict):
|
||||
msg = "tool call must be an object"
|
||||
raise MetadataResolutionError(msg)
|
||||
function = raw_call.get("function")
|
||||
if not isinstance(function, dict):
|
||||
msg = "tool call is missing function"
|
||||
raise MetadataResolutionError(msg)
|
||||
name = function.get("name")
|
||||
if not isinstance(name, str) or not name:
|
||||
msg = "tool call is missing function name"
|
||||
raise MetadataResolutionError(msg)
|
||||
arguments = parse_tool_arguments(function.get("arguments", {}))
|
||||
tool_calls.append((name, arguments))
|
||||
return tool_calls
|
||||
|
||||
|
||||
def parse_tool_arguments(raw_arguments: object) -> dict[str, object]:
|
||||
"""Parse tool call arguments returned by Ollama."""
|
||||
if isinstance(raw_arguments, dict):
|
||||
return {str(key): value for key, value in raw_arguments.items()}
|
||||
if isinstance(raw_arguments, str):
|
||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||
if isinstance(parsed, dict):
|
||||
return {str(key): value for key, value in parsed.items()}
|
||||
msg = "tool arguments must be an object"
|
||||
raise MetadataResolutionError(msg)
|
||||
|
||||
|
||||
def validate_title_slug(title: str) -> None:
|
||||
"""Validate a canonical book title slug."""
|
||||
if not TITLE_SLUG_PATTERN.fullmatch(title):
|
||||
msg = f"title slug is invalid: {title}"
|
||||
raise MetadataResolutionError(msg)
|
||||
|
||||
|
||||
def validate_catalog_slug(value: str, label: str) -> None:
|
||||
"""Validate a canonical catalog slug."""
|
||||
if not CATALOG_SLUG_PATTERN.fullmatch(value):
|
||||
msg = f"{label} slug is invalid: {value}"
|
||||
raise MetadataResolutionError(msg)
|
||||
|
||||
|
||||
def normalize_catalog_slug(value: str) -> str:
|
||||
"""Normalize noisy catalog names into lower snake-case slugs."""
|
||||
return re.sub(r"[^a-z0-9]+", "_", value.strip().casefold()).strip("_")
|
||||
|
||||
|
||||
def compact_catalog_slug(value: str) -> str:
|
||||
"""Return a catalog slug comparison key that ignores underscores."""
|
||||
return normalize_catalog_slug(value).replace("_", "")
|
||||
|
||||
|
||||
def normalize_title_slug(value: str) -> str:
|
||||
"""Normalize noisy book titles into lower kebab-case slugs."""
|
||||
return re.sub(r"[^a-z0-9]+", "-", value.strip().casefold()).strip("-")
|
||||
|
||||
|
||||
def is_fatal_tool_error(error: MetadataResolutionError) -> bool:
|
||||
"""Return whether a tool error should stop the agent immediately."""
|
||||
message = str(error)
|
||||
return message.startswith(
|
||||
(
|
||||
"Unknown audiobook metadata tool",
|
||||
"Audiobook metadata tool is not enabled",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def query_terms(query: str) -> tuple[str, ...]:
|
||||
"""Return text variants useful for matching noisy audiobook metadata."""
|
||||
normalized = query.strip().casefold()
|
||||
underscore_slug = normalize_catalog_slug(normalized)
|
||||
compact_slug = compact_catalog_slug(normalized)
|
||||
hyphen_slug = normalize_title_slug(normalized)
|
||||
return tuple(dict.fromkeys(term for term in (normalized, underscore_slug, compact_slug, hyphen_slug) if term))
|
||||
|
||||
|
||||
def required_string(data: dict[str, object], key: str) -> str:
|
||||
"""Read a required string field."""
|
||||
value = data.get(key)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
msg = f"{key} must be a non-empty string"
|
||||
raise MetadataResolutionError(msg)
|
||||
return value.strip()
|
||||
|
||||
|
||||
def required_int(data: dict[str, object], key: str) -> int:
|
||||
"""Read a required integer field."""
|
||||
value = data.get(key)
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
msg = f"{key} must be an integer"
|
||||
raise MetadataResolutionError(msg)
|
||||
return value
|
||||
|
||||
|
||||
def required_series_index(data: dict[str, object], key: str) -> float:
|
||||
"""Read a required whole-number or half-number series index."""
|
||||
value = data.get(key)
|
||||
if isinstance(value, bool) or not isinstance(value, int | float):
|
||||
msg = f"{key} must be a number"
|
||||
raise MetadataResolutionError(msg)
|
||||
series_index = float(value)
|
||||
if not (series_index * 2).is_integer():
|
||||
msg = f"{key} must be a whole number or .5 increment"
|
||||
raise MetadataResolutionError(msg)
|
||||
return series_index
|
||||
|
||||
|
||||
def optional_int(value: object, key: str) -> int | None:
|
||||
"""Read an optional integer field."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
msg = f"{key} must be an integer or null"
|
||||
raise MetadataResolutionError(msg)
|
||||
return value
|
||||
@@ -1,575 +0,0 @@
|
||||
"""Resolve audiobook metadata with a controlled Ollama tool loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import asdict, dataclass, is_dataclass, replace
|
||||
from os import PathLike
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.common import utcnow
|
||||
from python.tools.audiobook.llm_tool_calling import (
|
||||
CatalogToolRegistry,
|
||||
MetadataResolutionError,
|
||||
normalize_title_slug,
|
||||
optional_int,
|
||||
parse_tool_calls,
|
||||
required_int,
|
||||
required_series_index,
|
||||
required_string,
|
||||
run_tool_calls,
|
||||
validate_catalog_slug,
|
||||
validate_title_slug,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from python.orm.richie import AudiobookAuthor
|
||||
|
||||
FENCED_JSON_PATTERN = re.compile(r"^```(?:json)?\s*(?P<json>.*?)\s*```$", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentConfig:
|
||||
"""Runtime settings for the audiobook metadata agent."""
|
||||
|
||||
model: str = "deepseek-v4-flash:cloud"
|
||||
ollama_chat_url: str = "https://ollama.com/api/chat"
|
||||
http_timeout_seconds: int = 300
|
||||
max_agent_turns: int = 8
|
||||
max_tool_results: int = 10
|
||||
min_confidence: float = 0.85
|
||||
invalid_final_retries: int = 1
|
||||
standalone_series: str = "standalone"
|
||||
tool_names: tuple[str, ...] = (
|
||||
"search_authors",
|
||||
"search_series",
|
||||
"search_books",
|
||||
"ensure_author",
|
||||
"ensure_series",
|
||||
"ensure_book",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StandardBookMetadata:
|
||||
"""Canonical metadata for the final audiobook path."""
|
||||
|
||||
author_id: int
|
||||
author: str
|
||||
book_id: int | None
|
||||
title: str
|
||||
series_id: int | None
|
||||
series: str
|
||||
series_index: float
|
||||
confidence: float
|
||||
needs_review: bool
|
||||
evidence: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinalMetadataFields:
|
||||
"""Raw model fields after schema validation."""
|
||||
|
||||
author_id: int
|
||||
book_id: int | None
|
||||
title: str
|
||||
series_id: int | None
|
||||
series_index: float
|
||||
confidence: float
|
||||
evidence: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedBookFields:
|
||||
"""Book fields after optional catalog book resolution."""
|
||||
|
||||
book_id: int | None
|
||||
title: str
|
||||
series_id: int | None
|
||||
series_index: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentStepResult:
|
||||
"""Outcome from one model response."""
|
||||
|
||||
metadata: StandardBookMetadata | None
|
||||
invalid_final_count: int
|
||||
should_continue: bool
|
||||
|
||||
|
||||
def standard_book_metadata(
|
||||
aax_file_name: str,
|
||||
aax_metadata_from_ffprobe: dict[str, str],
|
||||
engine: Engine,
|
||||
log_path: Path,
|
||||
ollama_api_key: str,
|
||||
config: AgentConfig,
|
||||
) -> StandardBookMetadata:
|
||||
"""Resolve canonical audiobook metadata with the configured Ollama Cloud model."""
|
||||
with Session(engine) as session:
|
||||
registry = CatalogToolRegistry(session, log_path, config, write_agent_log)
|
||||
agent = AudiobookMetadataAgent(
|
||||
registry=registry, log_path=log_path, ollama_api_key=ollama_api_key, config=config
|
||||
)
|
||||
metadata = agent.run(aax_file_name, aax_metadata_from_ffprobe)
|
||||
if metadata.needs_review:
|
||||
session.rollback()
|
||||
else:
|
||||
registry.prune_unused_created_rows(
|
||||
author_id=metadata.author_id,
|
||||
book_id=metadata.book_id,
|
||||
series_id=metadata.series_id,
|
||||
)
|
||||
session.commit()
|
||||
return metadata
|
||||
|
||||
|
||||
class AudiobookMetadataAgent:
|
||||
"""Ollama-backed metadata resolver with a fixed local tool registry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
registry: CatalogToolRegistry,
|
||||
log_path: Path,
|
||||
ollama_api_key: str,
|
||||
config: AgentConfig,
|
||||
) -> None:
|
||||
"""Create an Ollama metadata agent."""
|
||||
self._registry = registry
|
||||
self._log_path = log_path
|
||||
self._ollama_api_key = ollama_api_key
|
||||
self._config = config
|
||||
|
||||
def run(self, aax_file_name: str, aax_metadata_from_ffprobe: dict[str, str]) -> StandardBookMetadata:
|
||||
"""Resolve metadata for one AAX file."""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt()},
|
||||
{"role": "user", "content": user_prompt(aax_file_name, aax_metadata_from_ffprobe)},
|
||||
]
|
||||
invalid_final_count = 0
|
||||
result: StandardBookMetadata | None = None
|
||||
|
||||
for turn in range(1, self._config.max_agent_turns + 1):
|
||||
step = self.run_step(messages, turn, invalid_final_count)
|
||||
invalid_final_count = step.invalid_final_count
|
||||
if step.should_continue:
|
||||
continue
|
||||
result = step.metadata
|
||||
break
|
||||
|
||||
if result is None:
|
||||
return self.force_final_response(messages)
|
||||
return result
|
||||
|
||||
def run_step(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
turn: int,
|
||||
invalid_final_count: int,
|
||||
) -> AgentStepResult:
|
||||
"""Run one model turn and return the next agent-loop action."""
|
||||
data = self.chat(messages, turn)
|
||||
message = data.get("message")
|
||||
if not isinstance(message, dict):
|
||||
return AgentStepResult(
|
||||
metadata=review_metadata("Ollama response did not include a message", self._config),
|
||||
invalid_final_count=invalid_final_count,
|
||||
should_continue=False,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_calls = parse_tool_calls(message)
|
||||
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
||||
return AgentStepResult(
|
||||
metadata=review_metadata(str(error), self._config),
|
||||
invalid_final_count=invalid_final_count,
|
||||
should_continue=False,
|
||||
)
|
||||
if tool_calls:
|
||||
fatal_error = run_tool_calls(messages, message, tool_calls, self._registry, self._log_path, write_agent_log)
|
||||
if fatal_error is not None:
|
||||
return AgentStepResult(
|
||||
metadata=review_metadata(fatal_error, self._config),
|
||||
invalid_final_count=invalid_final_count,
|
||||
should_continue=False,
|
||||
)
|
||||
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
||||
return self.handle_final_message(messages, message, invalid_final_count)
|
||||
|
||||
def handle_final_message(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
message: dict[str, object],
|
||||
invalid_final_count: int,
|
||||
) -> AgentStepResult:
|
||||
"""Validate a final model message or request one retry."""
|
||||
content = message.get("content")
|
||||
if not isinstance(content, str):
|
||||
return AgentStepResult(
|
||||
metadata=review_metadata("Ollama final response did not include string content", self._config),
|
||||
invalid_final_count=invalid_final_count,
|
||||
should_continue=False,
|
||||
)
|
||||
|
||||
try:
|
||||
resolved = self.validate_final(parse_final_json_content(content))
|
||||
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
||||
return self.handle_invalid_final(messages, error, invalid_final_count)
|
||||
|
||||
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
|
||||
return AgentStepResult(metadata=resolved, invalid_final_count=invalid_final_count, should_continue=False)
|
||||
|
||||
def handle_invalid_final(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
error: json.JSONDecodeError | MetadataResolutionError,
|
||||
invalid_final_count: int,
|
||||
) -> AgentStepResult:
|
||||
"""Log invalid final JSON and either retry or return review metadata."""
|
||||
invalid_final_count += 1
|
||||
write_agent_log(
|
||||
self._log_path,
|
||||
"final_validation_error",
|
||||
error=str(error),
|
||||
invalid_final_count=invalid_final_count,
|
||||
)
|
||||
if invalid_final_count > self._config.invalid_final_retries:
|
||||
return AgentStepResult(
|
||||
metadata=review_metadata(str(error), self._config),
|
||||
invalid_final_count=invalid_final_count,
|
||||
should_continue=False,
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous final answer was invalid. Return only valid JSON matching the required "
|
||||
f"schema. Validation error: {error}"
|
||||
),
|
||||
},
|
||||
)
|
||||
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
||||
|
||||
def force_final_response(self, messages: list[dict[str, object]]) -> StandardBookMetadata:
|
||||
"""Request a no-tool final answer after the normal turn limit."""
|
||||
messages.append({"role": "user", "content": forced_final_prompt()})
|
||||
write_agent_log(self._log_path, "forced_final_request", reason="max_turns")
|
||||
data = self.chat(messages, self._config.max_agent_turns + 1, tools_enabled=False)
|
||||
message = data.get("message")
|
||||
if not isinstance(message, dict):
|
||||
return review_metadata("Ollama forced final response did not include a message", self._config)
|
||||
content = message.get("content")
|
||||
if not isinstance(content, str):
|
||||
return review_metadata("Ollama forced final response did not include string content", self._config)
|
||||
try:
|
||||
resolved = self.validate_final(parse_final_json_content(content))
|
||||
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
||||
return review_metadata(f"Ollama forced final response was invalid: {error}", self._config)
|
||||
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
|
||||
return resolved
|
||||
|
||||
def chat(self, messages: list[dict[str, object]], turn: int, *, tools_enabled: bool = True) -> dict[str, object]:
|
||||
"""Send one chat request to Ollama and log the request and response."""
|
||||
payload = {
|
||||
"model": self._config.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1},
|
||||
}
|
||||
tool_names = []
|
||||
if tools_enabled:
|
||||
payload["tools"] = self._registry.tool_schemas()
|
||||
tool_names = self._config.tool_names
|
||||
write_agent_log(
|
||||
self._log_path,
|
||||
"model_request",
|
||||
model=self._config.model,
|
||||
turn=turn,
|
||||
message_count=len(messages),
|
||||
tool_names=tool_names,
|
||||
tools_enabled=tools_enabled,
|
||||
)
|
||||
write_agent_log(
|
||||
self._log_path,
|
||||
"llm_messages_sent",
|
||||
model=self._config.model,
|
||||
turn=turn,
|
||||
messages=messages,
|
||||
tools_enabled=tools_enabled,
|
||||
)
|
||||
response = httpx.post(
|
||||
self._config.ollama_chat_url,
|
||||
headers={"Authorization": f"Bearer {self._ollama_api_key}"},
|
||||
json=payload,
|
||||
timeout=self._config.http_timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
raw_data = response.json()
|
||||
if not isinstance(raw_data, dict):
|
||||
return {}
|
||||
data = {str(key): value for key, value in raw_data.items()}
|
||||
message = data.get("message", {})
|
||||
content = message.get("content") if isinstance(message, dict) else ""
|
||||
write_agent_log(
|
||||
self._log_path,
|
||||
"llm_message_received",
|
||||
model=self._config.model,
|
||||
turn=turn,
|
||||
message=message,
|
||||
)
|
||||
write_agent_log(
|
||||
self._log_path,
|
||||
"model_response",
|
||||
model=self._config.model,
|
||||
turn=turn,
|
||||
has_tool_calls=bool(isinstance(message, dict) and message.get("tool_calls")),
|
||||
content_chars=len(content) if isinstance(content, str) else 0,
|
||||
)
|
||||
return data
|
||||
|
||||
def validate_final(self, raw_metadata: object) -> StandardBookMetadata:
|
||||
"""Validate final model metadata against catalog rows."""
|
||||
fields = parse_final_metadata_fields(raw_metadata)
|
||||
fields = replace(fields, title=normalize_title_slug(fields.title))
|
||||
author = self.validate_author(fields.author_id)
|
||||
validate_title_slug(fields.title)
|
||||
book_fields = self.resolve_book_fields(fields)
|
||||
series = self.validate_series(fields.author_id, book_fields.series_id, book_fields.series_index)
|
||||
|
||||
return StandardBookMetadata(
|
||||
author_id=fields.author_id,
|
||||
author=author.name,
|
||||
book_id=book_fields.book_id,
|
||||
title=book_fields.title,
|
||||
series_id=book_fields.series_id,
|
||||
series=series,
|
||||
series_index=book_fields.series_index,
|
||||
confidence=fields.confidence,
|
||||
needs_review=fields.confidence < self._config.min_confidence,
|
||||
evidence=fields.evidence,
|
||||
)
|
||||
|
||||
def validate_author(self, author_id: int) -> AudiobookAuthor:
|
||||
"""Validate that an author id was seen and exists."""
|
||||
if author_id not in self._registry.seen_author_ids:
|
||||
msg = f"author_id {author_id} was not returned by search_authors"
|
||||
raise MetadataResolutionError(msg)
|
||||
author = self._registry.get_author(author_id)
|
||||
if author is None:
|
||||
msg = f"author_id {author_id} does not exist"
|
||||
raise MetadataResolutionError(msg)
|
||||
validate_catalog_slug(author.name, "author")
|
||||
return author
|
||||
|
||||
def resolve_book_fields(self, fields: FinalMetadataFields) -> ResolvedBookFields:
|
||||
"""Resolve final book fields from a seen book id or created book."""
|
||||
if fields.book_id is None:
|
||||
ensured = self._registry.ensure_book(
|
||||
fields.title,
|
||||
fields.author_id,
|
||||
fields.series_id,
|
||||
fields.series_index,
|
||||
)
|
||||
return ResolvedBookFields(
|
||||
book_id=ensured.book.id,
|
||||
title=ensured.book.title,
|
||||
series_id=ensured.book.series_id,
|
||||
series_index=ensured.book.series_index,
|
||||
)
|
||||
|
||||
if fields.book_id not in self._registry.seen_book_ids:
|
||||
msg = f"book_id {fields.book_id} was not returned by search_books"
|
||||
raise MetadataResolutionError(msg)
|
||||
book = self._registry.get_book(fields.book_id)
|
||||
if book is None:
|
||||
msg = f"book_id {fields.book_id} does not exist"
|
||||
raise MetadataResolutionError(msg)
|
||||
if book.author_id != fields.author_id:
|
||||
msg = f"book_id {fields.book_id} does not belong to author_id {fields.author_id}"
|
||||
raise MetadataResolutionError(msg)
|
||||
return ResolvedBookFields(
|
||||
book_id=fields.book_id,
|
||||
title=book.title,
|
||||
series_id=book.series_id,
|
||||
series_index=book.series_index,
|
||||
)
|
||||
|
||||
def validate_series(self, author_id: int, series_id: int | None, series_index: float) -> str:
|
||||
"""Validate final series fields and return the canonical series slug."""
|
||||
if series_id is None:
|
||||
if series_index != 0:
|
||||
msg = "standalone books must use series_index 0"
|
||||
raise MetadataResolutionError(msg)
|
||||
return self._config.standalone_series
|
||||
|
||||
if series_id not in self._registry.seen_series_ids:
|
||||
msg = f"series_id {series_id} was not returned by search_series"
|
||||
raise MetadataResolutionError(msg)
|
||||
series = self._registry.get_series(series_id)
|
||||
if series is None:
|
||||
msg = f"series_id {series_id} does not exist"
|
||||
raise MetadataResolutionError(msg)
|
||||
if series.author_id != author_id:
|
||||
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
||||
raise MetadataResolutionError(msg)
|
||||
if series_index <= 0:
|
||||
msg = "series books must use a positive series_index"
|
||||
raise MetadataResolutionError(msg)
|
||||
validate_catalog_slug(series.name, "series")
|
||||
return series.name
|
||||
|
||||
|
||||
def write_agent_log(log_path: Path, event: str, **fields: object) -> None:
|
||||
"""Append one JSONL audit event."""
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
record = {
|
||||
"created": utcnow().isoformat(),
|
||||
"event": event,
|
||||
**{key: json_log_value(value) for key, value in fields.items()},
|
||||
}
|
||||
with log_path.open("a", encoding="utf-8") as file:
|
||||
file.write(json.dumps(record, sort_keys=True))
|
||||
file.write("\n")
|
||||
|
||||
|
||||
def json_log_value(value: object) -> object:
|
||||
"""Return a JSON-serializable value for audit logs."""
|
||||
if is_dataclass(value) and not isinstance(value, type):
|
||||
return json_log_value(asdict(value))
|
||||
if isinstance(value, dict):
|
||||
return {str(key): json_log_value(item) for key, item in value.items()}
|
||||
if isinstance(value, list | tuple):
|
||||
return [json_log_value(item) for item in value]
|
||||
if isinstance(value, set):
|
||||
return [json_log_value(item) for item in sorted(value, key=str)]
|
||||
if isinstance(value, PathLike):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
def system_prompt() -> str:
|
||||
"""Return the stable system prompt."""
|
||||
return """You standardize Audible audiobook metadata against a private catalog.
|
||||
|
||||
Rules:
|
||||
- You must use the provided tools before returning final metadata.
|
||||
- Only use author_id, series_id, or book_id values returned by tools.
|
||||
- Return final metadata as JSON only. Do not wrap it in Markdown.
|
||||
- The final JSON object must contain author_id, book_id, title, series_id, series_index, confidence, and evidence.
|
||||
- title must be a canonical title slug using lower-case words separated by hyphens.
|
||||
- Use series_id null and series_index 0 for standalone books.
|
||||
- If you use a series_id, series_index must be a whole number or .5 value greater than 0.
|
||||
- Treat series slugs that differ only by underscores as the same series. Prefer the existing catalog row instead of
|
||||
creating a new series.
|
||||
- Detect omnibus or box-set editions that contain multiple numbered novels, books, or novellas.
|
||||
- For an omnibus, make a best-effort range from the filename, tags, and catalog rows. Keep series_index as the
|
||||
first covered book number and include the range in the title when the source title includes it, for example
|
||||
books-1-3.
|
||||
- Be careful with omnibuses of novels or novellas later published as one book: keep the omnibus as the audiobook's
|
||||
book record unless catalog rows clearly identify a better match.
|
||||
- Do not create publisher collections or author collections as series unless the book metadata clearly gives a
|
||||
numbered series.
|
||||
- Series belong to authors. Use a series_id only when it belongs to the selected author_id.
|
||||
- Always search for the author before creating one. If no exact author slug exists, call ensure_author.
|
||||
- Always search for a series with author_id before creating one. If no exact series slug exists, call ensure_series.
|
||||
- Always search for a book before creating one. If no exact title slug exists, call ensure_book.
|
||||
- If a tool returns an error, correct your tool arguments or final metadata before continuing.
|
||||
- confidence must be a number from 0 to 1.
|
||||
- evidence must be a short list of strings explaining which filename, tags, and catalog rows support the answer."""
|
||||
|
||||
|
||||
def forced_final_prompt() -> str:
|
||||
"""Return the no-tools finalization prompt."""
|
||||
return (
|
||||
"Stop calling tools. Return final metadata as JSON only using the tool results already provided. "
|
||||
"If search_books returned no matching rows but author and series are known, use book_id null and resolve "
|
||||
"the title slug from the AAX filename and ffprobe tags. The validator will create the missing book. "
|
||||
"Use only author_id and series_id values returned by earlier tool results."
|
||||
)
|
||||
|
||||
|
||||
def user_prompt(aax_file_name: str, metadata: dict[str, str]) -> str:
|
||||
"""Build the user prompt from source metadata."""
|
||||
return (
|
||||
"Resolve this Audible audiobook.\n\n"
|
||||
f"AAX file name: {aax_file_name}\n\n"
|
||||
"ffprobe format tags:\n"
|
||||
f"{json.dumps(metadata, indent=2, sort_keys=True)}"
|
||||
)
|
||||
|
||||
|
||||
def parse_final_json_content(content: str) -> object:
|
||||
"""Parse final model content, accepting bare or fenced JSON."""
|
||||
stripped = content.strip()
|
||||
if match := FENCED_JSON_PATTERN.fullmatch(stripped):
|
||||
stripped = match.group("json").strip()
|
||||
return json.loads(stripped)
|
||||
|
||||
|
||||
def parse_final_metadata_fields(raw_metadata: object) -> FinalMetadataFields:
|
||||
"""Parse the model's final JSON object into typed fields."""
|
||||
if not isinstance(raw_metadata, dict):
|
||||
msg = "Final metadata must be a JSON object"
|
||||
raise MetadataResolutionError(msg)
|
||||
data = {str(key): value for key, value in raw_metadata.items()}
|
||||
return FinalMetadataFields(
|
||||
author_id=required_int(data, "author_id"),
|
||||
book_id=optional_int(data.get("book_id"), "book_id"),
|
||||
title=required_string(data, "title"),
|
||||
series_id=optional_int(data.get("series_id"), "series_id"),
|
||||
series_index=required_series_index(data, "series_index"),
|
||||
confidence=required_float(data, "confidence"),
|
||||
evidence=required_string_list(data, "evidence"),
|
||||
)
|
||||
|
||||
|
||||
def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
|
||||
"""Return a metadata result that must be reviewed manually."""
|
||||
return StandardBookMetadata(
|
||||
author_id=0,
|
||||
author="unknown_author",
|
||||
book_id=None,
|
||||
title="unknown-title",
|
||||
series_id=None,
|
||||
series=config.standalone_series,
|
||||
series_index=0,
|
||||
confidence=0,
|
||||
needs_review=True,
|
||||
evidence=[reason],
|
||||
)
|
||||
|
||||
|
||||
def required_float(data: dict[str, object], key: str) -> float:
|
||||
"""Read a required float field."""
|
||||
value = data.get(key)
|
||||
if isinstance(value, bool) or not isinstance(value, int | float):
|
||||
msg = f"{key} must be a number"
|
||||
raise MetadataResolutionError(msg)
|
||||
confidence = float(value)
|
||||
if confidence < 0 or confidence > 1:
|
||||
msg = f"{key} must be between 0 and 1"
|
||||
raise MetadataResolutionError(msg)
|
||||
return confidence
|
||||
|
||||
|
||||
def required_string_list(data: dict[str, object], key: str) -> list[str]:
|
||||
"""Read a required list of strings."""
|
||||
value = data.get(key)
|
||||
if not isinstance(value, list) or not value or not all(isinstance(item, str) for item in value):
|
||||
msg = f"{key} must be a non-empty list of strings"
|
||||
raise MetadataResolutionError(msg)
|
||||
strings = [item.strip() for item in value if item.strip()]
|
||||
if not strings:
|
||||
msg = f"{key} must include at least one non-empty string"
|
||||
raise MetadataResolutionError(msg)
|
||||
return strings
|
||||
@@ -28,14 +28,7 @@
|
||||
networking = {
|
||||
hostName = "bob";
|
||||
hostId = "7c678a41";
|
||||
firewall = {
|
||||
enable = true;
|
||||
allowedTCPPorts = [
|
||||
8000
|
||||
8001
|
||||
8002
|
||||
];
|
||||
};
|
||||
firewall.enable = true;
|
||||
networkmanager.enable = true;
|
||||
};
|
||||
|
||||
|
||||
@@ -30,11 +30,6 @@
|
||||
keyFile = "/dev/disk/by-id/usb-Samsung_Flash_Drive_FIT_0374620080067131-0:0";
|
||||
};
|
||||
};
|
||||
|
||||
zfs.extraPools = [
|
||||
"storage"
|
||||
];
|
||||
|
||||
kernelModules = [ "kvm-amd" ];
|
||||
extraModulePackages = [ ];
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
host = "0.0.0.0";
|
||||
enable = true;
|
||||
|
||||
syncModels = false;
|
||||
syncModels = true;
|
||||
loadModels = [
|
||||
"codellama:7b"
|
||||
"deepscaler:1.5b"
|
||||
|
||||
@@ -17,9 +17,6 @@
|
||||
allowedTCPPorts = [ ];
|
||||
allowedUDPPorts = [ ];
|
||||
};
|
||||
allowedTCPPorts = [
|
||||
8070
|
||||
];
|
||||
};
|
||||
useNetworkd = true;
|
||||
};
|
||||
|
||||
@@ -6,7 +6,7 @@ in
|
||||
user = "ollama";
|
||||
enable = true;
|
||||
host = "0.0.0.0";
|
||||
syncModels = false;
|
||||
syncModels = true;
|
||||
loadModels = [
|
||||
"codellama:7b"
|
||||
"deepscaler:1.5b"
|
||||
@@ -30,9 +30,6 @@ in
|
||||
"ministral-3:14b"
|
||||
"nemotron-3-nano:30b"
|
||||
"qwen3-coder:30b"
|
||||
"qwen3-embedding:0.6b"
|
||||
"qwen3-embedding:4b"
|
||||
"qwen3-embedding:8b"
|
||||
"qwen3-vl:32b"
|
||||
"qwen3:14b"
|
||||
"qwen3.5:35b"
|
||||
|
||||
@@ -38,6 +38,9 @@ in
|
||||
# signalbot
|
||||
local signalbot signalbot trust
|
||||
|
||||
# hedgedoc
|
||||
local hedgedoc hedgedoc trust
|
||||
|
||||
# math
|
||||
local postgres math trust
|
||||
host postgres math 127.0.0.1/32 trust
|
||||
@@ -117,11 +120,19 @@ in
|
||||
login = true;
|
||||
};
|
||||
}
|
||||
{
|
||||
name = "hedgedoc";
|
||||
ensureDBOwnership = true;
|
||||
ensureClauses = {
|
||||
login = true;
|
||||
};
|
||||
}
|
||||
];
|
||||
ensureDatabases = [
|
||||
"data_science_dev"
|
||||
"hass"
|
||||
"gitea"
|
||||
"hedgedoc"
|
||||
"math"
|
||||
"n8n"
|
||||
"richie"
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Shared test fixtures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, event
|
||||
|
||||
from python.orm.signal_bot.base import SignalBotBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sqlite_engine() -> Generator[Engine]:
|
||||
"""Create an in-memory SQLite engine for testing."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_sqlite_pragma(dbapi_connection, _connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
SignalBotBase.metadata.create_all(engine)
|
||||
yield engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(sqlite_engine: Engine) -> Generator[Engine]:
|
||||
"""Yield the shared engine after cleaning all tables between tests."""
|
||||
yield sqlite_engine
|
||||
with sqlite_engine.begin() as connection:
|
||||
for table in reversed(SignalBotBase.metadata.sorted_tables):
|
||||
connection.execute(table.delete())
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user