Compare commits

..

24 Commits

Author SHA1 Message Date
Richie 09d847e66a add .ebook_search_bm25 to gitignore
treefmt / nix fmt (pull_request) Failing after 6s
build_systems / build-brain (pull_request) Successful in 51s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m7s
pytest / pytest (pull_request) Successful in 31s
build_systems / build-bob (pull_request) Successful in 52s
build_systems / build-leviathan (pull_request) Successful in 1m5s
build_systems / build-jeeves (pull_request) Successful in 2m43s
2026-06-12 13:11:17 -04:00
Richie b68db3bcaf updated python
treefmt / nix fmt (pull_request) Failing after 6s
pytest / pytest (pull_request) Successful in 30s
build_systems / build-brain (pull_request) Successful in 55s
build_systems / build-bob (pull_request) Successful in 59s
build_systems / build-leviathan (pull_request) Successful in 1m33s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m38s
build_systems / build-jeeves (pull_request) Successful in 2m35s
2026-06-12 03:12:26 -04:00
Richie 41592491dc setup tests
treefmt / nix fmt (pull_request) Failing after 7s
pytest / pytest (pull_request) Failing after 12s
build_systems / build-brain (pull_request) Successful in 44s
build_systems / build-bob (pull_request) Successful in 46s
build_systems / build-leviathan (pull_request) Successful in 54s
build_systems / build-rhapsody-in-green (pull_request) Successful in 59s
build_systems / build-jeeves (pull_request) Successful in 2m29s
2026-06-12 03:10:53 -04:00
Richie 9850a13dbe build api and frountend 2026-06-12 03:10:19 -04:00
Richie 258919c461 added answer.py and config 2026-06-12 03:09:51 -04:00
Richie a2a5f145dc added __init__ 2026-06-12 03:09:10 -04:00
Richie f5368879c9 made llm_interface.py 2026-06-12 03:08:21 -04:00
Richie 0325013323 added rerank 2026-06-12 03:06:27 -04:00
Richie e3266ac694 built ingest 2026-06-12 03:06:16 -04:00
Richie 7010a4f3b9 built rag search setup 2026-06-12 02:45:44 -04:00
Richie e8f5a3d334 set up embedding system 2026-06-12 01:57:36 -04:00
Richie fc0fd6e3f5 built BM25 search foundation 2026-06-12 01:30:20 -04:00
Richie e1078bc084 clean up 2026-06-11 22:08:48 -04:00
Richie 8354d5613c added ebook embedding to orm 2026-06-11 18:08:31 -04:00
Richie 978594bc61 removed hedgedoc 2026-06-10 20:18:16 -04:00
Richie 799f03d90e adding embedding Models to jeeves 2026-06-10 20:16:59 -04:00
Richie 74e9d84f57 updated series_index to float and added UniqueConstraint to audiobook and audiobook_author
treefmt / nix fmt (pull_request) Failing after 6s
pytest / pytest (pull_request) Successful in 30s
build_systems / build-brain (pull_request) Successful in 47s
build_systems / build-bob (pull_request) Successful in 48s
build_systems / build-leviathan (pull_request) Successful in 55s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m1s
build_systems / build-jeeves (pull_request) Successful in 2m37s
2026-06-10 20:07:27 -04:00
Richie 65c4dc0f18 fixed omnibus for audio books
treefmt / nix fmt (pull_request) Failing after 7s
pytest / pytest (pull_request) Successful in 29s
build_systems / build-brain (pull_request) Successful in 48s
build_systems / build-bob (pull_request) Successful in 50s
build_systems / build-leviathan (pull_request) Successful in 54s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m2s
build_systems / build-jeeves (pull_request) Successful in 2m31s
2026-06-10 16:33:16 -04:00
Richie d111e27cad moved installer to python dir 2026-06-10 15:15:59 -04:00
Richie f01d2cc001 deleted frontend dir 2026-06-10 15:15:46 -04:00
Richie cf8c91635a added llm_tool_calling.py
treefmt / nix fmt (pull_request) Successful in 6s
pytest / pytest (pull_request) Successful in 27s
build_systems / build-bob (pull_request) Successful in 53s
build_systems / build-brain (pull_request) Successful in 53s
build_systems / build-leviathan (pull_request) Successful in 56s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m6s
build_systems / build-jeeves (pull_request) Successful in 2m41s
2026-06-07 10:39:29 -04:00
Richie 79162b65d4 built workflow 2026-06-07 10:39:29 -04:00
Richie 67a95657a2 Add catalog.py for manually adding authors and series to the database. 2026-06-07 10:33:26 -04:00
Richie a31b83b9c9 adding audiobook data to DB 2026-06-07 10:33:26 -04:00
71 changed files with 9166 additions and 582 deletions
-1
View File
@@ -8,7 +8,6 @@ jobs:
lockfile:
runs-on: self-hosted
permissions:
actions: write
contents: write
pull-requests: write
steps:
+12
View File
@@ -0,0 +1,12 @@
## Dev environment tips
- use treefmt to format all files
- make python code ruff compliant
- use pytest to test python code
- always use the minimum amount of complexity
- if judgment calls are easy to reverse make them. if not ask me first
- Match existing code style.
- Use builtin helpers getenv() over os.environ.get.
- Prefer single-purpose functions over “do everything” helpers.
- Avoid compatibility branches like PG_USER and POSTGRESQL_URL unless requested.
- Keep helpers only if reused or they simplify the code otherwise inline.
File diff suppressed because one or more lines are too long
Generated
+15 -15
View File
@@ -8,11 +8,11 @@
},
"locked": {
"dir": "pkgs/firefox-addons",
"lastModified": 1781150628,
"narHash": "sha256-b4mp8l3qWuSCyYYo9HSngDtcB3PpecYiOXjULrjwwlw=",
"lastModified": 1780733803,
"narHash": "sha256-QBJPq12P1DAXFGezoEJaSO/xPUrPlnaI3ddSaMG2JpM=",
"owner": "rycee",
"repo": "nur-expressions",
"rev": "753319310f4673a2dabbfab87482187b40bf9bac",
"rev": "c80b0aa94392c5f3612ac797108f6d952752036d",
"type": "gitlab"
},
"original": {
@@ -29,11 +29,11 @@
]
},
"locked": {
"lastModified": 1781189114,
"narHash": "sha256-5inaamLgUMWy+MOBE9ChF9QAF1o/74LFuHkI0W/9rqc=",
"lastModified": 1780679734,
"narHash": "sha256-KmRNvpNOb7QEORa06bVgjW9kITcx0VhsI7w0vhmZyD8=",
"owner": "nix-community",
"repo": "home-manager",
"rev": "486595d2cf49cfcd649b58a284fa11ac0e34da22",
"rev": "b2b7db486e06e098711dc291bb25db82850e1d16",
"type": "github"
},
"original": {
@@ -47,11 +47,11 @@
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1781168557,
"narHash": "sha256-LOnLQ2tpYF9gqIDDr3+j3DbpJJr/QCH6zPRT2GzEUOE=",
"lastModified": 1780310866,
"narHash": "sha256-fPBRVf6A5xlACYcOI59shGrjURuvwu0lRsDoSCEXt/I=",
"owner": "nixos",
"repo": "nixos-hardware",
"rev": "6358ff76821101c178e3ab4919a62799bfe3652e",
"rev": "4ed851c979641e28597a05086332d75cdc9e395f",
"type": "github"
},
"original": {
@@ -76,11 +76,11 @@
},
"nixpkgs-master": {
"locked": {
"lastModified": 1781229721,
"narHash": "sha256-ORvqDbb/LYxiJljGIejapjkc/kJbVote2N1WSb9W45I=",
"lastModified": 1780798858,
"narHash": "sha256-4KLc5ZMjfMQosXA2JasUgZTk3i+c/i1zMH4custtmI0=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "173d0ad7a974f8543a9ab01d2271b2e290341b33",
"rev": "92840095e65b9970125843175f4be974b71a92ad",
"type": "github"
},
"original": {
@@ -108,11 +108,11 @@
},
"nixpkgs_2": {
"locked": {
"lastModified": 1781074563,
"narHash": "sha256-md8WlXOlfnIeHeOScMTTHFyf2d6iaTwPl2apR5EQ3P4=",
"lastModified": 1780243769,
"narHash": "sha256-x5UQuRsH3MqI0U9afaXSNqzTPSeZlRLvFAav2Ux1pNw=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "9ae611a455b90cf061d8f332b977e387bda8e1ca",
"rev": "331800de5053fcebacf6813adb5db9c9dca22a0c",
"type": "github"
},
"original": {
+3
View File
@@ -84,6 +84,9 @@ lint.ignore = [
"python/alembic/**" = [
"INP001", # (perm) this creates LSP issues for alembic
]
"python/signal_bot/**" = [
"D107", # (perm) class docstrings cover __init__
]
[tool.ruff.lint.pydocstyle]
convention = "google"
@@ -0,0 +1,50 @@
"""adding FailedIngestion.
Revision ID: 2f43120e3ffc
Revises: f99be864fe69
Create Date: 2026-03-24 23:46:17.277897
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "2f43120e3ffc"
down_revision: str | None = "f99be864fe69"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"failed_ingestion",
sa.Column("raw_line", sa.Text(), nullable=False),
sa.Column("error", sa.Text(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_failed_ingestion")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("failed_ingestion", schema=schema)
# ### end Alembic commands ###
@@ -0,0 +1,72 @@
"""Attach all partition tables to the posts parent table.
Alembic autogenerate creates partition tables as standalone tables but does not
emit the ALTER TABLE ... ATTACH PARTITION statements needed for PostgreSQL to
route inserts to the correct partition.
Revision ID: a1b2c3d4e5f6
Revises: 605b1794838f
Create Date: 2026-03-25 10:00:00.000000
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from alembic import op
from sqlalchemy import text
from python.orm import DataScienceDevBase
from python.orm.data_science_dev.posts.partitions import (
PARTITION_END_YEAR,
PARTITION_START_YEAR,
iso_weeks_in_year,
week_bounds,
)
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "a1b2c3d4e5f6"
down_revision: str | None = "605b1794838f"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
ALREADY_ATTACHED_QUERY = text("""
SELECT inhrelid::regclass::text
FROM pg_inherits
WHERE inhparent = :parent::regclass
""")
def upgrade() -> None:
"""Attach all weekly partition tables to the posts parent table."""
connection = op.get_bind()
already_attached = {row[0] for row in connection.execute(ALREADY_ATTACHED_QUERY, {"parent": f"{schema}.posts"})}
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
table_name = f"posts_{year}_{week:02d}"
qualified_name = f"{schema}.{table_name}"
if qualified_name in already_attached:
continue
start, end = week_bounds(year, week)
start_str = start.strftime("%Y-%m-%d %H:%M:%S")
end_str = end.strftime("%Y-%m-%d %H:%M:%S")
op.execute(
f"ALTER TABLE {schema}.posts "
f"ATTACH PARTITION {qualified_name} "
f"FOR VALUES FROM ('{start_str}') TO ('{end_str}')"
)
def downgrade() -> None:
"""Detach all weekly partition tables from the posts parent table."""
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
table_name = f"posts_{year}_{week:02d}"
op.execute(f"ALTER TABLE {schema}.posts DETACH PARTITION {schema}.{table_name}")
@@ -0,0 +1,153 @@
"""adding congress data.
Revision ID: 83bfc8af92d8
Revises: a1b2c3d4e5f6
Create Date: 2026-03-27 10:43:02.324510
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "83bfc8af92d8"
down_revision: str | None = "a1b2c3d4e5f6"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"bill",
sa.Column("congress", sa.Integer(), nullable=False),
sa.Column("bill_type", sa.String(), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("title", sa.String(), nullable=True),
sa.Column("title_short", sa.String(), nullable=True),
sa.Column("official_title", sa.String(), nullable=True),
sa.Column("status", sa.String(), nullable=True),
sa.Column("status_at", sa.Date(), nullable=True),
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
sa.Column("subjects_top_term", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
sa.UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
schema=schema,
)
op.create_index("ix_bill_congress", "bill", ["congress"], unique=False, schema=schema)
op.create_table(
"legislator",
sa.Column("bioguide_id", sa.Text(), nullable=False),
sa.Column("thomas_id", sa.String(), nullable=True),
sa.Column("lis_id", sa.String(), nullable=True),
sa.Column("govtrack_id", sa.Integer(), nullable=True),
sa.Column("opensecrets_id", sa.String(), nullable=True),
sa.Column("fec_ids", sa.String(), nullable=True),
sa.Column("first_name", sa.String(), nullable=False),
sa.Column("last_name", sa.String(), nullable=False),
sa.Column("official_full_name", sa.String(), nullable=True),
sa.Column("nickname", sa.String(), nullable=True),
sa.Column("birthday", sa.Date(), nullable=True),
sa.Column("gender", sa.String(), nullable=True),
sa.Column("current_party", sa.String(), nullable=True),
sa.Column("current_state", sa.String(), nullable=True),
sa.Column("current_district", sa.Integer(), nullable=True),
sa.Column("current_chamber", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
schema=schema,
)
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
op.create_table(
"bill_text",
sa.Column("bill_id", sa.Integer(), nullable=False),
sa.Column("version_code", sa.String(), nullable=False),
sa.Column("version_name", sa.String(), nullable=True),
sa.Column("text_content", sa.String(), nullable=True),
sa.Column("date", sa.Date(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_bill_text_bill_id_bill"), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill_text")),
sa.UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),
schema=schema,
)
op.create_table(
"vote",
sa.Column("congress", sa.Integer(), nullable=False),
sa.Column("chamber", sa.String(), nullable=False),
sa.Column("session", sa.Integer(), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("vote_type", sa.String(), nullable=True),
sa.Column("question", sa.String(), nullable=True),
sa.Column("result", sa.String(), nullable=True),
sa.Column("result_text", sa.String(), nullable=True),
sa.Column("vote_date", sa.Date(), nullable=False),
sa.Column("yea_count", sa.Integer(), nullable=True),
sa.Column("nay_count", sa.Integer(), nullable=True),
sa.Column("not_voting_count", sa.Integer(), nullable=True),
sa.Column("present_count", sa.Integer(), nullable=True),
sa.Column("bill_id", sa.Integer(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
sa.UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
schema=schema,
)
op.create_index("ix_vote_congress_chamber", "vote", ["congress", "chamber"], unique=False, schema=schema)
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
op.create_table(
"vote_record",
sa.Column("vote_id", sa.Integer(), nullable=False),
sa.Column("legislator_id", sa.Integer(), nullable=False),
sa.Column("position", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["legislator_id"],
[f"{schema}.legislator.id"],
name=op.f("fk_vote_record_legislator_id_legislator"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("vote_record", schema=schema)
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
op.drop_table("vote", schema=schema)
op.drop_table("bill_text", schema=schema)
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
op.drop_table("legislator", schema=schema)
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
op.drop_table("bill", schema=schema)
# ### end Alembic commands ###
@@ -0,0 +1,58 @@
"""adding LegislatorSocialMedia.
Revision ID: 5cd7eee3549d
Revises: 83bfc8af92d8
Create Date: 2026-03-29 11:53:44.224799
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import DataScienceDevBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "5cd7eee3549d"
down_revision: str | None = "83bfc8af92d8"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = DataScienceDevBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"legislator_social_media",
sa.Column("legislator_id", sa.Integer(), nullable=False),
sa.Column("platform", sa.String(), nullable=False),
sa.Column("account_name", sa.String(), nullable=False),
sa.Column("url", sa.String(), nullable=True),
sa.Column("source", sa.String(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["legislator_id"],
[f"{schema}.legislator.id"],
name=op.f("fk_legislator_social_media_legislator_id_legislator"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator_social_media")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("legislator_social_media", schema=schema)
# ### end Alembic commands ###
@@ -1,54 +0,0 @@
"""add 1024 ebook embedding cosine index.
Revision ID: c460105682d2
Revises: 2db132cace1a
Create Date: 2026-06-13 19:53:45.680289
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from alembic import op
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "c460105682d2"
down_revision: str | None = "2db132cace1a"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(
"ix_ebook_chunk_embedding_1024_embedding_cosine",
"ebook_chunk_embedding_1024",
["embedding"],
unique=False,
schema=schema,
postgresql_using="hnsw",
postgresql_ops={"embedding": "vector_cosine_ops"},
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
"ix_ebook_chunk_embedding_1024_embedding_cosine",
table_name="ebook_chunk_embedding_1024",
schema=schema,
postgresql_using="hnsw",
postgresql_ops={"embedding": "vector_cosine_ops"},
)
# ### end Alembic commands ###
@@ -0,0 +1,100 @@
"""seprating signal_bot database.
Revision ID: 6eaf696e07a5
Revises:
Create Date: 2026-03-17 21:35:37.612672
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from python.orm import SignalBotBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "6eaf696e07a5"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = SignalBotBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"dead_letter_message",
sa.Column("source", sa.String(), nullable=False),
sa.Column("message", sa.Text(), nullable=False),
sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"status", postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema), nullable=False
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_dead_letter_message")),
schema=schema,
)
op.create_table(
"role",
sa.Column("name", sa.String(length=50), nullable=False),
sa.Column("id", sa.SmallInteger(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
sa.UniqueConstraint("name", name=op.f("uq_role_name")),
schema=schema,
)
op.create_table(
"signal_device",
sa.Column("phone_number", sa.String(length=50), nullable=False),
sa.Column("safety_number", sa.String(), nullable=True),
sa.Column(
"trust_level",
postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
nullable=False,
),
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_signal_device")),
sa.UniqueConstraint("phone_number", name=op.f("uq_signal_device_phone_number")),
schema=schema,
)
op.create_table(
"device_role",
sa.Column("device_id", sa.Integer(), nullable=False),
sa.Column("role_id", sa.SmallInteger(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["device_id"], [f"{schema}.signal_device.id"], name=op.f("fk_device_role_device_id_signal_device")
),
sa.ForeignKeyConstraint(["role_id"], [f"{schema}.role.id"], name=op.f("fk_device_role_role_id_role")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_device_role")),
sa.UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("device_role", schema=schema)
op.drop_table("signal_device", schema=schema)
op.drop_table("role", schema=schema)
op.drop_table("dead_letter_message", schema=schema)
# ### end Alembic commands ###
@@ -0,0 +1,72 @@
"""test.
Revision ID: 66bdd532bcab
Revises: 6eaf696e07a5
Create Date: 2026-03-18 19:21:14.561568
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from python.orm import SignalBotBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "66bdd532bcab"
down_revision: str | None = "6eaf696e07a5"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = SignalBotBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"dead_letter_message",
"status",
existing_type=postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
type_=sa.Enum("UNPROCESSED", "PROCESSED", name="message_status", native_enum=False),
existing_nullable=False,
schema=schema,
)
op.alter_column(
"signal_device",
"trust_level",
existing_type=postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
type_=sa.Enum("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", native_enum=False),
existing_nullable=False,
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"signal_device",
"trust_level",
existing_type=sa.Enum("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", native_enum=False),
type_=postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
existing_nullable=False,
schema=schema,
)
op.alter_column(
"dead_letter_message",
"status",
existing_type=sa.Enum("UNPROCESSED", "PROCESSED", name="message_status", native_enum=False),
type_=postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
existing_nullable=False,
schema=schema,
)
# ### end Alembic commands ###
+1 -1
View File
@@ -9,9 +9,9 @@ import typer
import uvicorn
from fastapi import FastAPI
from python.api.middleware import ZstdMiddleware
from python.api.routers import contact_router, views_router
from python.common import configure_logger
from python.fastapi_tools import ZstdMiddleware
from python.orm.common import get_postgres_engine
logger = logging.getLogger(__name__)
@@ -1,4 +1,4 @@
"""Zstd response compression middleware."""
"""Middleware for the FastAPI application."""
from compression import zstd
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+1 -1
View File
@@ -9,7 +9,7 @@ from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from python.fastapi_tools.db import DbSession
from python.api.dependencies import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+1 -1
View File
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from python.fastapi_tools.db import DbSession
from python.api.dependencies import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+3
View File
@@ -0,0 +1,3 @@
"""Data science CLI tools."""
from __future__ import annotations
+613
View File
@@ -0,0 +1,613 @@
"""Ingestion pipeline for loading congress data from unitedstates/congress JSON files.
Loads legislators, bills, votes, vote records, and bill text into the data_science_dev database.
Expects the parent directory to contain congress-tracker/ and congress-legislators/ as siblings.
Usage:
ingest-congress /path/to/parent/
ingest-congress /path/to/parent/ --congress 118
ingest-congress /path/to/parent/ --congress 118 --only bills
"""
from __future__ import annotations
import logging
from pathlib import Path # noqa: TC003 needed at runtime for typer CLI argument
from typing import TYPE_CHECKING, Annotated
import orjson
import typer
import yaml
from sqlalchemy import select
from sqlalchemy.orm import Session
from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, LegislatorSocialMedia, Vote, VoteRecord
if TYPE_CHECKING:
from collections.abc import Iterator
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
BATCH_SIZE = 10_000
app = typer.Typer(help="Ingest unitedstates/congress data into data_science_dev.")
@app.command()
def main(
parent_dir: Annotated[
Path,
typer.Argument(help="Parent directory containing congress-tracker/ and congress-legislators/"),
],
congress: Annotated[int | None, typer.Option(help="Only ingest a specific congress number")] = None,
only: Annotated[
str | None,
typer.Option(help="Only run a specific step: legislators, social-media, bills, votes, bill-text"),
] = None,
) -> None:
"""Ingest congress data from unitedstates/congress JSON files."""
configure_logger(level="INFO")
data_dir = parent_dir / "congress-tracker/congress/data/"
legislators_dir = parent_dir / "congress-legislators"
if not data_dir.is_dir():
typer.echo(f"Expected congress-tracker/ directory: {data_dir}", err=True)
raise typer.Exit(code=1)
if not legislators_dir.is_dir():
typer.echo(f"Expected congress-legislators/ directory: {legislators_dir}", err=True)
raise typer.Exit(code=1)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
congress_dirs = _resolve_congress_dirs(data_dir, congress)
if not congress_dirs:
typer.echo("No congress directories found.", err=True)
raise typer.Exit(code=1)
logger.info("Found %d congress directories to process", len(congress_dirs))
steps: dict[str, tuple] = {
"legislators": (ingest_legislators, (engine, legislators_dir)),
"legislators-social-media": (ingest_social_media, (engine, legislators_dir)),
"bills": (ingest_bills, (engine, congress_dirs)),
"votes": (ingest_votes, (engine, congress_dirs)),
"bill-text": (ingest_bill_text, (engine, congress_dirs)),
}
if only:
if only not in steps:
typer.echo(f"Unknown step: {only}. Choose from: {', '.join(steps)}", err=True)
raise typer.Exit(code=1)
steps = {only: steps[only]}
for step_name, (step_func, step_args) in steps.items():
logger.info("=== Starting step: %s ===", step_name)
step_func(*step_args)
logger.info("=== Finished step: %s ===", step_name)
logger.info("ingest-congress done")
def _resolve_congress_dirs(data_dir: Path, congress: int | None) -> list[Path]:
"""Find congress number directories under data_dir."""
if congress is not None:
target = data_dir / str(congress)
return [target] if target.is_dir() else []
return sorted(path for path in data_dir.iterdir() if path.is_dir() and path.name.isdigit())
def _flush_batch(session: Session, batch: list[object], label: str) -> int:
"""Add a batch of ORM objects to the session and commit. Returns count added."""
if not batch:
return 0
session.add_all(batch)
session.commit()
count = len(batch)
logger.info("Committed %d %s", count, label)
batch.clear()
return count
# ---------------------------------------------------------------------------
# Legislators — loaded from congress-legislators YAML files
# ---------------------------------------------------------------------------
def ingest_legislators(engine: Engine, legislators_dir: Path) -> None:
"""Load legislators from congress-legislators YAML files."""
legislators_data = _load_legislators_yaml(legislators_dir)
logger.info("Loaded %d legislators from YAML files", len(legislators_data))
with Session(engine) as session:
existing_legislators = {
legislator.bioguide_id: legislator for legislator in session.scalars(select(Legislator)).all()
}
logger.info("Found %d existing legislators in DB", len(existing_legislators))
total_inserted = 0
total_updated = 0
for entry in legislators_data:
bioguide_id = entry.get("id", {}).get("bioguide")
if not bioguide_id:
continue
fields = _parse_legislator(entry)
if existing := existing_legislators.get(bioguide_id):
changed = False
for field, value in fields.items():
if value is not None and getattr(existing, field) != value:
setattr(existing, field, value)
changed = True
if changed:
total_updated += 1
else:
session.add(Legislator(bioguide_id=bioguide_id, **fields))
total_inserted += 1
session.commit()
logger.info("Inserted %d new legislators, updated %d existing", total_inserted, total_updated)
def _load_legislators_yaml(legislators_dir: Path) -> list[dict]:
"""Load and combine legislators-current.yaml and legislators-historical.yaml."""
legislators: list[dict] = []
for filename in ("legislators-current.yaml", "legislators-historical.yaml"):
path = legislators_dir / filename
if not path.exists():
logger.warning("Legislators file not found: %s", path)
continue
with path.open() as file:
data = yaml.safe_load(file)
if isinstance(data, list):
legislators.extend(data)
return legislators
def _parse_legislator(entry: dict) -> dict:
"""Extract Legislator fields from a congress-legislators YAML entry."""
ids = entry.get("id", {})
name = entry.get("name", {})
bio = entry.get("bio", {})
terms = entry.get("terms", [])
latest_term = terms[-1] if terms else {}
fec_ids = ids.get("fec")
fec_ids_joined = ",".join(fec_ids) if isinstance(fec_ids, list) else fec_ids
chamber = latest_term.get("type")
chamber_normalized = {"rep": "House", "sen": "Senate"}.get(chamber, chamber)
return {
"thomas_id": ids.get("thomas"),
"lis_id": ids.get("lis"),
"govtrack_id": ids.get("govtrack"),
"opensecrets_id": ids.get("opensecrets"),
"fec_ids": fec_ids_joined,
"first_name": name.get("first"),
"last_name": name.get("last"),
"official_full_name": name.get("official_full"),
"nickname": name.get("nickname"),
"birthday": bio.get("birthday"),
"gender": bio.get("gender"),
"current_party": latest_term.get("party"),
"current_state": latest_term.get("state"),
"current_district": latest_term.get("district"),
"current_chamber": chamber_normalized,
}
# ---------------------------------------------------------------------------
# Social Media — loaded from legislators-social-media.yaml
# ---------------------------------------------------------------------------
SOCIAL_MEDIA_PLATFORMS = {
"twitter": "https://twitter.com/{account}",
"facebook": "https://facebook.com/{account}",
"youtube": "https://youtube.com/{account}",
"instagram": "https://instagram.com/{account}",
"mastodon": None,
}
def ingest_social_media(engine: Engine, legislators_dir: Path) -> None:
"""Load social media accounts from legislators-social-media.yaml."""
social_media_path = legislators_dir / "legislators-social-media.yaml"
if not social_media_path.exists():
logger.warning("Social media file not found: %s", social_media_path)
return
with social_media_path.open() as file:
social_media_data = yaml.safe_load(file)
if not isinstance(social_media_data, list):
logger.warning("Unexpected format in %s", social_media_path)
return
logger.info("Loaded %d entries from legislators-social-media.yaml", len(social_media_data))
with Session(engine) as session:
legislator_map = _build_legislator_map(session)
existing_accounts = {
(account.legislator_id, account.platform)
for account in session.scalars(select(LegislatorSocialMedia)).all()
}
logger.info("Found %d existing social media accounts in DB", len(existing_accounts))
total_inserted = 0
total_updated = 0
for entry in social_media_data:
bioguide_id = entry.get("id", {}).get("bioguide")
if not bioguide_id:
continue
legislator_id = legislator_map.get(bioguide_id)
if legislator_id is None:
continue
social = entry.get("social", {})
for platform, url_template in SOCIAL_MEDIA_PLATFORMS.items():
account_name = social.get(platform)
if not account_name:
continue
url = url_template.format(account=account_name) if url_template else None
if (legislator_id, platform) in existing_accounts:
total_updated += 1
else:
session.add(
LegislatorSocialMedia(
legislator_id=legislator_id,
platform=platform,
account_name=str(account_name),
url=url,
source="https://github.com/unitedstates/congress-legislators",
)
)
existing_accounts.add((legislator_id, platform))
total_inserted += 1
session.commit()
logger.info("Inserted %d new social media accounts, updated %d existing", total_inserted, total_updated)
def _iter_voters(position_group: object) -> Iterator[dict]:
"""Yield voter dicts from a vote position group (handles list, single dict, or string)."""
if isinstance(position_group, dict):
yield position_group
elif isinstance(position_group, list):
for voter in position_group:
if isinstance(voter, dict):
yield voter
# ---------------------------------------------------------------------------
# Bills
# ---------------------------------------------------------------------------
def ingest_bills(engine: Engine, congress_dirs: list[Path]) -> None:
"""Load bill data.json files."""
with Session(engine) as session:
existing_bills = {(bill.congress, bill.bill_type, bill.number) for bill in session.scalars(select(Bill)).all()}
logger.info("Found %d existing bills in DB", len(existing_bills))
total_inserted = 0
batch: list[Bill] = []
for congress_dir in congress_dirs:
bills_dir = congress_dir / "bills"
if not bills_dir.is_dir():
continue
logger.info("Scanning bills from %s", congress_dir.name)
for bill_file in bills_dir.rglob("data.json"):
data = _read_json(bill_file)
if data is None:
continue
bill = _parse_bill(data, existing_bills)
if bill is not None:
batch.append(bill)
if len(batch) >= BATCH_SIZE:
total_inserted += _flush_batch(session, batch, "bills")
total_inserted += _flush_batch(session, batch, "bills")
logger.info("Inserted %d new bills total", total_inserted)
def _parse_bill(data: dict, existing_bills: set[tuple[int, str, int]]) -> Bill | None:
"""Parse a bill data.json dict into a Bill ORM object, skipping existing."""
raw_congress = data.get("congress")
bill_type = data.get("bill_type")
raw_number = data.get("number")
if raw_congress is None or bill_type is None or raw_number is None:
return None
congress = int(raw_congress)
number = int(raw_number)
if (congress, bill_type, number) in existing_bills:
return None
sponsor_bioguide = None
sponsor = data.get("sponsor")
if sponsor:
sponsor_bioguide = sponsor.get("bioguide_id")
return Bill(
congress=congress,
bill_type=bill_type,
number=number,
title=data.get("short_title") or data.get("official_title"),
title_short=data.get("short_title"),
official_title=data.get("official_title"),
status=data.get("status"),
status_at=data.get("status_at"),
sponsor_bioguide_id=sponsor_bioguide,
subjects_top_term=data.get("subjects_top_term"),
)
# ---------------------------------------------------------------------------
# Votes (and vote records)
# ---------------------------------------------------------------------------
def ingest_votes(engine: Engine, congress_dirs: list[Path]) -> None:
"""Load vote data.json files with their vote records."""
with Session(engine) as session:
legislator_map = _build_legislator_map(session)
logger.info("Loaded %d legislators into lookup map", len(legislator_map))
bill_map = _build_bill_map(session)
logger.info("Loaded %d bills into lookup map", len(bill_map))
existing_votes = {
(vote.congress, vote.chamber, vote.session, vote.number) for vote in session.scalars(select(Vote)).all()
}
logger.info("Found %d existing votes in DB", len(existing_votes))
total_inserted = 0
batch: list[Vote] = []
for congress_dir in congress_dirs:
votes_dir = congress_dir / "votes"
if not votes_dir.is_dir():
continue
logger.info("Scanning votes from %s", congress_dir.name)
for vote_file in votes_dir.rglob("data.json"):
data = _read_json(vote_file)
if data is None:
continue
vote = _parse_vote(data, legislator_map, bill_map, existing_votes)
if vote is not None:
batch.append(vote)
if len(batch) >= BATCH_SIZE:
total_inserted += _flush_batch(session, batch, "votes")
total_inserted += _flush_batch(session, batch, "votes")
logger.info("Inserted %d new votes total", total_inserted)
def _build_legislator_map(session: Session) -> dict[str, int]:
"""Build a mapping of bioguide_id -> legislator.id."""
return {legislator.bioguide_id: legislator.id for legislator in session.scalars(select(Legislator)).all()}
def _build_bill_map(session: Session) -> dict[tuple[int, str, int], int]:
"""Build a mapping of (congress, bill_type, number) -> bill.id."""
return {(bill.congress, bill.bill_type, bill.number): bill.id for bill in session.scalars(select(Bill)).all()}
def _parse_vote(
data: dict,
legislator_map: dict[str, int],
bill_map: dict[tuple[int, str, int], int],
existing_votes: set[tuple[int, str, int, int]],
) -> Vote | None:
"""Parse a vote data.json dict into a Vote ORM object with records."""
raw_congress = data.get("congress")
chamber = data.get("chamber")
raw_number = data.get("number")
vote_date = data.get("date")
if raw_congress is None or chamber is None or raw_number is None or vote_date is None:
return None
raw_session = data.get("session")
if raw_session is None:
return None
congress = int(raw_congress)
number = int(raw_number)
session_number = int(raw_session)
# Normalize chamber from "h"/"s" to "House"/"Senate"
chamber_normalized = {"h": "House", "s": "Senate"}.get(chamber, chamber)
if (congress, chamber_normalized, session_number, number) in existing_votes:
return None
# Resolve linked bill
bill_id = None
bill_ref = data.get("bill")
if bill_ref:
bill_key = (
int(bill_ref.get("congress", congress)),
bill_ref.get("type"),
int(bill_ref.get("number", 0)),
)
bill_id = bill_map.get(bill_key)
raw_votes = data.get("votes", {})
vote_counts = _count_votes(raw_votes)
vote_records = _build_vote_records(raw_votes, legislator_map)
return Vote(
congress=congress,
chamber=chamber_normalized,
session=session_number,
number=number,
vote_type=data.get("type"),
question=data.get("question"),
result=data.get("result"),
result_text=data.get("result_text"),
vote_date=vote_date[:10] if isinstance(vote_date, str) else vote_date,
bill_id=bill_id,
vote_records=vote_records,
**vote_counts,
)
def _count_votes(raw_votes: dict) -> dict[str, int]:
"""Count voters per position category, correctly handling dict and list formats."""
yea_count = 0
nay_count = 0
not_voting_count = 0
present_count = 0
for position, position_group in raw_votes.items():
voter_count = sum(1 for _ in _iter_voters(position_group))
if position in ("Yea", "Aye"):
yea_count += voter_count
elif position in ("Nay", "No"):
nay_count += voter_count
elif position == "Not Voting":
not_voting_count += voter_count
elif position == "Present":
present_count += voter_count
return {
"yea_count": yea_count,
"nay_count": nay_count,
"not_voting_count": not_voting_count,
"present_count": present_count,
}
def _build_vote_records(raw_votes: dict, legislator_map: dict[str, int]) -> list[VoteRecord]:
"""Build VoteRecord objects from raw vote data."""
records: list[VoteRecord] = []
for position, position_group in raw_votes.items():
for voter in _iter_voters(position_group):
bioguide_id = voter.get("id")
if not bioguide_id:
continue
legislator_id = legislator_map.get(bioguide_id)
if legislator_id is None:
continue
records.append(
VoteRecord(
legislator_id=legislator_id,
position=position,
)
)
return records
# ---------------------------------------------------------------------------
# Bill Text
# ---------------------------------------------------------------------------
def ingest_bill_text(engine: Engine, congress_dirs: list[Path]) -> None:
"""Load bill text from text-versions directories."""
with Session(engine) as session:
bill_map = _build_bill_map(session)
logger.info("Loaded %d bills into lookup map", len(bill_map))
existing_bill_texts = {
(bill_text.bill_id, bill_text.version_code) for bill_text in session.scalars(select(BillText)).all()
}
logger.info("Found %d existing bill text versions in DB", len(existing_bill_texts))
total_inserted = 0
batch: list[BillText] = []
for congress_dir in congress_dirs:
logger.info("Scanning bill texts from %s", congress_dir.name)
for bill_text in _iter_bill_texts(congress_dir, bill_map, existing_bill_texts):
batch.append(bill_text)
if len(batch) >= BATCH_SIZE:
total_inserted += _flush_batch(session, batch, "bill texts")
total_inserted += _flush_batch(session, batch, "bill texts")
logger.info("Inserted %d new bill text versions total", total_inserted)
def _iter_bill_texts(
congress_dir: Path,
bill_map: dict[tuple[int, str, int], int],
existing_bill_texts: set[tuple[int, str]],
) -> Iterator[BillText]:
"""Yield BillText objects for a single congress directory, skipping existing."""
bills_dir = congress_dir / "bills"
if not bills_dir.is_dir():
return
for bill_dir in bills_dir.rglob("text-versions"):
if not bill_dir.is_dir():
continue
bill_key = _bill_key_from_dir(bill_dir.parent, congress_dir)
if bill_key is None:
continue
bill_id = bill_map.get(bill_key)
if bill_id is None:
continue
for version_dir in sorted(bill_dir.iterdir()):
if not version_dir.is_dir():
continue
if (bill_id, version_dir.name) in existing_bill_texts:
continue
text_content = _read_bill_text(version_dir)
version_data = _read_json(version_dir / "data.json")
yield BillText(
bill_id=bill_id,
version_code=version_dir.name,
version_name=version_data.get("version_name") if version_data else None,
date=version_data.get("issued_on") if version_data else None,
text_content=text_content,
)
def _bill_key_from_dir(bill_dir: Path, congress_dir: Path) -> tuple[int, str, int] | None:
"""Extract (congress, bill_type, number) from directory structure."""
congress = int(congress_dir.name)
bill_type = bill_dir.parent.name
name = bill_dir.name
# Directory name is like "hr3590" — strip the type prefix to get the number
number_str = name[len(bill_type) :]
if not number_str.isdigit():
return None
return (congress, bill_type, int(number_str))
def _read_bill_text(version_dir: Path) -> str | None:
"""Read bill text from a version directory, preferring .txt over .xml."""
for extension in ("txt", "htm", "html", "xml"):
candidates = list(version_dir.glob(f"document.{extension}"))
if not candidates:
candidates = list(version_dir.glob(f"*.{extension}"))
if candidates:
try:
return candidates[0].read_text(encoding="utf-8")
except Exception:
logger.exception("Failed to read %s", candidates[0])
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _read_json(path: Path) -> dict | None:
"""Read and parse a JSON file, returning None on failure."""
try:
return orjson.loads(path.read_bytes())
except FileNotFoundError:
return None
except Exception:
logger.exception("Failed to parse %s", path)
return None
if __name__ == "__main__":
app()
+247
View File
@@ -0,0 +1,247 @@
"""Ingestion pipeline for loading JSONL post files into the weekly-partitioned posts table.
Usage:
ingest-posts /path/to/files/
ingest-posts /path/to/single_file.jsonl
ingest-posts /data/dir/ --workers 4 --batch-size 5000
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from pathlib import Path # noqa: TC003 this is needed for typer
from typing import TYPE_CHECKING, Annotated
import orjson
import psycopg
import typer
from python.common import configure_logger
from python.orm.common import get_connection_info
from python.parallelize import parallelize_process
if TYPE_CHECKING:
from collections.abc import Iterator
logger = logging.getLogger(__name__)
app = typer.Typer(help="Ingest JSONL post files into the partitioned posts table.")
@app.command()
def main(
path: Annotated[Path, typer.Argument(help="Directory containing JSONL files, or a single JSONL file")],
batch_size: Annotated[int, typer.Option(help="Rows per INSERT batch")] = 10000,
workers: Annotated[int, typer.Option(help="Parallel workers for multi-file ingestion")] = 4,
pattern: Annotated[str, typer.Option(help="Glob pattern for JSONL files")] = "*.jsonl",
) -> None:
"""Ingest JSONL post files into the weekly-partitioned posts table."""
configure_logger(level="INFO")
logger.info("starting ingest-posts")
logger.info("path=%s batch_size=%d workers=%d pattern=%s", path, batch_size, workers, pattern)
if path.is_file():
ingest_file(path, batch_size=batch_size)
elif path.is_dir():
ingest_directory(path, batch_size=batch_size, max_workers=workers, pattern=pattern)
else:
typer.echo(f"Path does not exist: {path}", err=True)
raise typer.Exit(code=1)
logger.info("ingest-posts done")
def ingest_directory(
directory: Path,
*,
batch_size: int,
max_workers: int,
pattern: str = "*.jsonl",
) -> None:
"""Ingest all JSONL files in a directory using parallel workers."""
files = sorted(directory.glob(pattern))
if not files:
logger.warning("No JSONL files found in %s", directory)
return
logger.info("Found %d JSONL files to ingest", len(files))
kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files]
parallelize_process(ingest_file, kwargs_list, max_workers=max_workers)
SCHEMA = "main"
COLUMNS = (
"post_id",
"user_id",
"instance",
"date",
"text",
"langs",
"like_count",
"reply_count",
"repost_count",
"reply_to",
"replied_author",
"thread_root",
"thread_root_author",
"repost_from",
"reposted_author",
"quotes",
"quoted_author",
"labels",
"sent_label",
"sent_score",
)
INSERT_FROM_STAGING = f"""
INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)})
SELECT {", ".join(COLUMNS)} FROM pg_temp.staging
ON CONFLICT (post_id, date) DO NOTHING
""" # noqa: S608
FAILED_INSERT = f"""
INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error)
VALUES (%(raw_line)s, %(error)s)
""" # noqa: S608
def get_psycopg_connection() -> psycopg.Connection:
"""Create a raw psycopg3 connection from environment variables."""
database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV")
return psycopg.connect(
dbname=database,
host=host,
port=int(port),
user=username,
password=password,
autocommit=False,
)
def ingest_file(path: Path, *, batch_size: int) -> None:
"""Ingest a single JSONL file into the posts table."""
log_trigger = max(100_000 // batch_size, 1)
failed_lines: list[dict] = []
try:
with get_psycopg_connection() as connection:
for index, batch in enumerate(read_jsonl_batches(path, batch_size, failed_lines), 1):
ingest_batch(connection, batch)
if index % log_trigger == 0:
logger.info("Ingested %d batches (%d rows) from %s", index, index * batch_size, path)
if failed_lines:
logger.warning("Recording %d malformed lines from %s", len(failed_lines), path.name)
with connection.cursor() as cursor:
cursor.executemany(FAILED_INSERT, failed_lines)
connection.commit()
except Exception:
logger.exception("Failed to ingest file: %s", path)
raise
def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None:
"""COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts."""
if not batch:
return
try:
with connection.cursor() as cursor:
cursor.execute(f"""
CREATE TEMP TABLE IF NOT EXISTS staging
(LIKE {SCHEMA}.posts INCLUDING DEFAULTS)
ON COMMIT DELETE ROWS
""")
cursor.execute("TRUNCATE pg_temp.staging")
with cursor.copy(f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN") as copy:
for row in batch:
copy.write_row(tuple(row.get(column) for column in COLUMNS))
cursor.execute(INSERT_FROM_STAGING)
connection.commit()
except Exception as error:
connection.rollback()
if len(batch) == 1:
logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id"))
with connection.cursor() as cursor:
cursor.execute(
FAILED_INSERT,
{
"raw_line": orjson.dumps(batch[0], default=str).decode(),
"error": str(error),
},
)
connection.commit()
return
midpoint = len(batch) // 2
ingest_batch(connection, batch[:midpoint])
ingest_batch(connection, batch[midpoint:])
def read_jsonl_batches(file_path: Path, batch_size: int, failed_lines: list[dict]) -> Iterator[list[dict]]:
"""Stream a JSONL file and yield batches of transformed rows."""
batch: list[dict] = []
with file_path.open("r", encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line:
continue
batch.extend(parse_line(line, file_path, failed_lines))
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]:
"""Parse a JSONL line, handling concatenated JSON objects."""
try:
yield transform_row(orjson.loads(line))
except orjson.JSONDecodeError:
if "}{" not in line:
logger.warning("Skipping malformed line in %s: %s", file_path.name, line[:120])
failed_lines.append({"raw_line": line, "error": "malformed JSON"})
return
fragments = line.replace("}{", "}\n{").split("\n")
for fragment in fragments:
try:
yield transform_row(orjson.loads(fragment))
except (orjson.JSONDecodeError, KeyError, ValueError) as error:
logger.warning("Skipping malformed fragment in %s: %s", file_path.name, fragment[:120])
failed_lines.append({"raw_line": fragment, "error": str(error)})
except Exception as error:
logger.exception("Skipping bad row in %s: %s", file_path.name, line[:120])
failed_lines.append({"raw_line": line, "error": str(error)})
def transform_row(raw: dict) -> dict:
"""Transform a raw JSONL row into a dict matching the Posts table columns."""
raw["date"] = parse_date(raw["date"])
if raw.get("langs") is not None:
raw["langs"] = orjson.dumps(raw["langs"])
if raw.get("text") is not None:
raw["text"] = raw["text"].replace("\x00", "")
return raw
def parse_date(raw_date: int) -> datetime:
"""Parse compact YYYYMMDDHHmm integer into a naive datetime (input is UTC by spec)."""
return datetime(
raw_date // 100000000,
(raw_date // 1000000) % 100,
(raw_date // 10000) % 100,
(raw_date // 100) % 100,
raw_date % 100,
tzinfo=UTC,
)
if __name__ == "__main__":
app()
+14
View File
@@ -83,6 +83,20 @@ DATABASES: dict[str, DatabaseConfig] = {
base_class_name="VanInventoryBase",
models_module="python.orm.van_inventory.models",
),
"signal_bot": DatabaseConfig(
env_prefix="SIGNALBOT",
version_location="python/alembic/signal_bot/versions",
base_module="python.orm.signal_bot.base",
base_class_name="SignalBotBase",
models_module="python.orm.signal_bot.models",
),
"data_science_dev": DatabaseConfig(
env_prefix="DATA_SCIENCE_DEV",
version_location="python/alembic/data_science_dev/versions",
base_module="python.orm.data_science_dev.base",
base_class_name="DataScienceDevBase",
models_module="python.orm.data_science_dev.models",
),
}
+1 -3
View File
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
from sqlalchemy.orm import Session
from python.ebook_search.bm25_corpus import load_bm25_corpus, refresh_bm25_corpus
from python.ebook_search.bm25_corpus import refresh_bm25_corpus
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -56,5 +56,3 @@ def refresh_bm25_for_engine(engine: Engine, config: EbookSearchConfig) -> None:
"""Refresh the BM25 corpus using a SQLAlchemy engine."""
with Session(engine) as session:
refresh_bm25_corpus(session, config)
load_bm25_corpus.cache_clear()
logger.info("ebook_bm25_corpus_cache_cleared_after_refresh")
+5 -9
View File
@@ -14,11 +14,10 @@ from sqlalchemy.orm import Session
from python.common import configure_logger
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
from python.ebook_search.api.routes import admin_router, page_router, search_router
from python.ebook_search.api.routes import register_admin_routes, register_page_routes, register_search_routes
from python.ebook_search.api.web import STATIC_DIR
from python.ebook_search.bm25_corpus import ensure_bm25_corpus
from python.ebook_search.config import load_config
from python.fastapi_tools import ZstdMiddleware
from python.orm.common import get_postgres_engine
if TYPE_CHECKING:
@@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage application startup and shutdown resources."""
logger.info("ebook_search_startup")
app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True)
app.state.engine = get_postgres_engine(name="RICHIE")
with Session(app.state.engine) as session:
ensure_bm25_corpus(session, app.state.config)
try:
@@ -46,7 +45,6 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
def create_app() -> FastAPI:
"""Create the EPUB search web app."""
app = FastAPI(title="EPUB Search", lifespan=lifespan)
app.add_middleware(ZstdMiddleware)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.state.config = load_config()
logger.info(
@@ -57,11 +55,9 @@ def create_app() -> FastAPI:
app.state.config.answer_enabled,
len(app.state.config.library_paths),
)
app.include_router(admin_router)
app.include_router(page_router)
app.include_router(search_router)
register_page_routes(app)
register_search_routes(app)
register_admin_routes(app)
return app
+11 -6
View File
@@ -1,11 +1,16 @@
"""EPUB search web route modules."""
from python.ebook_search.api.routes.admin import router as admin_router
from python.ebook_search.api.routes.page import router as page_router
from python.ebook_search.api.routes.search import router as search_router
from python.ebook_search.api.routes import admin, page, search
register_admin_routes = admin.register_admin_routes
register_page_routes = page.register_page_routes
register_search_routes = search.register_search_routes
__all__ = [
"admin_router",
"page_router",
"search_router",
"admin",
"page",
"register_admin_routes",
"register_page_routes",
"register_search_routes",
"search",
]
+9
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
from dataclasses import replace
from typing import TYPE_CHECKING
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
@@ -14,12 +15,20 @@ from python.ebook_search.api.web import templates
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
from python.ebook_search.ingest import ingest_configured_paths
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin")
EMBED_ALL_BATCH_SIZE = 32
def register_admin_routes(app: FastAPI) -> None:
"""Register admin routes on the app."""
app.include_router(router)
@router.get("", response_class=HTMLResponse)
def admin(request: Request) -> HTMLResponse:
"""Render the admin page."""
+9
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
@@ -12,11 +13,19 @@ from sqlalchemy.orm import Session
from python.ebook_search.api.web import templates
from python.orm.richie import EbookSource
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
router = APIRouter()
def register_page_routes(app: FastAPI) -> None:
"""Register page routes on the app."""
app.include_router(router)
@router.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse:
"""Render the search page."""
+9 -1
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
import logging
from dataclasses import replace
from time import perf_counter
from typing import Annotated
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse
@@ -15,11 +15,19 @@ from python.ebook_search.api.web import templates
from python.ebook_search.search import search_ebooks
from python.ebook_search.timing import runtime_step_from_start
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
router = APIRouter()
def register_search_routes(app: FastAPI) -> None:
"""Register search routes on the app."""
app.include_router(router)
@router.post("/search", response_class=HTMLResponse)
def search(
request: Request,
+38 -70
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
import json
import logging
import shutil
import tempfile
from dataclasses import dataclass
from datetime import UTC, datetime
from functools import cache
@@ -58,21 +59,13 @@ class BM25CorpusUnavailableError(RuntimeError):
def bm25_index_path(config: EbookSearchConfig) -> Path:
"""Return the configured BM25 index root path relative to the current working directory."""
"""Return the configured BM25 index path relative to the current working directory."""
path = Path(config.bm25_index_dir).expanduser()
if path.is_absolute():
return path
return Path.cwd() / path
def get_current_bm25_index(index_path: Path) -> Path:
"""Return the live BM25 index directory."""
current_path = index_path / "current"
if current_path.exists() or current_path.is_symlink():
return current_path
return index_path
def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None:
"""Create or refresh the persisted BM25 corpus when it is missing or stale."""
index_path = bm25_index_path(config)
@@ -107,18 +100,19 @@ def refresh_bm25_corpus(
) -> BM25Manifest:
"""Rebuild and persist the BM25 corpus from the current database chunks."""
index_path = bm25_index_path(config)
records, texts = fetch_bm25_corpus_records(session)
records = fetch_bm25_corpus_records(session)
manifest = BM25Manifest(
created_at=datetime.now(tz=UTC),
db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session),
chunk_count=len(records),
)
write_bm25_corpus(index_path, records, texts, manifest)
write_bm25_corpus(index_path, records, manifest)
logger.info(
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s note=%s",
index_path,
manifest.chunk_count,
manifest.created_at.isoformat(),
"restart_service_to_use_refreshed_bm25_cache",
)
return manifest
@@ -127,11 +121,15 @@ def refresh_bm25_corpus(
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
"""Load the BM25 corpus into memory once per process.
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
This cache intentionally does not notice later on-disk corpus refreshes. Restart the service after rebuilding the
BM25 corpus for searches to use the new index.
"""
index_path = bm25_index_path(config)
active_index_path = get_current_bm25_index(index_path)
logger.info("ebook_bm25_corpus_cache_load path=%s active_path=%s", index_path, active_index_path)
logger.info(
"ebook_bm25_corpus_cache_load path=%s note=%s",
index_path,
"restart_service_after_bm25_refresh",
)
manifest = read_bm25_manifest(index_path)
if manifest is None or not bm25_index_exists(index_path, manifest):
msg = f"BM25 corpus is not available: {index_path}"
@@ -139,7 +137,7 @@ def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
if manifest.chunk_count == 0:
return BM25Corpus(retriever=None, records=(), manifest=manifest)
retriever = bm25s.BM25.load(active_index_path, load_corpus=True, mmap=True)
retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True)
records = tuple(dict(record) for record in retriever.corpus)
return BM25Corpus(retriever=retriever, records=records, manifest=manifest)
@@ -164,12 +162,8 @@ def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tup
return results
def fetch_bm25_corpus_records(session: Session) -> tuple[list[dict[str, object]], list[str]]:
"""Fetch persistable BM25 corpus records and their matching index texts from the database.
search_text is only needed to build the index, so it is returned separately instead of
being persisted into the corpus records, which would double the corpus size.
"""
def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
"""Fetch BM25 corpus records from the database."""
statement = (
select(
EbookChunk.id.label("chunk_id"),
@@ -178,20 +172,20 @@ def fetch_bm25_corpus_records(session: Session) -> tuple[list[dict[str, object]]
EbookSource.author.label("source_author"),
EbookChapter.title.label("chapter_title"),
EbookChunk.page_label.label("page_label"),
EbookChunk.search_text.label("bm25_text"),
func.concat_ws(
" ",
EbookSource.title,
EbookSource.author,
EbookChapter.title,
EbookChunk.search_text,
).label("bm25_text"),
)
.select_from(EbookChunk)
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
.order_by(EbookChunk.id)
)
records: list[dict[str, object]] = []
texts: list[str] = []
for row in session.execute(statement).mappings():
record = dict(row)
texts.append(str(record.pop("bm25_text")))
records.append(record)
return records, texts
return [dict(row) for row in session.execute(statement).mappings()]
def corpus_last_updated_at(session: Session) -> datetime | None:
@@ -204,42 +198,28 @@ def corpus_last_updated_at(session: Session) -> datetime | None:
return session.scalar(select(func.max(update_times.c.updated)))
def write_bm25_corpus(
index_path: Path,
records: list[dict[str, object]],
texts: list[str],
manifest: BM25Manifest,
) -> None:
"""Write a BM25 corpus generation and publish it through the current symlink."""
index_path.mkdir(parents=True, exist_ok=True)
generations_path = index_path / "generations"
generations_path.mkdir(exist_ok=True)
generation_path = next_bm25_generation_path(generations_path, manifest.created_at)
current_path = index_path / "current"
next_current_path = index_path / f".current.{generation_path.name}.tmp"
def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None:
"""Write a BM25 corpus and manifest atomically."""
index_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent))
try:
generation_path.mkdir()
# Empty corpora publish a manifest-only generation so startup succeeds before any chunks exist.
if records:
retriever = bm25s.BM25()
texts = [str(record["bm25_text"]) for record in records]
retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False)
retriever.save(generation_path, corpus=records, show_progress=False)
write_bm25_manifest(generation_path, manifest)
next_current_path.unlink(missing_ok=True)
next_current_path.symlink_to(generation_path, target_is_directory=True)
next_current_path.replace(current_path)
retriever.save(temp_path, corpus=records, show_progress=False)
write_bm25_manifest(temp_path, manifest)
if index_path.exists():
shutil.rmtree(index_path)
temp_path.rename(index_path)
except Exception:
next_current_path.unlink(missing_ok=True)
shutil.rmtree(generation_path, ignore_errors=True)
shutil.rmtree(temp_path, ignore_errors=True)
raise
def read_bm25_manifest(index_path: Path) -> BM25Manifest | None:
"""Read the BM25 manifest if it exists and is valid."""
manifest_path = get_current_bm25_index(index_path) / MANIFEST_NAME
manifest_path = index_path / MANIFEST_NAME
if not manifest_path.exists():
return None
body = json.loads(manifest_path.read_text(encoding="utf-8"))
@@ -262,20 +242,8 @@ def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None:
def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool:
"""Return whether a usable persisted BM25 index exists."""
active_index_path = get_current_bm25_index(index_path)
if manifest is None or not active_index_path.is_dir():
if manifest is None or not index_path.is_dir():
return False
if manifest.chunk_count == 0:
return True
return all((active_index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
def next_bm25_generation_path(generations_path: Path, created_at: datetime) -> Path:
"""Return an unused dated BM25 generation path."""
base_name = created_at.astimezone(UTC).strftime("%Y%m%dT%H%M%S.%fZ")
generation_path = generations_path / base_name
suffix = 1
while generation_path.exists():
generation_path = generations_path / f"{base_name}.{suffix}"
suffix += 1
return generation_path
return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
+1 -3
View File
@@ -13,8 +13,6 @@ if TYPE_CHECKING:
from python.ebook_search.search import SearchResult
logger = logging.getLogger(__name__)
RERANK_SCORE_WEIGHT = 0.7
HYBRID_SCORE_WEIGHT = 0.3
@dataclass(frozen=True)
@@ -112,7 +110,7 @@ def clamp_score(score: float) -> float:
def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
"""Combine rerank relevance with normalized hybrid retrieval evidence."""
return (RERANK_SCORE_WEIGHT * rerank_score) + (HYBRID_SCORE_WEIGHT * normalized_hybrid_score(result, candidates))
return rerank_score * normalized_hybrid_score(result, candidates)
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
+7 -19
View File
@@ -13,7 +13,6 @@ from sqlalchemy import literal, select
from sqlalchemy.orm import Session
from python.ebook_search.bm25_corpus import (
BM25CorpusUnavailableError,
load_bm25_corpus,
score_bm25_corpus,
)
@@ -94,14 +93,13 @@ def search_ebooks(
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
timings: list[RuntimeStep] = []
bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query)
retrieval_query, timing = timed_result("Query preparation", retrieval_query_from_text, query)
timings.append(timing)
retrieval, timing = timed_result(
"Hybrid retrieval",
parallel_retrieval,
engine,
query,
bm25_query,
retrieval_query,
config,
)
timings.extend(retrieval.timings)
@@ -132,12 +130,7 @@ def search_ebooks(
return response
def parallel_retrieval(
engine: Engine,
vector_query: str,
bm25_query: str,
config: EbookSearchConfig,
) -> RetrievalResponse:
def parallel_retrieval(engine: Engine, query: str, config: EbookSearchConfig) -> RetrievalResponse:
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor:
vector_future = executor.submit(
@@ -145,14 +138,14 @@ def parallel_retrieval(
"Embedding + vector search",
vector_candidates,
engine,
vector_query,
query,
config,
)
bm25_future = executor.submit(
timed_result,
"BM25 search",
bm25_candidates,
bm25_query,
query,
config,
)
vector_results, vector_timing = vector_future.result()
@@ -203,7 +196,7 @@ def apply_rerank(
def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]:
"""Return pgvector cosine candidates for a natural-language query."""
"""Return pgvector cosine candidates for a normalized query."""
with Session(engine) as session:
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
if model is None:
@@ -253,12 +246,7 @@ def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) ->
def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]:
"""Return BM25-ranked lexical candidates using the persisted corpus."""
try:
corpus = load_bm25_corpus(config)
except BM25CorpusUnavailableError as error:
logger.warning("ebook_bm25_index_unavailable_skipping error=%s", error)
return []
corpus = load_bm25_corpus(config)
if not corpus.records:
logger.info("ebook_bm25_search_complete corpus=0 candidates=0")
return []
-6
View File
@@ -1,6 +0,0 @@
"""Reusable FastAPI tools."""
from python.fastapi_tools.db import DbSession, get_db
from python.fastapi_tools.zstd_middleware import ZstdMiddleware
__all__ = ["DbSession", "ZstdMiddleware", "get_db"]
+4
View File
@@ -1,9 +1,13 @@
"""ORM package exports."""
from python.orm.data_science_dev.base import DataScienceDevBase
from python.orm.richie.base import RichieBase
from python.orm.signal_bot.base import SignalBotBase
from python.orm.van_inventory.base import VanInventoryBase
__all__ = [
"DataScienceDevBase",
"RichieBase",
"SignalBotBase",
"VanInventoryBase",
]
+2 -24
View File
@@ -31,24 +31,8 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(
*,
name: str = "POSTGRES",
pool_pre_ping: bool = True,
vector_engine: bool = False,
) -> Engine:
"""Create a SQLAlchemy engine from environment variables.
Args:
name (str, optional): The name of the environment variable prefix. Defaults to "POSTGRES".
pool_pre_ping (bool, optional): Whether to ping the database before each connection. Defaults to True.
This fixes the issue of trying to use a conection that has timed out on the database side.
vector_engine (bool, optional): Whether to use the vector search schema. Defaults to False.
This updates the search path the incldued the vecore types and operators.
Returns:
Engine: The SQLAlchemy engine.
"""
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info(name)
url = URL.create(
@@ -60,14 +44,8 @@ def get_postgres_engine(
database=database,
)
connect_args = {}
# There more better way to do this is with separate PG account and a dedicated vector schema for the vector types
if vector_engine:
connect_args["options"] = "-csearch_path=main,public"
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
connect_args=connect_args,
)
+11
View File
@@ -0,0 +1,11 @@
"""Data science dev database ORM exports."""
from __future__ import annotations
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase, DataScienceDevTableBaseBig
__all__ = [
"DataScienceDevBase",
"DataScienceDevTableBase",
"DataScienceDevTableBaseBig",
]
+52
View File
@@ -0,0 +1,52 @@
"""Data science dev database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class DataScienceDevBase(DeclarativeBase):
"""Base class for data_science_dev database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class _TableMixin:
"""Shared timestamp columns for all table bases."""
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
class DataScienceDevTableBase(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with Integer primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
class DataScienceDevTableBaseBig(_TableMixin, AbstractConcreteBase, DataScienceDevBase):
"""Table with BigInteger primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
@@ -0,0 +1,14 @@
"""init."""
from python.orm.data_science_dev.congress.bill import Bill, BillText
from python.orm.data_science_dev.congress.legislator import Legislator, LegislatorSocialMedia
from python.orm.data_science_dev.congress.vote import Vote, VoteRecord
__all__ = [
"Bill",
"BillText",
"Legislator",
"LegislatorSocialMedia",
"Vote",
"VoteRecord",
]
@@ -0,0 +1,66 @@
"""Bill model - legislation introduced in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from python.orm.data_science_dev.congress.vote import Vote
class Bill(DataScienceDevTableBase):
"""Legislation with congress number, type, titles, status, and sponsor."""
__tablename__ = "bill"
congress: Mapped[int]
bill_type: Mapped[str]
number: Mapped[int]
title: Mapped[str | None]
title_short: Mapped[str | None]
official_title: Mapped[str | None]
status: Mapped[str | None]
status_at: Mapped[date | None]
sponsor_bioguide_id: Mapped[str | None]
subjects_top_term: Mapped[str | None]
votes: Mapped[list[Vote]] = relationship(
"Vote",
back_populates="bill",
)
bill_texts: Mapped[list[BillText]] = relationship(
"BillText",
back_populates="bill",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
Index("ix_bill_congress", "congress"),
)
class BillText(DataScienceDevTableBase):
"""Stores different text versions of a bill (introduced, enrolled, etc.)."""
__tablename__ = "bill_text"
bill_id: Mapped[int] = mapped_column(ForeignKey("main.bill.id", ondelete="CASCADE"))
version_code: Mapped[str]
version_name: Mapped[str | None]
text_content: Mapped[str | None]
date: Mapped[date | None]
bill: Mapped[Bill] = relationship("Bill", back_populates="bill_texts")
__table_args__ = (UniqueConstraint("bill_id", "version_code", name="uq_bill_text_bill_id_version_code"),)
@@ -0,0 +1,66 @@
"""Legislator model - members of Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.data_science_dev.base import DataScienceDevTableBase
if TYPE_CHECKING:
from python.orm.data_science_dev.congress.vote import VoteRecord
class Legislator(DataScienceDevTableBase):
"""Members of Congress with identification and current term info."""
__tablename__ = "legislator"
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
thomas_id: Mapped[str | None]
lis_id: Mapped[str | None]
govtrack_id: Mapped[int | None]
opensecrets_id: Mapped[str | None]
fec_ids: Mapped[str | None]
first_name: Mapped[str]
last_name: Mapped[str]
official_full_name: Mapped[str | None]
nickname: Mapped[str | None]
birthday: Mapped[date | None]
gender: Mapped[str | None]
current_party: Mapped[str | None]
current_state: Mapped[str | None]
current_district: Mapped[int | None]
current_chamber: Mapped[str | None]
social_media_accounts: Mapped[list[LegislatorSocialMedia]] = relationship(
"LegislatorSocialMedia",
back_populates="legislator",
cascade="all, delete-orphan",
)
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="legislator",
cascade="all, delete-orphan",
)
class LegislatorSocialMedia(DataScienceDevTableBase):
"""Social media account linked to a legislator."""
__tablename__ = "legislator_social_media"
legislator_id: Mapped[int] = mapped_column(ForeignKey("main.legislator.id"))
platform: Mapped[str]
account_name: Mapped[str]
url: Mapped[str | None]
source: Mapped[str]
legislator: Mapped[Legislator] = relationship(back_populates="social_media_accounts")
@@ -0,0 +1,79 @@
"""Vote model - roll call votes in Congress."""
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.data_science_dev.base import DataScienceDevBase, DataScienceDevTableBase
if TYPE_CHECKING:
from python.orm.data_science_dev.congress.bill import Bill
from python.orm.data_science_dev.congress.legislator import Legislator
from python.orm.data_science_dev.congress.vote import Vote
class VoteRecord(DataScienceDevBase):
"""Links a vote to a legislator with their position (Yea, Nay, etc.)."""
__tablename__ = "vote_record"
vote_id: Mapped[int] = mapped_column(
ForeignKey("main.vote.id", ondelete="CASCADE"),
primary_key=True,
)
legislator_id: Mapped[int] = mapped_column(
ForeignKey("main.legislator.id", ondelete="CASCADE"),
primary_key=True,
)
position: Mapped[str]
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")
class Vote(DataScienceDevTableBase):
"""Roll call votes with counts and optional bill linkage."""
__tablename__ = "vote"
congress: Mapped[int]
chamber: Mapped[str]
session: Mapped[int]
number: Mapped[int]
vote_type: Mapped[str | None]
question: Mapped[str | None]
result: Mapped[str | None]
result_text: Mapped[str | None]
vote_date: Mapped[date]
yea_count: Mapped[int | None]
nay_count: Mapped[int | None]
not_voting_count: Mapped[int | None]
present_count: Mapped[int | None]
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="vote",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint(
"congress",
"chamber",
"session",
"number",
name="uq_vote_congress_chamber_session_number",
),
Index("ix_vote_date", "vote_date"),
Index("ix_vote_congress_chamber", "congress", "chamber"),
)
+16
View File
@@ -0,0 +1,16 @@
"""Data science dev database ORM models."""
from __future__ import annotations
from python.orm.data_science_dev.congress import Bill, BillText, Legislator, Vote, VoteRecord
from python.orm.data_science_dev.posts import partitions # noqa: F401 — registers partition classes in metadata
from python.orm.data_science_dev.posts.tables import Posts
__all__ = [
"Bill",
"BillText",
"Legislator",
"Posts",
"Vote",
"VoteRecord",
]
@@ -0,0 +1,11 @@
"""Posts module — weekly-partitioned posts table and partition ORM models."""
from __future__ import annotations
from python.orm.data_science_dev.posts.failed_ingestion import FailedIngestion
from python.orm.data_science_dev.posts.tables import Posts
__all__ = [
"FailedIngestion",
"Posts",
]
@@ -0,0 +1,33 @@
"""Shared column definitions for the posts partitioned table family."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, SmallInteger, Text
from sqlalchemy.orm import Mapped, mapped_column
class PostsColumns:
"""Mixin providing all posts columns. Used by both the parent table and partitions."""
post_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger)
instance: Mapped[str]
date: Mapped[datetime] = mapped_column(primary_key=True)
text: Mapped[str] = mapped_column(Text)
langs: Mapped[str | None]
like_count: Mapped[int]
reply_count: Mapped[int]
repost_count: Mapped[int]
reply_to: Mapped[int | None] = mapped_column(BigInteger)
replied_author: Mapped[int | None] = mapped_column(BigInteger)
thread_root: Mapped[int | None] = mapped_column(BigInteger)
thread_root_author: Mapped[int | None] = mapped_column(BigInteger)
repost_from: Mapped[int | None] = mapped_column(BigInteger)
reposted_author: Mapped[int | None] = mapped_column(BigInteger)
quotes: Mapped[int | None] = mapped_column(BigInteger)
quoted_author: Mapped[int | None] = mapped_column(BigInteger)
labels: Mapped[str | None]
sent_label: Mapped[int | None] = mapped_column(SmallInteger)
sent_score: Mapped[float | None]
@@ -0,0 +1,17 @@
"""Table for storing JSONL lines that failed during post ingestion."""
from __future__ import annotations
from sqlalchemy import Text
from sqlalchemy.orm import Mapped, mapped_column
from python.orm.data_science_dev.base import DataScienceDevTableBase
class FailedIngestion(DataScienceDevTableBase):
"""Stores raw JSONL lines and their error messages when ingestion fails."""
__tablename__ = "failed_ingestion"
raw_line: Mapped[str] = mapped_column(Text)
error: Mapped[str] = mapped_column(Text)
@@ -0,0 +1,71 @@
"""Dynamically generated ORM classes for each weekly partition of the posts table.
Each class maps to a PostgreSQL partition table (e.g. posts_2024_01).
These are real ORM models tracked by Alembic autogenerate.
Uses ISO week numbering (datetime.isocalendar().week). ISO years can have
52 or 53 weeks, and week boundaries are always Monday to Monday.
"""
from __future__ import annotations
import sys
from datetime import UTC, datetime
from python.orm.data_science_dev.base import DataScienceDevBase
from python.orm.data_science_dev.posts.columns import PostsColumns
PARTITION_START_YEAR = 2023
PARTITION_END_YEAR = 2026
_current_module = sys.modules[__name__]
def iso_weeks_in_year(year: int) -> int:
"""Return the number of ISO weeks in a given year (52 or 53)."""
dec_28 = datetime(year, 12, 28, tzinfo=UTC)
return dec_28.isocalendar().week
def week_bounds(year: int, week: int) -> tuple[datetime, datetime]:
"""Return (start, end) datetimes for an ISO week.
Start = Monday 00:00:00 UTC of the given ISO week.
End = Monday 00:00:00 UTC of the following ISO week.
"""
start = datetime.fromisocalendar(year, week, 1).replace(tzinfo=UTC)
if week < iso_weeks_in_year(year):
end = datetime.fromisocalendar(year, week + 1, 1).replace(tzinfo=UTC)
else:
end = datetime.fromisocalendar(year + 1, 1, 1).replace(tzinfo=UTC)
return start, end
def _build_partition_classes() -> dict[str, type]:
"""Generate one ORM class per ISO week partition."""
classes: dict[str, type] = {}
for year in range(PARTITION_START_YEAR, PARTITION_END_YEAR + 1):
for week in range(1, iso_weeks_in_year(year) + 1):
class_name = f"PostsWeek{year}W{week:02d}"
table_name = f"posts_{year}_{week:02d}"
partition_class = type(
class_name,
(PostsColumns, DataScienceDevBase),
{
"__tablename__": table_name,
"__table_args__": ({"implicit_returning": False},),
},
)
classes[class_name] = partition_class
return classes
# Generate all partition classes and register them on this module
_partition_classes = _build_partition_classes()
for _name, _cls in _partition_classes.items():
setattr(_current_module, _name, _cls)
__all__ = list(_partition_classes.keys())
@@ -0,0 +1,13 @@
"""Posts parent table with PostgreSQL weekly range partitioning on date column."""
from __future__ import annotations
from python.orm.data_science_dev.base import DataScienceDevBase
from python.orm.data_science_dev.posts.columns import PostsColumns
class Posts(PostsColumns, DataScienceDevBase):
"""Parent partitioned table for posts, partitioned by week on `date`."""
__tablename__ = "posts"
__table_args__ = ({"postgresql_partition_by": "RANGE (date)"},)
+2 -10
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
from datetime import datetime
from pgvector.sqlalchemy import Vector
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, Index, String, UniqueConstraint
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.richie.base import TableBase, TableBaseBig
@@ -101,15 +101,7 @@ class EbookChunkEmbedding1024(TableBaseBig):
"""1024-dimensional chunk embedding."""
__tablename__ = "ebook_chunk_embedding_1024"
__table_args__ = (
UniqueConstraint("chunk_id", "model_id"),
Index(
"ix_ebook_chunk_embedding_1024_embedding_cosine",
"embedding",
postgresql_using="hnsw",
postgresql_ops={"embedding": "vector_cosine_ops"},
),
)
__table_args__ = (UniqueConstraint("chunk_id", "model_id"),)
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
+16
View File
@@ -0,0 +1,16 @@
"""Signal bot database ORM exports."""
from __future__ import annotations
from python.orm.signal_bot.base import SignalBotBase, SignalBotTableBase, SignalBotTableBaseSmall
from python.orm.signal_bot.models import DeadLetterMessage, DeviceRole, RoleRecord, SignalDevice
__all__ = [
"DeadLetterMessage",
"DeviceRole",
"RoleRecord",
"SignalBotBase",
"SignalBotTableBase",
"SignalBotTableBaseSmall",
"SignalDevice",
]
+52
View File
@@ -0,0 +1,52 @@
"""Signal bot database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, MetaData, SmallInteger, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class SignalBotBase(DeclarativeBase):
"""Base class for signal_bot database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class _TableMixin:
"""Shared timestamp columns for all table bases."""
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
class SignalBotTableBaseSmall(_TableMixin, AbstractConcreteBase, SignalBotBase):
"""Table with SmallInteger primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(SmallInteger, primary_key=True)
class SignalBotTableBase(_TableMixin, AbstractConcreteBase, SignalBotBase):
"""Table with Integer primary key."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
+62
View File
@@ -0,0 +1,62 @@
"""Signal bot device, role, and dead letter ORM models."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
from python.signal_bot.models import MessageStatus, TrustLevel
class RoleRecord(SignalBotTableBaseSmall):
"""Lookup table for RBAC roles, keyed by smallint."""
__tablename__ = "role"
name: Mapped[str] = mapped_column(String(50), unique=True)
class DeviceRole(SignalBotTableBase):
"""Association between a device and a role."""
__tablename__ = "device_role"
__table_args__ = (
UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"),
{"schema": "main"},
)
device_id: Mapped[int] = mapped_column(ForeignKey("main.signal_device.id"))
role_id: Mapped[int] = mapped_column(SmallInteger, ForeignKey("main.role.id"))
class SignalDevice(SignalBotTableBase):
"""A Signal device tracked by phone number and safety number."""
__tablename__ = "signal_device"
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
safety_number: Mapped[str | None]
trust_level: Mapped[TrustLevel] = mapped_column(
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
default=TrustLevel.UNVERIFIED,
)
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
roles: Mapped[list[RoleRecord]] = relationship(secondary=DeviceRole.__table__)
class DeadLetterMessage(SignalBotTableBase):
"""A Signal message that failed processing and was sent to the dead letter queue."""
__tablename__ = "dead_letter_message"
source: Mapped[str]
message: Mapped[str] = mapped_column(Text)
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
status: Mapped[MessageStatus] = mapped_column(
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
default=MessageStatus.UNPROCESSED,
)
+1
View File
@@ -0,0 +1 @@
"""Signal command and control bot."""
+1
View File
@@ -0,0 +1 @@
"""Signal bot commands."""
+137
View File
@@ -0,0 +1,137 @@
"""Van inventory command — parse receipts and item lists via LLM, push to API."""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any
import httpx
from python.signal_bot.models import InventoryItem, InventoryUpdate
if TYPE_CHECKING:
from python.signal_bot.llm_client import LLMClient
from python.signal_bot.models import SignalMessage
from python.signal_bot.signal_client import SignalClient
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """\
You are an inventory assistant. Extract items from the input and return ONLY
a JSON array. Each element must have these fields:
- "name": item name (string)
- "quantity": numeric count or amount (default 1)
- "unit": unit of measure (e.g. "each", "lb", "oz", "gallon", "bag", "box")
- "category": category like "food", "tools", "supplies", etc.
- "notes": any extra detail (empty string if none)
Example output:
[{"name": "water bottles", "quantity": 6, "unit": "gallon", "category": "supplies", "notes": "1 gallon each"}]
Return ONLY the JSON array, no other text.\
"""
IMAGE_PROMPT = "Extract all items from this receipt or inventory photo."
TEXT_PROMPT = "Extract all items from this inventory list."
def parse_llm_response(raw: str) -> list[InventoryItem]:
"""Parse the LLM JSON response into InventoryItem list."""
text = raw.strip()
# Strip markdown code fences if present
if text.startswith("```"):
lines = text.split("\n")
lines = [line for line in lines if not line.startswith("```")]
text = "\n".join(lines)
items_data: list[dict[str, Any]] = json.loads(text)
return [InventoryItem.model_validate(item) for item in items_data]
def _upsert_item(api_url: str, item: InventoryItem) -> None:
"""Create or update an item via the van_inventory API.
Fetches existing items, and if one with the same name exists,
patches its quantity (summing). Otherwise creates a new item.
"""
base = api_url.rstrip("/")
response = httpx.get(f"{base}/api/items", timeout=10)
response.raise_for_status()
existing: list[dict[str, Any]] = response.json()
match = next((e for e in existing if e["name"].lower() == item.name.lower()), None)
if match:
new_qty = match["quantity"] + item.quantity
patch = {"quantity": new_qty}
if item.category:
patch["category"] = item.category
response = httpx.patch(f"{base}/api/items/{match['id']}", json=patch, timeout=10)
response.raise_for_status()
return
payload = {
"name": item.name,
"quantity": item.quantity,
"unit": item.unit,
"category": item.category or None,
}
response = httpx.post(f"{base}/api/items", json=payload, timeout=10)
response.raise_for_status()
def handle_inventory_update(
message: SignalMessage,
signal: SignalClient,
llm: LLMClient,
api_url: str,
) -> InventoryUpdate:
"""Process an inventory update from a Signal message.
Accepts either an image (receipt photo) or text list.
Uses the LLM to extract structured items, then pushes to the van_inventory API.
"""
try:
logger.info(f"Processing inventory update from {message.source}")
if message.attachments:
image_data = signal.get_attachment(message.attachments[0])
raw_response = llm.chat(
IMAGE_PROMPT,
image_data=image_data,
system=SYSTEM_PROMPT,
)
source_type = "receipt_photo"
elif message.message.strip():
raw_response = llm.chat(
f"{TEXT_PROMPT}\n\n{message.message}",
system=SYSTEM_PROMPT,
)
source_type = "text_list"
else:
signal.reply(message, "Send a photo of a receipt or a text list of items to update inventory.")
return InventoryUpdate()
logger.info(f"{raw_response=}")
new_items = parse_llm_response(raw_response)
logger.info(f"{new_items=}")
for item in new_items:
_upsert_item(api_url, item)
summary = _format_summary(new_items)
signal.reply(message, f"Inventory updated with {len(new_items)} item(s):\n{summary}")
return InventoryUpdate(items=new_items, raw_response=raw_response, source_type=source_type)
except Exception:
logger.exception("Failed to process inventory update")
signal.reply(message, "Failed to process inventory update. Check logs for details.")
return InventoryUpdate()
def _format_summary(items: list[InventoryItem]) -> str:
"""Format items into a readable summary."""
lines = [f" - {item.name} x{item.quantity} {item.unit} [{item.category}]" for item in items]
return "\n".join(lines)
+64
View File
@@ -0,0 +1,64 @@
"""Location command for the Signal bot."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from python.signal_bot.models import SignalMessage
from python.signal_bot.signal_client import SignalClient
logger = logging.getLogger(__name__)
def _get_entity_state(ha_url: str, ha_token: str, entity_id: str) -> dict[str, Any]:
"""Fetch an entity's state from Home Assistant."""
entity_url = f"{ha_url}/api/states/{entity_id}"
logger.debug(f"Fetching {entity_url=}")
response = httpx.get(
entity_url,
headers={"Authorization": f"Bearer {ha_token}"},
timeout=30,
)
response.raise_for_status()
return response.json()
def _format_location(latitude: str, longitude: str) -> str:
"""Render a friendly location response."""
return f"Van location: {latitude}, {longitude}\nhttps://maps.google.com/?q={latitude},{longitude}"
def handle_location_request(
message: SignalMessage,
signal: SignalClient,
ha_url: str | None,
ha_token: str | None,
) -> None:
"""Reply with van location from Home Assistant."""
if ha_url is None or ha_token is None:
signal.reply(message, "Location command is not configured (missing HA_URL or HA_TOKEN).")
return
lat_payload = None
lon_payload = None
try:
lat_payload = _get_entity_state(ha_url, ha_token, "sensor.van_last_known_latitude")
lon_payload = _get_entity_state(ha_url, ha_token, "sensor.van_last_known_longitude")
except httpx.HTTPError:
logger.exception("Couldn't fetch van location from Home Assistant right now.")
logger.debug(f"{ha_url=} {lat_payload=} {lon_payload=}")
signal.reply(message, "Couldn't fetch van location from Home Assistant right now.")
return
latitude = lat_payload.get("state", "")
longitude = lon_payload.get("state", "")
if not latitude or not longitude or latitude == "unavailable" or longitude == "unavailable":
signal.reply(message, "Van location is unavailable in Home Assistant right now.")
return
signal.reply(message, _format_location(latitude, longitude))
+284
View File
@@ -0,0 +1,284 @@
"""Device registry — tracks verified/unverified devices by safety number."""
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, NamedTuple
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from python.common import utcnow
from python.orm.signal_bot.models import RoleRecord, SignalDevice
from python.signal_bot.models import Role, TrustLevel
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
from python.signal_bot.signal_client import SignalClient
logger = logging.getLogger(__name__)
_BLOCKED_TTL = timedelta(minutes=60)
_DEFAULT_TTL = timedelta(minutes=5)
class _CacheEntry(NamedTuple):
expires: datetime
trust_level: TrustLevel
has_safety_number: bool
safety_number: str | None
roles: list[Role]
class DeviceRegistry:
"""Manage device trust based on Signal safety numbers.
Devices start as UNVERIFIED. An admin verifies them over SSH by calling
``verify(phone_number)`` which marks the device VERIFIED and also tells
signal-cli to trust the identity.
Only VERIFIED devices may execute commands.
"""
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
self.signal_client = signal_client
self.engine = engine
self._contact_cache: dict[str, _CacheEntry] = {}
def is_verified(self, phone_number: str) -> bool:
"""Check if a phone number is verified."""
if entry := self._cached(phone_number):
return entry.trust_level == TrustLevel.VERIFIED
device = self._load_device(phone_number)
return device is not None and device.trust_level == TrustLevel.VERIFIED
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
"""Record seeing a device. Creates entry if new, updates last_seen."""
now = utcnow()
entry = self._cached(phone_number)
if entry and entry.safety_number == safety_number:
return
with Session(self.engine) as session:
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).one_or_none()
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
device.safety_number = safety_number
device.trust_level = TrustLevel.UNVERIFIED
device.last_seen = now
else:
device = SignalDevice(
phone_number=phone_number,
safety_number=safety_number,
trust_level=TrustLevel.UNVERIFIED,
last_seen=now,
)
session.add(device)
logger.info(f"New device registered: {phone_number}")
session.commit()
self._update_cache(phone_number, device)
def has_safety_number(self, phone_number: str) -> bool:
"""Check if a device has a safety number on file."""
if entry := self._cached(phone_number):
return entry.has_safety_number
device = self._load_device(phone_number)
return device is not None and device.safety_number is not None
def verify(self, phone_number: str) -> bool:
"""Mark a device as verified. Called by admin over SSH.
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).one_or_none()
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
return False
device.trust_level = TrustLevel.VERIFIED
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device verified: {phone_number}")
return True
def block(self, phone_number: str) -> bool:
"""Block a device."""
return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
def unverify(self, phone_number: str) -> bool:
"""Reset a device to unverified."""
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
# -- role management ------------------------------------------------------
def get_roles(self, phone_number: str) -> list[Role]:
"""Return the roles for a device, defaulting to empty."""
if entry := self._cached(phone_number):
return entry.roles
device = self._load_device(phone_number)
return _extract_roles(device) if device else []
def has_role(self, phone_number: str, role: Role) -> bool:
"""Check if a device has a specific role or is admin."""
roles = self.get_roles(phone_number)
return Role.ADMIN in roles or role in roles
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).one_or_none()
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
return False
if any(record.name == role for record in device.roles):
return True
role_record = session.scalars(select(RoleRecord).where(RoleRecord.name == role)).one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
return False
device.roles.append(role_record)
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} granted role {role}")
return True
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).one_or_none()
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
return False
device.roles = [record for record in device.roles if record.name != role]
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} revoked role {role}")
return True
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).one_or_none()
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(role) for role in roles]
records = session.scalars(select(RoleRecord).where(RoleRecord.name.in_(role_names))).all()
device.roles = records
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} roles set to {role_names}")
return True
# -- queries --------------------------------------------------------------
def list_devices(self) -> list[SignalDevice]:
"""Return all known devices."""
with Session(self.engine) as session:
return list(session.scalars(select(SignalDevice)).all())
def sync_identities(self) -> None:
"""Pull identity list from signal-cli and record any new ones."""
identities = self.signal_client.get_identities()
for identity in identities:
number = identity.get("number", "")
safety = identity.get("safety_number", identity.get("fingerprint", ""))
if number:
self.record_contact(number, safety)
# -- internals ------------------------------------------------------------
def _cached(self, phone_number: str) -> _CacheEntry | None:
"""Return the cache entry if it exists and hasn't expired."""
entry = self._contact_cache.get(phone_number)
if entry and utcnow() < entry.expires:
return entry
return None
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return session.scalars(select(SignalDevice).where(SignalDevice.phone_number == phone_number)).one_or_none()
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + ttl,
trust_level=device.trust_level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
roles=_extract_roles(device),
)
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = session.scalars(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).one_or_none()
if not device:
return False
device.trust_level = level
session.commit()
self._update_cache(phone_number, device)
if log_msg:
logger.info(f"{log_msg}: {phone_number}")
return True
def _extract_roles(device: SignalDevice) -> list[Role]:
"""Convert a device's RoleRecord objects to a list of Role enums."""
return [Role(record.name) for record in device.roles]
def sync_roles(engine: Engine) -> None:
"""Sync the Role enum to the role table, adding new and removing stale entries."""
expected = {role.value for role in Role}
with Session(engine) as session:
existing = set(session.scalars(select(RoleRecord.name)).all())
to_add = expected - existing
to_remove = existing - expected
for name in to_add:
session.add(RoleRecord(name=name))
logger.info(f"Role added: {name}")
if to_remove:
session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove)))
for name in to_remove:
logger.info(f"Role removed: {name}")
session.commit()
+80
View File
@@ -0,0 +1,80 @@
"""Flexible LLM client for ollama backends."""
from __future__ import annotations
import base64
import logging
from typing import Any, Self
import httpx
logger = logging.getLogger(__name__)
class LLMClient:
"""Talk to an ollama instance.
Args:
model: Ollama model name.
host: Ollama host.
port: Ollama port.
temperature: Sampling temperature.
"""
def __init__(
self,
*,
model: str,
host: str,
port: int = 11434,
temperature: float = 0.1,
timeout: int = 300,
) -> None:
self.model = model
self.temperature = temperature
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
def chat(self, prompt: str, image_data: bytes | None = None, system: str | None = None) -> str:
"""Send a text prompt and return the response."""
messages: list[dict[str, Any]] = []
if system:
messages.append({"role": "system", "content": system})
user_msg = {"role": "user", "content": prompt}
if image_data:
user_msg["images"] = [base64.b64encode(image_data).decode()]
messages.append(user_msg)
return self._generate(messages)
def _generate(self, messages: list[dict[str, Any]]) -> str:
"""Call the ollama chat API."""
payload = {
"model": self.model,
"messages": messages,
"stream": False,
"options": {"temperature": self.temperature},
}
logger.info(f"LLM request to {self.model}")
response = self._client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
return data["message"]["content"]
def list_models(self) -> list[str]:
"""List available models on the ollama instance."""
response = self._client.get("/api/tags")
response.raise_for_status()
return [m["name"] for m in response.json().get("models", [])]
def __enter__(self) -> Self:
"""Enter the context manager."""
return self
def __exit__(self, *args: object) -> None:
"""Close the HTTP client on exit."""
self.close()
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
+239
View File
@@ -0,0 +1,239 @@
"""Signal command and control bot — main entry point."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from os import getenv
from typing import TYPE_CHECKING, Annotated
if TYPE_CHECKING:
from collections.abc import Callable
import typer
from alembic.command import upgrade
from sqlalchemy.orm import Session
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
from python.common import configure_logger, utcnow
from python.database_cli import DATABASES
from python.orm.common import get_postgres_engine
from python.orm.signal_bot.models import DeadLetterMessage
from python.signal_bot.commands.inventory import handle_inventory_update
from python.signal_bot.commands.location import handle_location_request
from python.signal_bot.device_registry import DeviceRegistry, sync_roles
from python.signal_bot.llm_client import LLMClient
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
from python.signal_bot.signal_client import SignalClient
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class Command:
"""A registered bot command."""
action: Callable[[SignalMessage, str], None]
help_text: str
role: Role | None # None = no role required (always allowed)
class Bot:
"""Holds shared resources and dispatches incoming messages to command handlers."""
def __init__(
self,
signal: SignalClient,
llm: LLMClient,
registry: DeviceRegistry,
config: BotConfig,
) -> None:
self.signal = signal
self.llm = llm
self.registry = registry
self.config = config
self.commands: dict[str, Command] = {
"help": Command(action=self._help, help_text="show this help message", role=None),
"status": Command(action=self._status, help_text="show bot status", role=Role.STATUS),
"inventory": Command(
action=self._inventory,
help_text="update van inventory from a text list or receipt photo",
role=Role.INVENTORY,
),
"location": Command(
action=self._location,
help_text="get current van location",
role=Role.LOCATION,
),
}
# -- actions --------------------------------------------------------------
def _help(self, message: SignalMessage, _cmd: str) -> None:
"""Return help text filtered to the sender's roles."""
self.signal.reply(message, self._build_help(self.registry.get_roles(message.source)))
def _status(self, message: SignalMessage, _cmd: str) -> None:
"""Return the status of the bot."""
models = self.llm.list_models()
model_list = ", ".join(models[:10])
device_count = len(self.registry.list_devices())
self.signal.reply(
message,
f"Bot online.\nLLM: {self.llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}",
)
def _inventory(self, message: SignalMessage, _cmd: str) -> None:
"""Process an inventory update."""
handle_inventory_update(message, self.signal, self.llm, self.config.inventory_api_url)
def _location(self, message: SignalMessage, _cmd: str) -> None:
"""Reply with current van location."""
handle_location_request(message, self.signal, self.config.ha_url, self.config.ha_token)
# -- dispatch -------------------------------------------------------------
def _build_help(self, roles: list[Role]) -> str:
"""Build help text showing only the commands the user can access."""
is_admin = Role.ADMIN in roles
lines = ["Available commands:"]
for name, cmd in self.commands.items():
if cmd.role is None or is_admin or cmd.role in roles:
lines.append(f" {name:20s}{cmd.help_text}")
return "\n".join(lines)
def dispatch(self, message: SignalMessage) -> None:
"""Route an incoming message to the right command handler."""
source = message.source
if not self.registry.is_verified(source):
logger.info(f"Device {source} not verified, ignoring message")
return
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
logger.warning(f"Admin device {source} missing safety number, ignoring message")
return
text = message.message.strip()
parts = text.split()
if not parts and not message.attachments:
return
cmd = parts[0].lower() if parts else ""
logger.info(f"f{source=} running {cmd=} with {message=}")
command = self.commands.get(cmd)
if command is None:
if message.attachments:
command = self.commands["inventory"]
cmd = "inventory"
else:
return
if command.role is not None and not self.registry.has_role(source, command.role):
logger.warning(f"Device {source} denied access to {cmd!r}")
self.signal.reply(message, f"Permission denied: you do not have the '{command.role}' role.")
return
command.action(message, cmd)
def process_message(self, message: SignalMessage) -> None:
"""Process a single message, sending it to the dead letter queue after repeated failures."""
max_attempts = self.config.max_message_attempts
for attempt in range(1, max_attempts + 1):
try:
safety_number = self.signal.get_safety_number(message.source)
self.registry.record_contact(message.source, safety_number)
self.dispatch(message)
except Exception:
logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})")
else:
return
logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue")
with Session(self.config.engine) as session:
session.add(
DeadLetterMessage(
source=message.source,
message=message.message,
received_at=utcnow(),
status=MessageStatus.UNPROCESSED,
)
)
session.commit()
def run(self) -> None:
"""Listen for messages via WebSocket, reconnecting on failure."""
logger.info("Bot started — listening via WebSocket")
@retry(
stop=stop_after_attempt(self.config.max_retries),
wait=wait_exponential(multiplier=self.config.reconnect_delay, max=self.config.max_reconnect_delay),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
def _listen() -> None:
for message in self.signal.listen():
logger.info(f"Message from {message.source}: {message.message[:80]}")
self.process_message(message)
try:
_listen()
except Exception:
logger.critical("Max retries exceeded, shutting down")
raise
def main(
log_level: Annotated[str, typer.Option()] = "DEBUG",
llm_timeout: Annotated[int, typer.Option()] = 600,
) -> None:
"""Run the Signal command and control bot."""
configure_logger(log_level)
signal_api_url = getenv("SIGNAL_API_URL")
phone_number = getenv("SIGNAL_PHONE_NUMBER")
inventory_api_url = getenv("INVENTORY_API_URL")
if signal_api_url is None:
error = "SIGNAL_API_URL environment variable not set"
raise ValueError(error)
if phone_number is None:
error = "SIGNAL_PHONE_NUMBER environment variable not set"
raise ValueError(error)
if inventory_api_url is None:
error = "INVENTORY_API_URL environment variable not set"
raise ValueError(error)
signal_bot_config = DATABASES["signal_bot"].alembic_config()
upgrade(signal_bot_config, "head")
engine = get_postgres_engine(name="SIGNALBOT")
sync_roles(engine)
config = BotConfig(
signal_api_url=signal_api_url,
phone_number=phone_number,
inventory_api_url=inventory_api_url,
ha_url=getenv("HA_URL"),
ha_token=getenv("HA_TOKEN"),
engine=engine,
)
llm_host = getenv("LLM_HOST")
llm_model = getenv("LLM_MODEL", "qwen3-vl:32b")
llm_port = int(getenv("LLM_PORT", "11434"))
if llm_host is None:
error = "LLM_HOST environment variable not set"
raise ValueError(error)
with (
SignalClient(config.signal_api_url, config.phone_number) as signal,
LLMClient(model=llm_model, host=llm_host, port=llm_port, timeout=llm_timeout) as llm,
):
registry = DeviceRegistry(signal, engine)
bot = Bot(signal, llm, registry, config)
bot.run()
if __name__ == "__main__":
typer.run(main)
+97
View File
@@ -0,0 +1,97 @@
"""Models for the Signal command and control bot."""
from __future__ import annotations
from datetime import datetime # noqa: TC003 - pydantic needs this at runtime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict
from sqlalchemy.engine import Engine # noqa: TC002 - pydantic needs this at runtime
class TrustLevel(StrEnum):
"""Device trust level."""
VERIFIED = "verified"
UNVERIFIED = "unverified"
BLOCKED = "blocked"
class Role(StrEnum):
"""RBAC roles — one per command, plus admin which grants all."""
ADMIN = "admin"
STATUS = "status"
INVENTORY = "inventory"
LOCATION = "location"
class MessageStatus(StrEnum):
"""Dead letter queue message status."""
UNPROCESSED = "unprocessed"
PROCESSED = "processed"
class Device(BaseModel):
"""A registered device tracked by safety number."""
phone_number: str
safety_number: str
trust_level: TrustLevel = TrustLevel.UNVERIFIED
first_seen: datetime
last_seen: datetime
class SignalMessage(BaseModel):
"""An incoming Signal message."""
source: str
timestamp: int
message: str = ""
attachments: list[str] = []
group_id: str | None = None
is_receipt: bool = False
class SignalEnvelope(BaseModel):
"""Raw envelope from signal-cli-rest-api."""
envelope: dict[str, Any]
account: str | None = None
class InventoryItem(BaseModel):
"""An item in the van inventory."""
name: str
quantity: float = 1
unit: str = "each"
category: str = ""
notes: str = ""
class InventoryUpdate(BaseModel):
"""Result of processing an inventory update."""
items: list[InventoryItem] = []
raw_response: str = ""
source_type: str = "" # "receipt_photo" or "text_list"
class BotConfig(BaseModel):
"""Top-level bot configuration."""
model_config = ConfigDict(arbitrary_types_allowed=True)
signal_api_url: str
phone_number: str
inventory_api_url: str
ha_url: str | None = None
ha_token: str | None = None
engine: Engine
reconnect_delay: int = 5
max_reconnect_delay: int = 300
max_retries: int = 10
max_message_attempts: int = 3
+141
View File
@@ -0,0 +1,141 @@
"""Client for the signal-cli-rest-api."""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any, Self
import httpx
import websockets.sync.client
if TYPE_CHECKING:
from collections.abc import Generator
from python.signal_bot.models import SignalMessage
logger = logging.getLogger(__name__)
def _parse_envelope(envelope: dict[str, Any]) -> SignalMessage | None:
"""Parse a signal-cli envelope into a SignalMessage, or None if not a data message."""
data_message = envelope.get("dataMessage")
if not data_message:
return None
attachment_ids = [att["id"] for att in data_message.get("attachments", []) if "id" in att]
group_info = data_message.get("groupInfo")
group_id = group_info.get("groupId") if group_info else None
return SignalMessage(
source=envelope.get("source", ""),
timestamp=envelope.get("timestamp", 0),
message=data_message.get("message", "") or "",
attachments=attachment_ids,
group_id=group_id,
)
class SignalClient:
"""Communicate with signal-cli-rest-api.
Args:
base_url: URL of the signal-cli-rest-api (e.g. http://localhost:8989).
phone_number: The registered phone number to send/receive as.
"""
def __init__(self, base_url: str, phone_number: str) -> None:
self.base_url = base_url.rstrip("/")
self.phone_number = phone_number
self._client = httpx.Client(base_url=self.base_url, timeout=30)
def _ws_url(self) -> str:
"""Build the WebSocket URL from the base HTTP URL."""
url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
return f"{url}/v1/receive/{self.phone_number}"
def listen(self) -> Generator[SignalMessage]:
"""Connect via WebSocket and yield messages as they arrive."""
ws_url = self._ws_url()
logger.info(f"Connecting to WebSocket: {ws_url}")
with websockets.sync.client.connect(ws_url) as ws:
for raw in ws:
try:
data = json.loads(raw)
envelope = data.get("envelope", {})
message = _parse_envelope(envelope)
if message:
yield message
except json.JSONDecodeError:
logger.warning(f"Non-JSON WebSocket frame: {raw[:200]}")
def send(self, recipient: str, message: str) -> None:
"""Send a text message."""
payload = {
"message": message,
"number": self.phone_number,
"recipients": [recipient],
}
response = self._client.post("/v2/send", json=payload)
response.raise_for_status()
def send_to_group(self, group_id: str, message: str) -> None:
"""Send a message to a group."""
payload = {
"message": message,
"number": self.phone_number,
"recipients": [group_id],
}
response = self._client.post("/v2/send", json=payload)
response.raise_for_status()
def get_attachment(self, attachment_id: str) -> bytes:
"""Download an attachment by ID."""
response = self._client.get(f"/v1/attachments/{attachment_id}")
response.raise_for_status()
return response.content
def get_identities(self) -> list[dict[str, Any]]:
"""List known identities and their trust levels."""
response = self._client.get(f"/v1/identities/{self.phone_number}")
response.raise_for_status()
return response.json()
def get_safety_number(self, phone_number: str) -> str | None:
"""Look up the safety number for a contact from signal-cli's local store."""
for identity in self.get_identities():
if identity.get("number") == phone_number:
return identity.get("safety_number", identity.get("fingerprint", ""))
return None
def trust_identity(self, number_to_trust: str, *, trust_all_known_keys: bool = False) -> None:
"""Trust an identity (verify safety number)."""
payload: dict[str, Any] = {}
if trust_all_known_keys:
payload["trust_all_known_keys"] = True
response = self._client.put(
f"/v1/identities/{self.phone_number}/trust/{number_to_trust}",
json=payload,
)
response.raise_for_status()
def reply(self, message: SignalMessage, text: str) -> None:
"""Reply to a message, routing to group or individual."""
if message.group_id:
self.send_to_group(message.group_id, text)
else:
self.send(message.source, text)
def __enter__(self) -> Self:
"""Enter the context manager."""
return self
def __exit__(self, *args: object) -> None:
"""Close the HTTP client on exit."""
self.close()
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
@@ -16,13 +16,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
logger = logging.getLogger(__name__)
ESCAPE_KEY = 27
def configure_logger(level: str = "INFO") -> None:
"""Configure the logger.
Args:
level (str, optional): The logging level. Defaults to "INFO".
"""
@@ -36,17 +32,15 @@ def configure_logger(level: str = "INFO") -> None:
def bash_wrapper(command: str) -> str:
"""Execute a bash command and capture the output.
Args:
command (str): The bash command to be executed.
Returns:
Tuple[str, int]: A tuple containing the output of the command (stdout) as a string,
the error output (stderr) as a string (optional), and the return code as an integer.
"""
logger.debug(f"running {command=}")
logging.debug(f"running {command=}")
# This is a acceptable risk
process = Popen(command.split(), stdout=PIPE, stderr=PIPE)
process = Popen(command.split(), stdout=PIPE, stderr=PIPE) # noqa: S603
output, _ = process.communicate()
if process.returncode != 0:
error = f"Failed to run command {command=} return code {process.returncode=}"
@@ -57,7 +51,6 @@ def bash_wrapper(command: str) -> str:
def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
"""Partition a disk.
Args:
disk (str): The disk to partition.
swap_size (int): The size of the swap partition in GB.
@@ -65,7 +58,7 @@ def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
reserve (int, optional): The size of the reserve partition in GB. Defaults to 0.
minimum value is 0.
"""
logger.info(f"partitioning {disk=}")
logging.info(f"partitioning {disk=}")
swap_size = max(swap_size, 1)
reserve = max(reserve, 0)
@@ -73,16 +66,16 @@ def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
if reserve > 0:
msg = f"Creating swap partition on {disk=} with size {swap_size=}GiB and reserve {reserve=}GiB"
logger.info(msg)
logging.info(msg)
swap_start = swap_size + reserve
swap_partition = f"mkpart swap -{swap_start}GiB -{reserve}GiB "
else:
logger.info(f"Creating swap partition on {disk=} with size {swap_size=}GiB")
logging.info(f"Creating swap partition on {disk=} with size {swap_size=}GiB")
swap_start = swap_size
swap_partition = f"mkpart swap -{swap_start}GiB 100% "
logger.debug(f"{swap_partition=}")
logging.debug(f"{swap_partition=}")
create_partitions = (
f"parted --script --align=optimal {disk} -- "
@@ -94,14 +87,13 @@ def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
)
bash_wrapper(create_partitions)
logger.info(f"{disk=} successfully partitioned")
logging.info(f"{disk=} successfully partitioned")
def create_zfs_pool(pool_disks: Sequence[str], mnt_dir: str) -> None:
"""Create a ZFS pool.
Args:
pool_disks (Sequence[str]): A tuple of disks to use for the pool.
disks (Sequence[str]): A tuple of disks to use for the pool.
mnt_dir (str): The mount directory.
"""
if len(pool_disks) <= 0:
@@ -133,12 +125,13 @@ def create_zfs_pool(pool_disks: Sequence[str], mnt_dir: str) -> None:
bash_wrapper(zpool_create)
zpools = bash_wrapper("zpool list -o name")
if "root_pool" not in zpools.splitlines():
logger.critical("Failed to create root_pool")
logging.critical("Failed to create root_pool")
sys.exit(1)
def create_zfs_datasets() -> None:
"""Create ZFS datasets."""
bash_wrapper("zfs create -o canmount=noauto -o reservation=10G root_pool/root")
bash_wrapper("zfs create root_pool/home")
bash_wrapper("zfs create root_pool/var -o reservation=1G")
@@ -153,7 +146,7 @@ def create_zfs_datasets() -> None:
}
missing_datasets = expected_datasets.difference(datasets.splitlines())
if missing_datasets:
logger.critical(f"Failed to create pools {missing_datasets}")
logging.critical(f"Failed to create pools {missing_datasets}")
sys.exit(1)
@@ -166,8 +159,6 @@ def get_cpu_manufacturer() -> str:
for line in output.splitlines():
if "vendor_id" in line:
return id_vendor[line.split(": ")[1].strip()]
error = "Failed to get CPU manufacturer"
raise RuntimeError(error)
def get_boot_drive_id(disk: str) -> str:
@@ -176,8 +167,9 @@ def get_boot_drive_id(disk: str) -> str:
return output.splitlines()[1]
def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], *, encrypt: bool) -> None:
def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], encrypt: bool) -> None:
"""Create a NixOS hardware file."""
cpu_manufacturer = get_cpu_manufacturer()
devices = ""
@@ -201,15 +193,7 @@ def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], *, encrypt: boo
' imports = [ (modulesPath + "/installer/scan/not-detected.nix") ];\n\n'
" boot = {\n"
" initrd = {\n"
" availableKernelModules = [ \n"
' "ahci"\n'
' "ehci_pci"\n'
' "nvme"\n'
' "sd_mod"\n'
' "usb_storage"\n'
' "usbhid"\n'
' "xhci_pci"\n'
" ];\n"
' availableKernelModules = [ \n "ahci"\n "ehci_pci"\n "nvme"\n "sd_mod"\n "usb_storage"\n "usbhid"\n "xhci_pci"\n ];\n'
" kernelModules = [ ];\n"
f" {devices}"
" };\n"
@@ -223,18 +207,11 @@ def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], *, encrypt: boo
' "/nix" = {\n device = "root_pool/nix";\n fsType = "zfs";\n };\n\n'
' "/boot" = {\n'
f' device = "/dev/disk/by-uuid/{get_boot_drive_id(disks[0])}";\n'
' fsType = "vfat";\n'
" options = [\n"
' "fmask=0077"\n'
' "dmask=0077"\n'
" ];\n"
" };\n"
" };\n\n"
' fsType = "vfat";\n options = [\n "fmask=0077"\n "dmask=0077"\n ];\n };\n };\n\n'
" swapDevices = [ ];\n\n"
" networking.useDHCP = lib.mkDefault true;\n\n"
' nixpkgs.hostPlatform = lib.mkDefault "x86_64-linux";\n'
f" hardware.cpu.{cpu_manufacturer}.updateMicrocode = lib.mkDefault "
"config.hardware.enableRedistributableFirmware;\n"
f" hardware.cpu.{cpu_manufacturer}.updateMicrocode = lib.mkDefault config.hardware.enableRedistributableFirmware;\n"
f' networking.hostId = "{host_id}";\n'
"}\n"
)
@@ -242,7 +219,7 @@ def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], *, encrypt: boo
Path(f"{mnt_dir}/etc/nixos/hardware-configuration.nix").write_text(nix_hardware)
def install_nixos(mnt_dir: str, disks: Sequence[str], *, encrypt: bool) -> None:
def install_nixos(mnt_dir: str, disks: Sequence[str], encrypt: bool) -> None:
"""Install NixOS."""
bash_wrapper(f"mount -o X-mount.mkdir -t zfs root_pool/root {mnt_dir}")
bash_wrapper(f"mount -o X-mount.mkdir -t zfs root_pool/home {mnt_dir}/home")
@@ -253,16 +230,14 @@ def install_nixos(mnt_dir: str, disks: Sequence[str], *, encrypt: bool) -> None:
bash_wrapper(f"mkfs.vfat -n EFI {disk}-part1")
# set up mirroring afterwards if more than one disk
boot_partition = (
f"mount -t vfat -o fmask=0077,dmask=0077,iocharset=iso8859-1,X-mount.mkdir {disks[0]}-part1 {mnt_dir}/boot"
)
boot_partition = f"mount -t vfat -o fmask=0077,dmask=0077,iocharset=iso8859-1,X-mount.mkdir {disks[0]}-part1 {mnt_dir}/boot"
bash_wrapper(boot_partition)
bash_wrapper(f"nixos-generate-config --root {mnt_dir}")
create_nix_hardware_file(mnt_dir, disks, encrypt=encrypt)
create_nix_hardware_file(mnt_dir, disks, encrypt)
run(("nixos-install", "--root", mnt_dir), check=True)
run(("nixos-install", "--root", mnt_dir), check=True) # noqa: S603
def installer(
@@ -272,37 +247,27 @@ def installer(
encrypt_key: str | None,
) -> None:
"""Main."""
logger.info("Starting installation")
logging.info("Starting installation")
for disk in disks:
partition_disk(disk, swap_size, reserve)
if encrypt_key:
sleep(1)
key_input = encrypt_key.encode()
run(
("cryptsetup", "luksFormat", "--type", "luks2", f"{disk}-part2", "-"),
input=key_input,
check=True,
)
run(
(
"cryptsetup",
"luksOpen",
f"{disk}-part2",
f"luks-root-pool-{disk.split('/')[-1]}-part2",
"-",
),
input=key_input,
check=True,
)
for command in (
f'printf "{encrypt_key}" | cryptsetup luksFormat --type luks2 {disk}-part2 -',
f'printf "{encrypt_key}" | cryptsetup luksOpen {disk}-part2 luks-root-pool-{disk.split("/")[-1]}-part2 -',
):
run(command, shell=True, check=True)
mnt_dir = "/tmp/nix_install" # noqa: S108
Path(mnt_dir).mkdir(parents=True, exist_ok=True)
if encrypt_key:
pool_disks = [f"/dev/mapper/luks-root-pool-{disk.split('/')[-1]}-part2" for disk in disks]
pool_disks = [
f"/dev/mapper/luks-root-pool-{disk.split('/')[-1]}-part2" for disk in disks
]
else:
pool_disks = [f"{disk}-part2" for disk in disks]
@@ -310,73 +275,57 @@ def installer(
create_zfs_datasets()
install_nixos(mnt_dir, disks, encrypt=bool(encrypt_key))
install_nixos(mnt_dir, disks, encrypt_key)
logger.info("Installation complete")
logging.info("Installation complete")
class Cursor:
"""Track cursor position and constrain movement to screen bounds."""
def __init__(self) -> None:
"""Initialize cursor position and screen dimensions."""
def __init__(self):
self.x_position = 0
self.y_position = 0
self.height = 0
self.width = 0
def set_height(self, height: int) -> None:
"""Set the maximum screen height."""
def set_height(self, height: int):
self.height = height
def set_width(self, width: int) -> None:
"""Set the maximum screen width."""
def set_width(self, width: int):
self.width = width
def x_bounce_check(self, cursor: int) -> int:
"""Clamp an x position to the screen width."""
cursor = max(0, cursor)
return min(self.width - 1, cursor)
def y_bounce_check(self, cursor: int) -> int:
"""Clamp a y position to the screen height."""
cursor = max(0, cursor)
return min(self.height - 1, cursor)
def set_x(self, x: int) -> None:
"""Set the cursor x position."""
def set_x(self, x: int):
self.x_position = self.x_bounce_check(x)
def set_y(self, y: int) -> None:
"""Set the cursor y position."""
def set_y(self, y: int):
self.y_position = self.y_bounce_check(y)
def get_x(self) -> int:
"""Get the cursor x position."""
return self.x_position
def get_y(self) -> int:
"""Get the cursor y position."""
return self.y_position
def move_up(self) -> None:
"""Move the cursor up one row."""
def move_up(self):
self.set_y(self.y_position - 1)
def move_down(self) -> None:
"""Move the cursor down one row."""
def move_down(self):
self.set_y(self.y_position + 1)
def move_left(self) -> None:
"""Move the cursor left one column."""
def move_left(self):
self.set_x(self.x_position - 1)
def move_right(self) -> None:
"""Move the cursor right one column."""
def move_right(self):
self.set_x(self.x_position + 1)
def navigation(self, key: int) -> None:
"""Move the cursor for a curses navigation key."""
action = {
curses.KEY_DOWN: self.move_down,
curses.KEY_UP: self.move_up,
@@ -390,8 +339,7 @@ class Cursor:
class State:
"""State class to store the state of the program."""
def __init__(self) -> None:
"""Initialize installer menu state."""
def __init__(self):
self.key = 0
self.cursor = Cursor()
@@ -409,9 +357,11 @@ class State:
def get_device(raw_device: str) -> dict[str, str]:
"""Parse an lsblk key-value device row."""
raw_device_components = raw_device.split(" ")
return {thing.split("=")[0].lower(): thing.split("=")[1].strip('"') for thing in raw_device_components}
return {
thing.split("=")[0].lower(): thing.split("=")[1].strip('"')
for thing in raw_device_components
}
def get_devices() -> list[dict[str, str]]:
@@ -423,7 +373,6 @@ def get_devices() -> list[dict[str, str]]:
def get_device_id_mapping() -> dict[str, set[str]]:
"""Get a list of device ids.
Returns:
list[str]: the list of device ids
"""
@@ -438,8 +387,9 @@ def get_device_id_mapping() -> dict[str, set[str]]:
return device_id_mapping
def calculate_device_menu_padding(devices: list[dict[str, str]], column: str, padding: int = 0) -> int:
"""Calculate the width needed for a device menu column."""
def calculate_device_menu_padding(
devices: list[dict[str, str]], column: str, padding: int = 0
) -> int:
return max(len(device[column]) for device in devices) + padding
@@ -451,7 +401,6 @@ def draw_device_ids(
menu_width: list[int],
device_ids: set[str],
) -> tuple[State, int]:
"""Draw selectable device IDs for a device row."""
for device_id in sorted(device_ids):
row_number = row_number + 1
if row_number == state.cursor.get_y() and state.cursor.get_x() in menu_width:
@@ -480,9 +429,8 @@ def draw_device_menu(
state: State,
menu_start_y: int = 0,
menu_start_x: int = 0,
) -> tuple[State, int]:
"""Draw the device menu and handle user input.
) -> State:
"""draw the device menu and handle user input
Args:
std_screen (curses.window): the curses window to draw on
devices (list[dict[str, str]]): the list of devices to draw
@@ -490,7 +438,6 @@ def draw_device_menu(
state (State): the state object to update
menu_start_y (int, optional): the y position to start drawing the menu. Defaults to 0.
menu_start_x (int, optional): the x position to start drawing the menu. Defaults to 0.
Returns:
State: the updated state object
"""
@@ -501,9 +448,7 @@ def draw_device_menu(
type_padding = calculate_device_menu_padding(devices, "type", padding)
mountpoints_padding = calculate_device_menu_padding(devices, "mountpoints", padding)
device_header = (
f"{'Name':{name_padding}}{'Size':{size_padding}}{'Type':{type_padding}}{'Mountpoints':{mountpoints_padding}}"
)
device_header = f"{'Name':{name_padding}}{'Size':{size_padding}}{'Type':{type_padding}}{'Mountpoints':{mountpoints_padding}}"
menu_width = range(menu_start_x, len(device_header) + menu_start_x)
@@ -536,9 +481,8 @@ def draw_device_menu(
def debug_menu(std_screen: curses.window, key: int) -> None:
"""Draw debug information for the current curses screen."""
height, width = std_screen.getmaxyx()
width_height = f"Width: {width}, Height: {height}"
width_height = "Width: {}, Height: {}".format(width, height)
std_screen.addstr(height - 4, 0, width_height, curses.color_pair(5))
key_pressed = f"Last key pressed: {key}"[: width - 1]
@@ -546,7 +490,7 @@ def debug_menu(std_screen: curses.window, key: int) -> None:
key_pressed = "No key press detected..."[: width - 1]
std_screen.addstr(height - 3, 0, key_pressed)
for i in range(8):
for i in range(0, 8):
std_screen.addstr(height - 2, i * 3, f"{i}██", curses.color_pair(i))
@@ -556,11 +500,12 @@ def status_bar(
width: int,
height: int,
) -> None:
"""Draw the footer status bar."""
std_screen.attron(curses.A_REVERSE)
std_screen.attron(curses.color_pair(3))
status_bar = f"Press 'q' to exit | STATUS BAR | Pos: {cursor.get_x()}, {cursor.get_y()}"
status_bar = (
f"Press 'q' to exit | STATUS BAR | Pos: {cursor.get_x()}, {cursor.get_y()}"
)
std_screen.addstr(height - 1, 0, status_bar)
std_screen.addstr(height - 1, len(status_bar), " " * (width - len(status_bar) - 1))
@@ -569,15 +514,13 @@ def status_bar(
def set_color() -> None:
"""Initialize curses color pairs."""
curses.start_color()
curses.use_default_colors()
for i in range(curses.COLORS):
for i in range(0, curses.COLORS):
curses.init_pair(i + 1, i, -1)
def get_text_input(std_screen: curses.window, prompt: str, y: int, x: int) -> str:
"""Read text input from a curses screen."""
curses.echo()
std_screen.addstr(y, x, prompt)
input_str = ""
@@ -585,10 +528,10 @@ def get_text_input(std_screen: curses.window, prompt: str, y: int, x: int) -> st
key = std_screen.getch()
if key == ord("\n"):
break
if key == ESCAPE_KEY:
elif key == 27: # ESC key
input_str = ""
break
if key in (curses.KEY_BACKSPACE, ord("\b"), 127):
elif key in (curses.KEY_BACKSPACE, ord("\b"), 127):
input_str = input_str[:-1]
std_screen.addstr(y, x + len(prompt), input_str + " ")
else:
@@ -603,7 +546,6 @@ def swap_size_input(
state: State,
swap_offset: int,
) -> State:
"""Handle swap size input."""
swap_size_text = "Swap size (GB): "
std_screen.addstr(swap_offset, 0, f"{swap_size_text}{state.swap_size}")
if state.key == ord("\n") and state.cursor.get_y() == swap_offset:
@@ -615,7 +557,9 @@ def swap_size_input(
state.swap_size = int(swap_size_str)
state.show_swap_input = False
except ValueError:
std_screen.addstr(swap_offset, 0, "Invalid input. Press any key to continue.")
std_screen.addstr(
swap_offset, 0, "Invalid input. Press any key to continue."
)
std_screen.getch()
state.show_swap_input = False
@@ -627,19 +571,22 @@ def reserve_size_input(
state: State,
reserve_offset: int,
) -> State:
"""Handle reserve size input."""
reserve_size_text = "reserve size (GB): "
std_screen.addstr(reserve_offset, 0, f"{reserve_size_text}{state.reserve_size}")
if state.key == ord("\n") and state.cursor.get_y() == reserve_offset:
state.show_reserve_input = True
if state.show_reserve_input:
reserve_size_str = get_text_input(std_screen, reserve_size_text, reserve_offset, 0)
reserve_size_str = get_text_input(
std_screen, reserve_size_text, reserve_offset, 0
)
try:
state.reserve_size = int(reserve_size_str)
state.show_reserve_input = False
except ValueError:
std_screen.addstr(reserve_offset, 0, "Invalid input. Press any key to continue.")
std_screen.addstr(
reserve_offset, 0, "Invalid input. Press any key to continue."
)
std_screen.getch()
state.show_reserve_input = False
@@ -647,11 +594,9 @@ def reserve_size_input(
def draw_menu(std_screen: curses.window) -> State:
"""Draw the menu and handle user input.
"""draw the menu and handle user input
Args:
std_screen (curses.window): the curses window to draw on
Returns:
State: the state object
"""
@@ -711,18 +656,17 @@ def draw_menu(std_screen: curses.window) -> State:
def main() -> None:
"""Run the installer menu and start installation."""
configure_logger("DEBUG")
state = curses.wrapper(draw_menu)
encrypt_key = getenv("ENCRYPT_KEY")
logger.info("installing_nixos")
logger.info(f"disks: {state.selected_device_ids}")
logger.info(f"swap_size: {state.swap_size}")
logger.info(f"reserve: {state.reserve_size}")
logger.info(f"encrypted: {bool(encrypt_key)}")
logging.info("installing_nixos")
logging.info(f"disks: {state.selected_device_ids}")
logging.info(f"swap_size: {state.swap_size}")
logging.info(f"reserve: {state.reserve_size}")
logging.info(f"encrypted: {bool(encrypt_key)}")
sleep(3)
+1 -8
View File
@@ -28,14 +28,7 @@
networking = {
hostName = "bob";
hostId = "7c678a41";
firewall = {
enable = true;
allowedTCPPorts = [
8000
8001
8002
];
};
firewall.enable = true;
networkmanager.enable = true;
};
-5
View File
@@ -30,11 +30,6 @@
keyFile = "/dev/disk/by-id/usb-Samsung_Flash_Drive_FIT_0374620080067131-0:0";
};
};
zfs.extraPools = [
"storage"
];
kernelModules = [ "kvm-amd" ];
extraModulePackages = [ ];
};
-3
View File
@@ -17,9 +17,6 @@
allowedTCPPorts = [ ];
allowedUDPPorts = [ ];
};
allowedTCPPorts = [
8070
];
};
useNetworkd = true;
};
+40
View File
@@ -0,0 +1,40 @@
"""Shared test fixtures."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from sqlalchemy import create_engine, event
from python.orm.signal_bot.base import SignalBotBase
if TYPE_CHECKING:
from collections.abc import Generator
from sqlalchemy.engine import Engine
@pytest.fixture(scope="session")
def sqlite_engine() -> Generator[Engine]:
"""Create an in-memory SQLite engine for testing."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, _connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
SignalBotBase.metadata.create_all(engine)
yield engine
engine.dispose()
@pytest.fixture
def engine(sqlite_engine: Engine) -> Generator[Engine]:
"""Yield the shared engine after cleaning all tables between tests."""
yield sqlite_engine
with sqlite_engine.begin() as connection:
for table in reversed(SignalBotBase.metadata.sorted_tables):
connection.execute(table.delete())
+6 -162
View File
@@ -2,7 +2,6 @@
from __future__ import annotations
import logging
from dataclasses import replace
from datetime import UTC, datetime
from os import environ
@@ -20,11 +19,7 @@ from python.ebook_search.bm25_corpus import (
BM25CorpusUnavailableError,
BM25Manifest,
ensure_bm25_corpus,
fetch_bm25_corpus_records,
load_bm25_corpus,
read_bm25_manifest,
score_bm25_corpus,
write_bm25_corpus,
)
from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_config, normalize_embedding_model
from python.ebook_search.embeddings import MODEL_DIMENSIONS, ensure_embedding_models
@@ -38,14 +33,7 @@ from python.ebook_search.search import (
search_ebooks,
)
from python.ebook_search.timing import RuntimeStep
from python.orm.richie import (
EbookChapter,
EbookChunk,
EbookChunkEmbedding1024,
EbookEmbeddingModel,
EbookSource,
RichieBase,
)
from python.orm.richie import EbookEmbeddingModel, EbookSource, RichieBase
def test_chunk_text_uses_overlap() -> None:
@@ -98,49 +86,6 @@ def test_find_existing_source_matches_path_or_hash() -> None:
assert find_existing_source(session, Path("/new/book.epub"), "a" * 64) == source
def test_bm25_corpus_uses_existing_search_text_without_duplicate_metadata() -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
RichieBase.metadata.create_all(engine)
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as session:
source = EbookSource(
title="Book",
author="Author",
language=None,
publisher=None,
identifier=None,
file_path="/book.epub",
file_sha256="a" * 64,
file_mtime=datetime.now(tz=UTC),
file_size=10,
)
session.add(source)
session.flush()
chapter = EbookChapter(source_id=source.id, spine_index=0, title="Chapter", href=None)
session.add(chapter)
session.flush()
session.add(
EbookChunk(
id=1,
source_id=source.id,
chapter_id=chapter.id,
chunk_index=0,
text="content",
token_start=0,
token_count=1,
page_label=None,
content_sha256="b" * 64,
search_text="Book Author Chapter content",
)
)
session.commit()
records, texts = fetch_bm25_corpus_records(session)
assert texts == ["Book Author Chapter content"]
assert records[0]["chunk_id"] == 1
assert "bm25_text" not in records[0]
def test_reciprocal_rank_fusion_marks_hybrid_source() -> None:
vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")]
lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")]
@@ -174,7 +119,7 @@ def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
def fake_vector_candidates(received_engine, query, _config):
"""Return vector candidates after confirming BM25 has started."""
received_engines.append(received_engine)
assert query == "what is parallel"
assert query == "parallel"
vector_started.set()
assert bm25_started.wait(timeout=2)
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
@@ -190,14 +135,13 @@ def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
response = search_ebooks(engine, "what is parallel", config)
response = search_ebooks(engine, "parallel", config)
timings = {step.name: step for step in response.timings}
assert [result.chunk_id for result in response.results] == [1, 2]
assert timings["Embedding + vector search"].counts_toward_total is False
assert timings["BM25 search"].counts_toward_total is False
assert timings["Hybrid retrieval"].counts_toward_total is True
assert timings["BM25 query preparation"].counts_toward_total is True
assert received_engines == [engine]
@@ -240,106 +184,15 @@ def test_bm25_candidates_scores_whole_corpus(monkeypatch) -> None:
assert [result.bm25_score for result in results] == [1.5]
def test_bm25_candidates_returns_empty_when_corpus_is_unavailable(monkeypatch, caplog) -> None:
def test_bm25_candidates_raises_when_corpus_is_unavailable(monkeypatch) -> None:
def fake_load_bm25_corpus(_config):
raise BM25CorpusUnavailableError
monkeypatch.setattr("python.ebook_search.search.load_bm25_corpus", fake_load_bm25_corpus)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
with caplog.at_level(logging.WARNING):
results = bm25_candidates("high", config)
assert results == []
assert "ebook_bm25_index_unavailable_skipping" in caplog.text
def test_write_bm25_corpus_publishes_dated_generation(tmp_path) -> None:
index_path = tmp_path / "bm25"
index_path.mkdir()
generations_path = index_path / "generations"
generations_path.mkdir()
old_generation = generations_path / "20260101T000000.000000Z"
old_generation.mkdir()
(old_generation / "sentinel").write_text("old", encoding="utf-8")
(index_path / "current").symlink_to(Path("generations") / old_generation.name, target_is_directory=True)
manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None,
chunk_count=0,
)
write_bm25_corpus(index_path, [], [], manifest)
current_path = index_path / "current"
assert current_path.is_symlink()
assert current_path.readlink() == generations_path / "20260612T010203.456789Z"
assert old_generation.is_dir()
assert (old_generation / "sentinel").read_text(encoding="utf-8") == "old"
assert (generations_path / "20260612T010203.456789Z").is_dir()
assert read_bm25_manifest(index_path) == manifest
def test_write_bm25_corpus_keeps_current_generation_when_publish_fails(monkeypatch, tmp_path) -> None:
index_path = tmp_path / "bm25"
index_path.mkdir()
generations_path = index_path / "generations"
generations_path.mkdir()
old_generation = generations_path / "20260101T000000.000000Z"
old_generation.mkdir()
(old_generation / "sentinel").write_text("old", encoding="utf-8")
current_path = index_path / "current"
current_path.symlink_to(Path("generations") / old_generation.name, target_is_directory=True)
original_replace = Path.replace
def fail_current_replace(self, target):
if self.parent == index_path and self.name.startswith(".current.") and target == current_path:
msg = "current publish failed"
raise OSError(msg)
return original_replace(self, target)
monkeypatch.setattr(Path, "replace", fail_current_replace)
manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None,
chunk_count=0,
)
with pytest.raises(OSError, match="current publish failed"):
write_bm25_corpus(index_path, [], [], manifest)
assert current_path.readlink() == Path("generations") / old_generation.name
assert (old_generation / "sentinel").read_text(encoding="utf-8") == "old"
assert not (generations_path / "20260612T010203.456789Z").exists()
def test_load_bm25_corpus_uses_current_generation(tmp_path) -> None:
load_bm25_corpus.cache_clear()
index_path = tmp_path / "bm25"
manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None,
chunk_count=1,
)
record = {
"chunk_id": 2,
"text": "cached",
"source_title": "B",
"source_author": None,
"chapter_title": None,
"page_label": None,
}
write_bm25_corpus(index_path, [record], ["cached phrase"], manifest)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(index_path))
try:
corpus = load_bm25_corpus(config)
finally:
load_bm25_corpus.cache_clear()
assert corpus.manifest == manifest
assert corpus.records[0]["chunk_id"] == 2
assert score_bm25_corpus("cached", corpus, limit=10)
with pytest.raises(BM25CorpusUnavailableError):
bm25_candidates("high", config)
def test_load_bm25_corpus_caches_disk_load(monkeypatch, tmp_path) -> None:
@@ -471,15 +324,6 @@ def test_ensure_embedding_models_registers_service_names() -> None:
]
def test_1024_embedding_table_has_cosine_hnsw_index() -> None:
indexes = {index.name: index for index in EbookChunkEmbedding1024.__table__.indexes}
index = indexes["ix_ebook_chunk_embedding_1024_embedding_cosine"]
assert [column.name for column in index.columns] == ["embedding"]
assert index.dialect_options["postgresql"]["using"] == "hnsw"
assert index.dialect_options["postgresql"]["ops"] == {"embedding": "vector_cosine_ops"}
def test_embedding_model_aliases_normalize_to_provider_names() -> None:
assert normalize_embedding_model() == "qwen3-embedding-0.6b"
+5 -5
View File
@@ -75,7 +75,7 @@ def test_reranking_enabled_reorders_candidates(monkeypatch: pytest.MonkeyPatch)
results = rerank_chunks("query", candidates(), RerankConfig())
assert [result.chunk_id for result in results] == [2, 1, 3]
assert [round(result.score, 3) for result in results] == [0.78, 0.37, 0.28]
assert [round(result.score, 3) for result in results] == [0.45, 0.1, 0.0]
assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4]
@@ -100,8 +100,8 @@ def test_reranking_cannot_ignore_hybrid_score(monkeypatch: pytest.MonkeyPatch) -
results = rerank_chunks("query", candidates, RerankConfig())
assert [result.chunk_id for result in results] == [1, 2]
assert results[0].score == pytest.approx(0.79)
assert results[1].score == 0.7
assert results[0].score == 0.7
assert results[1].score == 0.0
assert results[1].rerank_score == 1.0
@@ -129,7 +129,7 @@ def test_malformed_vllm_rerank_json_does_not_crash_search(monkeypatch: pytest.Mo
results = rerank_chunks("query", candidates()[:1], RerankConfig())
assert results[0].score == 0.3
assert results[0].score == 0.0
def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -147,4 +147,4 @@ def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None
results = rerank_chunks("query", candidates()[:2], RerankConfig())
assert {result.chunk_id: result.rerank_score for result in results} == {1: 0.0, 2: 1.0}
assert [result.rerank_score for result in results] == [0.0, 1.0]
-38
View File
@@ -2,11 +2,9 @@
from __future__ import annotations
from compression import zstd
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from python.ebook_search.api.bm25_tasks import refresh_bm25_for_engine
from python.ebook_search.api.main import create_app
from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.embeddings import EmbeddingModelStats
@@ -25,19 +23,6 @@ def fake_get_postgres_engine(**_kwargs):
return create_engine("sqlite+pysqlite:///:memory:", future=True)
def test_search_page_uses_zstd_when_requested(monkeypatch) -> None:
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
with TestClient(app) as client:
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.headers["content-encoding"] == "zstd"
assert b"EPUB Search" in zstd.decompress(response.content)
def test_ui_form_passes_rerank_flag_to_search_handler(monkeypatch) -> None:
captured: dict[str, object] = {}
@@ -247,29 +232,6 @@ def test_ui_scan_schedules_bm25_refresh_after_database_change(monkeypatch) -> No
assert scheduled is True
def test_bm25_refresh_clears_loaded_corpus_cache(monkeypatch) -> None:
refreshed: list[object] = []
cache_cleared = False
def fake_refresh_bm25_corpus(session, config):
refreshed.append((session, config))
def fake_cache_clear():
nonlocal cache_cleared
cache_cleared = True
monkeypatch.setattr("python.ebook_search.api.bm25_tasks.refresh_bm25_corpus", fake_refresh_bm25_corpus)
monkeypatch.setattr("python.ebook_search.api.bm25_tasks.load_bm25_corpus.cache_clear", fake_cache_clear)
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
refresh_bm25_for_engine(engine, config)
assert len(refreshed) == 1
assert refreshed[0][1] == config
assert cache_cleared is True
def test_admin_page_shows_embedding_counts_by_model(monkeypatch) -> None:
def fake_embedding_model_stats(_session):
return [
+3 -6
View File
@@ -21,7 +21,7 @@ def test_validate_system(mocker: MockerFixture, fs: FakeFilesystem) -> None:
"""test_validate_system."""
fs.create_file(
"/mock_snapshot_config.toml",
contents='zpools = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
contents='zpool = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
)
mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=None)
@@ -33,10 +33,9 @@ def test_validate_system_errors(mocker: MockerFixture, fs: FakeFilesystem) -> No
"""test_validate_system_errors."""
fs.create_file(
"/mock_snapshot_config.toml",
contents='zpools = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
contents='zpool = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
)
mocker.patch(f"{VALIDATE_SYSTEM}.signal_alert")
mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=["systemd_tests error"])
mocker.patch(f"{VALIDATE_SYSTEM}.zpool_tests", return_value=["zpool_tests error"])
@@ -50,11 +49,9 @@ def test_validate_system_execution(mocker: MockerFixture, fs: FakeFilesystem) ->
"""test_validate_system_execution."""
fs.create_file(
"/mock_snapshot_config.toml",
contents='zpools = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
contents='zpool = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
)
mocker.patch(f"{VALIDATE_SYSTEM}.signal_alert")
mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=None)
mocker.patch(f"{VALIDATE_SYSTEM}.zpool_tests", side_effect=RuntimeError("zpool_tests error"))
with pytest.raises(SystemExit) as exception_info:
+321
View File
@@ -0,0 +1,321 @@
"""Tests for the Signal command and control bot."""
from __future__ import annotations
import json
from datetime import timedelta
from unittest.mock import MagicMock, patch
import pytest
from python.signal_bot.commands.inventory import (
_format_summary,
parse_llm_response,
)
from python.signal_bot.commands.location import _format_location, handle_location_request
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
from python.signal_bot.llm_client import LLMClient
from python.signal_bot.main import Bot
from python.signal_bot.models import (
BotConfig,
InventoryItem,
SignalMessage,
TrustLevel,
)
from python.signal_bot.signal_client import SignalClient
class TestModels:
def test_trust_level_values(self):
assert TrustLevel.VERIFIED == "verified"
assert TrustLevel.UNVERIFIED == "unverified"
assert TrustLevel.BLOCKED == "blocked"
def test_signal_message_defaults(self):
msg = SignalMessage(source="+1234", timestamp=0)
assert msg.message == ""
assert msg.attachments == []
assert msg.group_id is None
def test_inventory_item_defaults(self):
item = InventoryItem(name="wrench")
assert item.quantity == 1
assert item.unit == "each"
assert item.category == ""
class TestInventoryParsing:
def test_parse_llm_response_basic(self):
raw = '[{"name": "water", "quantity": 6, "unit": "gallon", "category": "supplies", "notes": ""}]'
items = parse_llm_response(raw)
assert len(items) == 1
assert items[0].name == "water"
assert items[0].quantity == 6
assert items[0].unit == "gallon"
def test_parse_llm_response_with_code_fence(self):
raw = '```json\n[{"name": "tape", "quantity": 1, "unit": "each", "category": "tools", "notes": ""}]\n```'
items = parse_llm_response(raw)
assert len(items) == 1
assert items[0].name == "tape"
def test_parse_llm_response_invalid_json(self):
with pytest.raises(json.JSONDecodeError):
parse_llm_response("not json at all")
def test_format_summary(self):
items = [InventoryItem(name="water", quantity=6, unit="gallon", category="supplies")]
summary = _format_summary(items)
assert "water" in summary
assert "x6" in summary
assert "gallon" in summary
class TestDeviceRegistry:
@pytest.fixture
def signal_mock(self):
return MagicMock(spec=SignalClient)
@pytest.fixture
def registry(self, signal_mock, engine):
return DeviceRegistry(signal_mock, engine)
def test_new_device_is_unverified(self, registry):
registry.record_contact("+1234", "abc123")
assert not registry.is_verified("+1234")
def test_verify_device(self, registry):
registry.record_contact("+1234", "abc123")
assert registry.verify("+1234")
assert registry.is_verified("+1234")
def test_verify_unknown_device(self, registry):
assert not registry.verify("+9999")
def test_block_device(self, registry):
registry.record_contact("+1234", "abc123")
assert registry.block("+1234")
assert not registry.is_verified("+1234")
def test_safety_number_change_resets_trust(self, registry):
registry.record_contact("+1234", "abc123")
registry.verify("+1234")
assert registry.is_verified("+1234")
registry.record_contact("+1234", "different_safety_number")
assert not registry.is_verified("+1234")
def test_persistence(self, signal_mock, engine):
reg1 = DeviceRegistry(signal_mock, engine)
reg1.record_contact("+1234", "abc123")
reg1.verify("+1234")
reg2 = DeviceRegistry(signal_mock, engine)
assert reg2.is_verified("+1234")
def test_list_devices(self, registry):
registry.record_contact("+1234", "abc")
registry.record_contact("+5678", "def")
assert len(registry.list_devices()) == 2
class TestContactCache:
@pytest.fixture
def signal_mock(self):
return MagicMock(spec=SignalClient)
@pytest.fixture
def registry(self, signal_mock, engine):
return DeviceRegistry(signal_mock, engine)
def test_second_call_uses_cache(self, registry):
registry.record_contact("+1234", "abc")
assert "+1234" in registry._contact_cache
with patch.object(registry, "engine") as mock_engine:
registry.record_contact("+1234", "abc")
mock_engine.assert_not_called()
def test_unverified_gets_default_ttl(self, registry):
registry.record_contact("+1234", "abc")
from python.common import utcnow
entry = registry._contact_cache["+1234"]
expected = utcnow() + _DEFAULT_TTL
assert abs((entry.expires - expected).total_seconds()) < 2
assert entry.trust_level == TrustLevel.UNVERIFIED
assert entry.has_safety_number is True
def test_blocked_gets_blocked_ttl(self, registry):
registry.record_contact("+1234", "abc")
registry.block("+1234")
from python.common import utcnow
entry = registry._contact_cache["+1234"]
expected = utcnow() + _BLOCKED_TTL
assert abs((entry.expires - expected).total_seconds()) < 2
assert entry.trust_level == TrustLevel.BLOCKED
def test_verify_updates_cache(self, registry):
registry.record_contact("+1234", "abc")
registry.verify("+1234")
entry = registry._contact_cache["+1234"]
assert entry.trust_level == TrustLevel.VERIFIED
def test_block_updates_cache(self, registry):
registry.record_contact("+1234", "abc")
registry.block("+1234")
entry = registry._contact_cache["+1234"]
assert entry.trust_level == TrustLevel.BLOCKED
def test_unverify_updates_cache(self, registry):
registry.record_contact("+1234", "abc")
registry.verify("+1234")
registry.unverify("+1234")
entry = registry._contact_cache["+1234"]
assert entry.trust_level == TrustLevel.UNVERIFIED
def test_is_verified_uses_cache(self, registry):
registry.record_contact("+1234", "abc")
registry.verify("+1234")
with patch.object(registry, "engine") as mock_engine:
assert registry.is_verified("+1234") is True
mock_engine.assert_not_called()
def test_has_safety_number_uses_cache(self, registry):
registry.record_contact("+1234", "abc")
with patch.object(registry, "engine") as mock_engine:
assert registry.has_safety_number("+1234") is True
mock_engine.assert_not_called()
def test_no_safety_number_cached(self, registry):
registry.record_contact("+1234", None)
with patch.object(registry, "engine") as mock_engine:
assert registry.has_safety_number("+1234") is False
mock_engine.assert_not_called()
def test_expired_cache_hits_db(self, registry):
registry.record_contact("+1234", "abc")
old = registry._contact_cache["+1234"]
registry._contact_cache["+1234"] = _CacheEntry(
expires=old.expires - timedelta(minutes=10),
trust_level=old.trust_level,
has_safety_number=old.has_safety_number,
safety_number=old.safety_number,
roles=old.roles,
)
with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
mock_device = MagicMock()
mock_device.trust_level = TrustLevel.UNVERIFIED
mock_session.scalars.return_value.one_or_none.return_value = mock_device
registry.record_contact("+1234", "abc")
mock_session.scalars.assert_called_once()
class TestLocationCommand:
def test_format_location(self):
response = _format_location("12.34", "56.78")
assert "12.34, 56.78" in response
assert "maps.google.com" in response
def test_handle_location_request_without_config(self):
signal = MagicMock(spec=SignalClient)
message = SignalMessage(source="+1234", timestamp=0, message="location")
handle_location_request(message, signal, None, None)
signal.reply.assert_called_once()
assert "not configured" in signal.reply.call_args[0][1]
class TestDispatch:
@pytest.fixture
def signal_mock(self):
return MagicMock(spec=SignalClient)
@pytest.fixture
def llm_mock(self):
return MagicMock(spec=LLMClient)
@pytest.fixture
def registry_mock(self):
mock = MagicMock(spec=DeviceRegistry)
mock.is_verified.return_value = True
mock.has_safety_number.return_value = True
mock.has_role.return_value = False
mock.get_roles.return_value = []
return mock
@pytest.fixture
def config(self, engine):
return BotConfig(
signal_api_url="http://localhost:8080",
phone_number="+1234567890",
inventory_api_url="http://localhost:9090",
engine=engine,
)
@pytest.fixture
def bot(self, signal_mock, llm_mock, registry_mock, config):
return Bot(signal_mock, llm_mock, registry_mock, config)
def test_unverified_device_ignored(self, bot, signal_mock, registry_mock):
registry_mock.is_verified.return_value = False
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_admin_without_safety_number_ignored(self, bot, signal_mock, registry_mock):
registry_mock.has_safety_number.return_value = False
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_non_admin_without_safety_number_allowed(self, bot, signal_mock, registry_mock):
registry_mock.has_safety_number.return_value = False
registry_mock.has_role.return_value = False
registry_mock.get_roles.return_value = []
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
def test_help_command(self, bot, signal_mock, registry_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
assert "Available commands" in signal_mock.reply.call_args[0][1]
def test_unknown_command_ignored(self, bot, signal_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_non_command_message_ignored(self, bot, signal_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_status_command(self, bot, signal_mock, llm_mock, registry_mock):
llm_mock.list_models.return_value = ["model1", "model2"]
llm_mock.model = "test:7b"
registry_mock.list_devices.return_value = []
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="status")
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
assert "Bot online" in signal_mock.reply.call_args[0][1]
def test_location_command(self, bot, signal_mock, registry_mock, config):
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="location")
with patch("python.signal_bot.main.handle_location_request") as mock_location:
bot.dispatch(msg)
mock_location.assert_called_once_with(
msg,
signal_mock,
config.ha_url,
config.ha_token,
)