mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-21 14:49:10 -04:00
Compare commits
1 Commits
claude/she
...
feature/ad
| Author | SHA1 | Date | |
|---|---|---|---|
| 65c4f1d23e |
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -232,7 +232,6 @@
|
|||||||
"pyopenweathermap",
|
"pyopenweathermap",
|
||||||
"pyownet",
|
"pyownet",
|
||||||
"pytest",
|
"pytest",
|
||||||
"qalculate",
|
|
||||||
"quicksuggest",
|
"quicksuggest",
|
||||||
"radarr",
|
"radarr",
|
||||||
"readahead",
|
"readahead",
|
||||||
@@ -257,7 +256,6 @@
|
|||||||
"sessionmaker",
|
"sessionmaker",
|
||||||
"sessionstore",
|
"sessionstore",
|
||||||
"shellcheck",
|
"shellcheck",
|
||||||
"signalbot",
|
|
||||||
"signon",
|
"signon",
|
||||||
"Signons",
|
"Signons",
|
||||||
"skia",
|
"skia",
|
||||||
@@ -307,7 +305,6 @@
|
|||||||
"useragent",
|
"useragent",
|
||||||
"usernamehw",
|
"usernamehw",
|
||||||
"userprefs",
|
"userprefs",
|
||||||
"vaninventory",
|
|
||||||
"vfat",
|
"vfat",
|
||||||
"victron",
|
"victron",
|
||||||
"virt",
|
"virt",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
fastapi
|
fastapi
|
||||||
fastapi-cli
|
fastapi-cli
|
||||||
httpx
|
httpx
|
||||||
|
python-multipart
|
||||||
mypy
|
mypy
|
||||||
polars
|
polars
|
||||||
psycopg
|
psycopg
|
||||||
@@ -33,18 +34,15 @@
|
|||||||
pytest-cov
|
pytest-cov
|
||||||
pytest-mock
|
pytest-mock
|
||||||
pytest-xdist
|
pytest-xdist
|
||||||
python-multipart
|
|
||||||
requests
|
requests
|
||||||
ruff
|
ruff
|
||||||
scalene
|
scalene
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
tenacity
|
|
||||||
textual
|
textual
|
||||||
tinytuya
|
tinytuya
|
||||||
typer
|
typer
|
||||||
types-requests
|
types-requests
|
||||||
websockets
|
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -21,13 +21,11 @@ dependencies = [
|
|||||||
"requests",
|
"requests",
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"typer",
|
"typer",
|
||||||
"websockets",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
database = "python.database_cli:app"
|
database = "python.database_cli:app"
|
||||||
van-inventory = "python.van_inventory.main:serve"
|
van-inventory = "python.van_inventory.main:serve"
|
||||||
sheet-music-ocr = "python.sheet_music_ocr.main:app"
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
@@ -58,10 +56,7 @@ lint.ignore = [
|
|||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
|
||||||
"tests/**" = [
|
"tests/**" = [
|
||||||
"ANN", # (perm) type annotations not needed in tests
|
"S101", # (perm) pytest needs asserts
|
||||||
"D", # (perm) docstrings not needed in tests
|
|
||||||
"PLR2004", # (perm) magic values are fine in test assertions
|
|
||||||
"S101", # (perm) pytest needs asserts
|
|
||||||
]
|
]
|
||||||
"python/stuff/**" = [
|
"python/stuff/**" = [
|
||||||
"T201", # (perm) I don't care about print statements dir
|
"T201", # (perm) I don't care about print statements dir
|
||||||
@@ -87,9 +82,6 @@ lint.ignore = [
|
|||||||
"python/alembic/**" = [
|
"python/alembic/**" = [
|
||||||
"INP001", # (perm) this creates LSP issues for alembic
|
"INP001", # (perm) this creates LSP issues for alembic
|
||||||
]
|
]
|
||||||
"python/signal_bot/**" = [
|
|
||||||
"D107", # (perm) class docstrings cover __init__
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
convention = "google"
|
convention = "google"
|
||||||
|
|||||||
@@ -45,18 +45,6 @@ def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
|
|||||||
Path(filename).write_text(dynamic_schema_file)
|
Path(filename).write_text(dynamic_schema_file)
|
||||||
|
|
||||||
|
|
||||||
@write_hooks.register("import_postgresql")
|
|
||||||
def import_postgresql(filename: str, _options: dict[Any, Any]) -> None:
|
|
||||||
"""Add postgresql dialect import when postgresql types are used."""
|
|
||||||
content = Path(filename).read_text()
|
|
||||||
if "postgresql." in content and "from sqlalchemy.dialects import postgresql" not in content:
|
|
||||||
content = content.replace(
|
|
||||||
"import sqlalchemy as sa\n",
|
|
||||||
"import sqlalchemy as sa\nfrom sqlalchemy.dialects import postgresql\n",
|
|
||||||
)
|
|
||||||
Path(filename).write_text(content)
|
|
||||||
|
|
||||||
|
|
||||||
@write_hooks.register("ruff")
|
@write_hooks.register("ruff")
|
||||||
def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None:
|
def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None:
|
||||||
"""Docstring for ruff_check_and_format."""
|
"""Docstring for ruff_check_and_format."""
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
"""adding SignalDevice for DeviceRegistry for signal bot.
|
|
||||||
|
|
||||||
Revision ID: 4c410c16e39c
|
|
||||||
Revises: 3f71565e38de
|
|
||||||
Create Date: 2026-03-09 14:51:24.228976
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
from python.orm import RichieBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "4c410c16e39c"
|
|
||||||
down_revision: str | None = "3f71565e38de"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
schema = RichieBase.schema_name
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table(
|
|
||||||
"signal_device",
|
|
||||||
sa.Column("phone_number", sa.String(length=50), nullable=False),
|
|
||||||
sa.Column("safety_number", sa.String(), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"trust_level",
|
|
||||||
postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_signal_device")),
|
|
||||||
sa.UniqueConstraint("phone_number", name=op.f("uq_signal_device_phone_number")),
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table("signal_device", schema=schema)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
"""fixed safety number logic.
|
|
||||||
|
|
||||||
Revision ID: 99fec682516c
|
|
||||||
Revises: 4c410c16e39c
|
|
||||||
Create Date: 2026-03-09 16:25:25.085806
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
from python.orm import RichieBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "99fec682516c"
|
|
||||||
down_revision: str | None = "4c410c16e39c"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
schema = RichieBase.schema_name
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.alter_column("signal_device", "safety_number", existing_type=sa.VARCHAR(), nullable=True, schema=schema)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.alter_column("signal_device", "safety_number", existing_type=sa.VARCHAR(), nullable=False, schema=schema)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
"""add dead_letter_message table.
|
|
||||||
|
|
||||||
Revision ID: a1b2c3d4e5f6
|
|
||||||
Revises: 99fec682516c
|
|
||||||
Create Date: 2026-03-10 12:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
from python.orm import RichieBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "a1b2c3d4e5f6"
|
|
||||||
down_revision: str | None = "99fec682516c"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
schema = RichieBase.schema_name
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade."""
|
|
||||||
op.create_table(
|
|
||||||
"dead_letter_message",
|
|
||||||
sa.Column("source", sa.String(), nullable=False),
|
|
||||||
sa.Column("message", sa.Text(), nullable=False),
|
|
||||||
sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"status",
|
|
||||||
postgresql.ENUM("UNPROCESSED", "PROCESSED", name="message_status", schema=schema),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_dead_letter_message")),
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade."""
|
|
||||||
op.drop_table("dead_letter_message", schema=schema)
|
|
||||||
op.execute(sa.text(f"DROP TYPE IF EXISTS {schema}.message_status"))
|
|
||||||
@@ -58,9 +58,8 @@ class DatabaseConfig:
|
|||||||
cfg.set_main_option("version_path_separator", "os")
|
cfg.set_main_option("version_path_separator", "os")
|
||||||
cfg.set_main_option("version_locations", self.version_location)
|
cfg.set_main_option("version_locations", self.version_location)
|
||||||
cfg.set_main_option("revision_environment", "true")
|
cfg.set_main_option("revision_environment", "true")
|
||||||
cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,import_postgresql,ruff")
|
cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,ruff")
|
||||||
cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema")
|
cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema")
|
||||||
cfg.set_section_option("post_write_hooks", "import_postgresql.type", "import_postgresql")
|
|
||||||
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
|
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
|
||||||
cfg.attributes["base"] = self.get_base()
|
cfg.attributes["base"] = self.get_base()
|
||||||
cfg.attributes["env_prefix"] = self.env_prefix
|
cfg.attributes["env_prefix"] = self.env_prefix
|
||||||
@@ -74,7 +73,7 @@ DATABASES: dict[str, DatabaseConfig] = {
|
|||||||
version_location="python/alembic/richie/versions",
|
version_location="python/alembic/richie/versions",
|
||||||
base_module="python.orm.richie.base",
|
base_module="python.orm.richie.base",
|
||||||
base_class_name="RichieBase",
|
base_class_name="RichieBase",
|
||||||
models_module="python.orm.richie",
|
models_module="python.orm.richie.contact",
|
||||||
),
|
),
|
||||||
"van_inventory": DatabaseConfig(
|
"van_inventory": DatabaseConfig(
|
||||||
env_prefix="VAN_INVENTORY",
|
env_prefix="VAN_INVENTORY",
|
||||||
|
|||||||
@@ -11,20 +11,16 @@ from python.orm.richie.contact import (
|
|||||||
Need,
|
Need,
|
||||||
RelationshipType,
|
RelationshipType,
|
||||||
)
|
)
|
||||||
from python.orm.richie.dead_letter_message import DeadLetterMessage
|
|
||||||
from python.orm.richie.signal_device import SignalDevice
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Bill",
|
"Bill",
|
||||||
"Contact",
|
"Contact",
|
||||||
"ContactNeed",
|
"ContactNeed",
|
||||||
"ContactRelationship",
|
"ContactRelationship",
|
||||||
"DeadLetterMessage",
|
|
||||||
"Legislator",
|
"Legislator",
|
||||||
"Need",
|
"Need",
|
||||||
"RelationshipType",
|
"RelationshipType",
|
||||||
"RichieBase",
|
"RichieBase",
|
||||||
"SignalDevice",
|
|
||||||
"TableBase",
|
"TableBase",
|
||||||
"Vote",
|
"Vote",
|
||||||
"VoteRecord",
|
"VoteRecord",
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import StrEnum
|
from enum import Enum
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, String
|
from sqlalchemy import ForeignKey, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|||||||
from python.orm.richie.base import RichieBase, TableBase
|
from python.orm.richie.base import RichieBase, TableBase
|
||||||
|
|
||||||
|
|
||||||
class RelationshipType(StrEnum):
|
class RelationshipType(str, Enum):
|
||||||
"""Relationship types with default closeness weights.
|
"""Relationship types with default closeness weights.
|
||||||
|
|
||||||
Default weight is an integer 1-10 where 10 = closest relationship.
|
Default weight is an integer 1-10 where 10 = closest relationship.
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
"""Dead letter queue for Signal bot messages that fail processing."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import DateTime, Text
|
|
||||||
from sqlalchemy.dialects.postgresql import ENUM
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from python.orm.richie.base import TableBase
|
|
||||||
from python.signal_bot.models import MessageStatus
|
|
||||||
|
|
||||||
|
|
||||||
class DeadLetterMessage(TableBase):
|
|
||||||
"""A Signal message that failed processing and was sent to the dead letter queue."""
|
|
||||||
|
|
||||||
__tablename__ = "dead_letter_message"
|
|
||||||
|
|
||||||
source: Mapped[str]
|
|
||||||
message: Mapped[str] = mapped_column(Text)
|
|
||||||
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
||||||
status: Mapped[MessageStatus] = mapped_column(
|
|
||||||
ENUM(MessageStatus, name="message_status", create_type=True, schema="main"),
|
|
||||||
default=MessageStatus.UNPROCESSED,
|
|
||||||
)
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""Signal bot device registry models."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import DateTime, String
|
|
||||||
from sqlalchemy.dialects.postgresql import ENUM
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from python.orm.richie.base import TableBase
|
|
||||||
from python.signal_bot.models import TrustLevel
|
|
||||||
|
|
||||||
|
|
||||||
class SignalDevice(TableBase):
|
|
||||||
"""A Signal device tracked by phone number and safety number."""
|
|
||||||
|
|
||||||
__tablename__ = "signal_device"
|
|
||||||
|
|
||||||
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
|
|
||||||
safety_number: Mapped[str | None]
|
|
||||||
trust_level: Mapped[TrustLevel] = mapped_column(
|
|
||||||
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"),
|
|
||||||
default=TrustLevel.UNVERIFIED,
|
|
||||||
)
|
|
||||||
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Sheet music OCR tool using Audiveris."""
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
"""Audiveris subprocess wrapper for optical music recognition."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class AudiverisError(Exception):
|
|
||||||
"""Raised when Audiveris processing fails."""
|
|
||||||
|
|
||||||
|
|
||||||
def find_audiveris() -> str:
|
|
||||||
"""Find the Audiveris executable on PATH.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the audiveris executable.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AudiverisError: If Audiveris is not found.
|
|
||||||
"""
|
|
||||||
path = shutil.which("audiveris")
|
|
||||||
if not path:
|
|
||||||
msg = "Audiveris not found on PATH. Install it via 'nix develop' or add it to your environment."
|
|
||||||
raise AudiverisError(msg)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def run_audiveris(input_path: Path, output_dir: Path) -> Path:
|
|
||||||
"""Run Audiveris on an input file and return the path to the generated .mxl.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_path: Path to the input sheet music file (PDF, PNG, JPG, TIFF).
|
|
||||||
output_dir: Directory where Audiveris will write its output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the generated .mxl file.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AudiverisError: If Audiveris fails or produces no output.
|
|
||||||
"""
|
|
||||||
audiveris = find_audiveris()
|
|
||||||
result = subprocess.run(
|
|
||||||
[audiveris, "-batch", "-export", "-output", str(output_dir), str(input_path)],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=False,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
msg = f"Audiveris failed (exit {result.returncode}):\n{result.stderr}"
|
|
||||||
raise AudiverisError(msg)
|
|
||||||
|
|
||||||
mxl_files = list(output_dir.rglob("*.mxl"))
|
|
||||||
if not mxl_files:
|
|
||||||
msg = f"Audiveris produced no .mxl output in {output_dir}"
|
|
||||||
raise AudiverisError(msg)
|
|
||||||
|
|
||||||
return mxl_files[0]
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
"""CLI tool for converting scanned sheet music to MusicXML.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
sheet-music-ocr convert scan.pdf
|
|
||||||
sheet-music-ocr convert scan.png -o output.mxml
|
|
||||||
sheet-music-ocr review output.mxml --provider claude
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import zipfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from python.sheet_music_ocr.audiveris import AudiverisError, run_audiveris
|
|
||||||
from python.sheet_music_ocr.review import LLMProvider, ReviewError, review_mxml
|
|
||||||
|
|
||||||
SUPPORTED_EXTENSIONS = {".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"}
|
|
||||||
|
|
||||||
app = typer.Typer(help="Convert scanned sheet music to MusicXML using Audiveris.")
|
|
||||||
|
|
||||||
|
|
||||||
def extract_mxml_from_mxl(mxl_path: Path, output_path: Path) -> Path:
|
|
||||||
"""Extract the MusicXML file from an .mxl archive.
|
|
||||||
|
|
||||||
An .mxl file is a ZIP archive containing one or more .xml MusicXML files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mxl_path: Path to the .mxl file.
|
|
||||||
output_path: Path where the extracted .mxml file should be written.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The output path.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If no MusicXML file is found inside the archive.
|
|
||||||
"""
|
|
||||||
with zipfile.ZipFile(mxl_path, "r") as zf:
|
|
||||||
xml_names = [n for n in zf.namelist() if n.endswith(".xml") and not n.startswith("META-INF")]
|
|
||||||
if not xml_names:
|
|
||||||
msg = f"No MusicXML (.xml) file found inside {mxl_path}"
|
|
||||||
raise FileNotFoundError(msg)
|
|
||||||
with zf.open(xml_names[0]) as src, output_path.open("wb") as dst:
|
|
||||||
dst.write(src.read())
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def convert(
|
|
||||||
input_file: Annotated[Path, typer.Argument(help="Path to sheet music scan (PDF, PNG, JPG, TIFF).")],
|
|
||||||
output: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option("--output", "-o", help="Output .mxml file path. Defaults to <input_stem>.mxml."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Convert a scanned sheet music file to MusicXML."""
|
|
||||||
if not input_file.exists():
|
|
||||||
typer.echo(f"Error: {input_file} does not exist.", err=True)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
if input_file.suffix.lower() not in SUPPORTED_EXTENSIONS:
|
|
||||||
typer.echo(
|
|
||||||
f"Error: Unsupported format '{input_file.suffix}'. Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
output_path = output or input_file.with_suffix(".mxml")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
try:
|
|
||||||
mxl_path = run_audiveris(input_file, Path(tmpdir))
|
|
||||||
except AudiverisError as e:
|
|
||||||
typer.echo(f"Error: {e}", err=True)
|
|
||||||
raise typer.Exit(code=1) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
extract_mxml_from_mxl(mxl_path, output_path)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
typer.echo(f"Error: {e}", err=True)
|
|
||||||
raise typer.Exit(code=1) from e
|
|
||||||
|
|
||||||
typer.echo(f"Written: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def review(
|
|
||||||
input_file: Annotated[Path, typer.Argument(help="Path to MusicXML (.mxml) file to review.")],
|
|
||||||
output: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option("--output", "-o", help="Output path for corrected .mxml. Defaults to overwriting input."),
|
|
||||||
] = None,
|
|
||||||
provider: Annotated[
|
|
||||||
LLMProvider,
|
|
||||||
typer.Option("--provider", "-p", help="LLM provider to use."),
|
|
||||||
] = LLMProvider.CLAUDE,
|
|
||||||
) -> None:
|
|
||||||
"""Review and fix a MusicXML file using an LLM."""
|
|
||||||
if not input_file.exists():
|
|
||||||
typer.echo(f"Error: {input_file} does not exist.", err=True)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
if input_file.suffix.lower() != ".mxml":
|
|
||||||
typer.echo("Error: Input file must be a .mxml file.", err=True)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
output_path = output or input_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
corrected = review_mxml(input_file, provider)
|
|
||||||
except ReviewError as e:
|
|
||||||
typer.echo(f"Error: {e}", err=True)
|
|
||||||
raise typer.Exit(code=1) from e
|
|
||||||
|
|
||||||
output_path.write_text(corrected, encoding="utf-8")
|
|
||||||
typer.echo(f"Reviewed: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
"""LLM-based MusicXML review and correction.
|
|
||||||
|
|
||||||
Supports both Claude (Anthropic) and OpenAI APIs for reviewing
|
|
||||||
MusicXML output from Audiveris and suggesting/applying fixes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import enum
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
REVIEW_PROMPT = """\
|
|
||||||
You are a music notation expert. Review the following MusicXML file produced by \
|
|
||||||
optical music recognition (Audiveris). Look for and fix common OCR errors including:
|
|
||||||
|
|
||||||
- Incorrect note pitches or durations
|
|
||||||
- Wrong or missing key signatures, time signatures, or clefs
|
|
||||||
- Incorrect rest durations or placements
|
|
||||||
- Missing or incorrect accidentals
|
|
||||||
- Wrong beam groupings or tuplets
|
|
||||||
- Garbled or misspelled lyrics and text annotations
|
|
||||||
- Missing or incorrect dynamic markings
|
|
||||||
- Incorrect measure numbers or barline types
|
|
||||||
- Voice/staff assignment errors
|
|
||||||
|
|
||||||
Return ONLY the corrected MusicXML. Do not include any explanation, commentary, or \
|
|
||||||
markdown formatting. Output the raw XML directly.
|
|
||||||
|
|
||||||
Here is the MusicXML to review:
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
_TIMEOUT = 300
|
|
||||||
|
|
||||||
|
|
||||||
class LLMProvider(enum.StrEnum):
|
|
||||||
"""Supported LLM providers."""
|
|
||||||
|
|
||||||
CLAUDE = "claude"
|
|
||||||
OPENAI = "openai"
|
|
||||||
|
|
||||||
|
|
||||||
class ReviewError(Exception):
|
|
||||||
"""Raised when LLM review fails."""
|
|
||||||
|
|
||||||
|
|
||||||
def _get_api_key(provider: LLMProvider) -> str:
|
|
||||||
env_var = "ANTHROPIC_API_KEY" if provider == LLMProvider.CLAUDE else "OPENAI_API_KEY"
|
|
||||||
key = os.environ.get(env_var)
|
|
||||||
if not key:
|
|
||||||
msg = f"{env_var} environment variable is not set."
|
|
||||||
raise ReviewError(msg)
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def _call_claude(content: str, api_key: str) -> str:
|
|
||||||
response = httpx.post(
|
|
||||||
"https://api.anthropic.com/v1/messages",
|
|
||||||
headers={
|
|
||||||
"x-api-key": api_key,
|
|
||||||
"anthropic-version": "2023-06-01",
|
|
||||||
"content-type": "application/json",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"model": "claude-sonnet-4-20250514",
|
|
||||||
"max_tokens": 16384,
|
|
||||||
"messages": [{"role": "user", "content": REVIEW_PROMPT + content}],
|
|
||||||
},
|
|
||||||
timeout=_TIMEOUT,
|
|
||||||
)
|
|
||||||
if response.status_code != 200: # noqa: PLR2004
|
|
||||||
msg = f"Claude API error ({response.status_code}): {response.text}"
|
|
||||||
raise ReviewError(msg)
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
return data["content"][0]["text"]
|
|
||||||
|
|
||||||
|
|
||||||
def _call_openai(content: str, api_key: str) -> str:
|
|
||||||
response = httpx.post(
|
|
||||||
"https://api.openai.com/v1/chat/completions",
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"messages": [{"role": "user", "content": REVIEW_PROMPT + content}],
|
|
||||||
"max_tokens": 16384,
|
|
||||||
},
|
|
||||||
timeout=_TIMEOUT,
|
|
||||||
)
|
|
||||||
if response.status_code != 200: # noqa: PLR2004
|
|
||||||
msg = f"OpenAI API error ({response.status_code}): {response.text}"
|
|
||||||
raise ReviewError(msg)
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
return data["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
|
|
||||||
def review_mxml(mxml_path: Path, provider: LLMProvider) -> str:
|
|
||||||
"""Review a MusicXML file using an LLM and return corrected content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mxml_path: Path to the .mxml file to review.
|
|
||||||
provider: Which LLM provider to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The corrected MusicXML content as a string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ReviewError: If the API call fails or the key is missing.
|
|
||||||
FileNotFoundError: If the input file does not exist.
|
|
||||||
"""
|
|
||||||
content = mxml_path.read_text(encoding="utf-8")
|
|
||||||
api_key = _get_api_key(provider)
|
|
||||||
|
|
||||||
if provider == LLMProvider.CLAUDE:
|
|
||||||
return _call_claude(content, api_key)
|
|
||||||
return _call_openai(content, api_key)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Signal command and control bot."""
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Signal bot commands."""
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
"""Van inventory command — parse receipts and item lists via LLM, push to API."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from python.signal_bot.models import InventoryItem, InventoryUpdate
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from python.signal_bot.llm_client import LLMClient
|
|
||||||
from python.signal_bot.models import SignalMessage
|
|
||||||
from python.signal_bot.signal_client import SignalClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """\
|
|
||||||
You are an inventory assistant. Extract items from the input and return ONLY
|
|
||||||
a JSON array. Each element must have these fields:
|
|
||||||
- "name": item name (string)
|
|
||||||
- "quantity": numeric count or amount (default 1)
|
|
||||||
- "unit": unit of measure (e.g. "each", "lb", "oz", "gallon", "bag", "box")
|
|
||||||
- "category": category like "food", "tools", "supplies", etc.
|
|
||||||
- "notes": any extra detail (empty string if none)
|
|
||||||
|
|
||||||
Example output:
|
|
||||||
[{"name": "water bottles", "quantity": 6, "unit": "gallon", "category": "supplies", "notes": "1 gallon each"}]
|
|
||||||
|
|
||||||
Return ONLY the JSON array, no other text.\
|
|
||||||
"""
|
|
||||||
|
|
||||||
IMAGE_PROMPT = "Extract all items from this receipt or inventory photo."
|
|
||||||
TEXT_PROMPT = "Extract all items from this inventory list."
|
|
||||||
|
|
||||||
|
|
||||||
def parse_llm_response(raw: str) -> list[InventoryItem]:
|
|
||||||
"""Parse the LLM JSON response into InventoryItem list."""
|
|
||||||
text = raw.strip()
|
|
||||||
# Strip markdown code fences if present
|
|
||||||
if text.startswith("```"):
|
|
||||||
lines = text.split("\n")
|
|
||||||
lines = [line for line in lines if not line.startswith("```")]
|
|
||||||
text = "\n".join(lines)
|
|
||||||
|
|
||||||
items_data: list[dict[str, Any]] = json.loads(text)
|
|
||||||
return [InventoryItem.model_validate(item) for item in items_data]
|
|
||||||
|
|
||||||
|
|
||||||
def _upsert_item(api_url: str, item: InventoryItem) -> None:
|
|
||||||
"""Create or update an item via the van_inventory API.
|
|
||||||
|
|
||||||
Fetches existing items, and if one with the same name exists,
|
|
||||||
patches its quantity (summing). Otherwise creates a new item.
|
|
||||||
"""
|
|
||||||
base = api_url.rstrip("/")
|
|
||||||
response = httpx.get(f"{base}/api/items", timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
existing: list[dict[str, Any]] = response.json()
|
|
||||||
|
|
||||||
match = next((e for e in existing if e["name"].lower() == item.name.lower()), None)
|
|
||||||
|
|
||||||
if match:
|
|
||||||
new_qty = match["quantity"] + item.quantity
|
|
||||||
patch = {"quantity": new_qty}
|
|
||||||
if item.category:
|
|
||||||
patch["category"] = item.category
|
|
||||||
response = httpx.patch(f"{base}/api/items/{match['id']}", json=patch, timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
return
|
|
||||||
payload = {
|
|
||||||
"name": item.name,
|
|
||||||
"quantity": item.quantity,
|
|
||||||
"unit": item.unit,
|
|
||||||
"category": item.category or None,
|
|
||||||
}
|
|
||||||
response = httpx.post(f"{base}/api/items", json=payload, timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
|
|
||||||
def handle_inventory_update(
|
|
||||||
message: SignalMessage,
|
|
||||||
signal: SignalClient,
|
|
||||||
llm: LLMClient,
|
|
||||||
api_url: str,
|
|
||||||
) -> InventoryUpdate:
|
|
||||||
"""Process an inventory update from a Signal message.
|
|
||||||
|
|
||||||
Accepts either an image (receipt photo) or text list.
|
|
||||||
Uses the LLM to extract structured items, then pushes to the van_inventory API.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Processing inventory update from {message.source}")
|
|
||||||
if message.attachments:
|
|
||||||
image_data = signal.get_attachment(message.attachments[0])
|
|
||||||
raw_response = llm.chat(
|
|
||||||
IMAGE_PROMPT,
|
|
||||||
image_data=image_data,
|
|
||||||
system=SYSTEM_PROMPT,
|
|
||||||
)
|
|
||||||
source_type = "receipt_photo"
|
|
||||||
elif message.message.strip():
|
|
||||||
raw_response = llm.chat(
|
|
||||||
f"{TEXT_PROMPT}\n\n{message.message}",
|
|
||||||
system=SYSTEM_PROMPT,
|
|
||||||
)
|
|
||||||
source_type = "text_list"
|
|
||||||
else:
|
|
||||||
signal.reply(message, "Send a photo of a receipt or a text list of items to update inventory.")
|
|
||||||
return InventoryUpdate()
|
|
||||||
|
|
||||||
logger.info(f"{raw_response=}")
|
|
||||||
|
|
||||||
new_items = parse_llm_response(raw_response)
|
|
||||||
|
|
||||||
logger.info(f"{new_items=}")
|
|
||||||
|
|
||||||
for item in new_items:
|
|
||||||
_upsert_item(api_url, item)
|
|
||||||
|
|
||||||
summary = _format_summary(new_items)
|
|
||||||
signal.reply(message, f"Inventory updated with {len(new_items)} item(s):\n{summary}")
|
|
||||||
|
|
||||||
return InventoryUpdate(items=new_items, raw_response=raw_response, source_type=source_type)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to process inventory update")
|
|
||||||
signal.reply(message, "Failed to process inventory update. Check logs for details.")
|
|
||||||
return InventoryUpdate()
|
|
||||||
|
|
||||||
|
|
||||||
def _format_summary(items: list[InventoryItem]) -> str:
|
|
||||||
"""Format items into a readable summary."""
|
|
||||||
lines = [f" - {item.name} x{item.quantity} {item.unit} [{item.category}]" for item in items]
|
|
||||||
return "\n".join(lines)
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
"""Device registry — tracks verified/unverified devices by safety number."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import TYPE_CHECKING, NamedTuple
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from python.common import utcnow
|
|
||||||
from python.orm.richie.signal_device import SignalDevice
|
|
||||||
from python.signal_bot.models import TrustLevel
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
|
|
||||||
from python.signal_bot.signal_client import SignalClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_BLOCKED_TTL = timedelta(minutes=60)
|
|
||||||
_DEFAULT_TTL = timedelta(minutes=5)
|
|
||||||
|
|
||||||
|
|
||||||
class _CacheEntry(NamedTuple):
|
|
||||||
expires: datetime
|
|
||||||
trust_level: TrustLevel
|
|
||||||
has_safety_number: bool
|
|
||||||
safety_number: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceRegistry:
|
|
||||||
"""Manage device trust based on Signal safety numbers.
|
|
||||||
|
|
||||||
Devices start as UNVERIFIED. An admin verifies them over SSH by calling
|
|
||||||
``verify(phone_number)`` which marks the device VERIFIED and also tells
|
|
||||||
signal-cli to trust the identity.
|
|
||||||
|
|
||||||
Only VERIFIED devices may execute commands.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
|
||||||
self.signal_client = signal_client
|
|
||||||
self.engine = engine
|
|
||||||
self._contact_cache: dict[str, _CacheEntry] = {}
|
|
||||||
|
|
||||||
def is_verified(self, phone_number: str) -> bool:
|
|
||||||
"""Check if a phone number is verified."""
|
|
||||||
if entry := self._cached(phone_number):
|
|
||||||
return entry.trust_level == TrustLevel.VERIFIED
|
|
||||||
device = self.get_device(phone_number)
|
|
||||||
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
|
||||||
|
|
||||||
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
|
||||||
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
|
||||||
now = utcnow()
|
|
||||||
|
|
||||||
entry = self._cached(phone_number)
|
|
||||||
if entry and entry.safety_number == safety_number:
|
|
||||||
return
|
|
||||||
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
device = session.execute(
|
|
||||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
|
|
||||||
if device:
|
|
||||||
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
|
||||||
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
|
|
||||||
device.safety_number = safety_number
|
|
||||||
device.trust_level = TrustLevel.UNVERIFIED
|
|
||||||
device.last_seen = now
|
|
||||||
else:
|
|
||||||
device = SignalDevice(
|
|
||||||
phone_number=phone_number,
|
|
||||||
safety_number=safety_number,
|
|
||||||
trust_level=TrustLevel.UNVERIFIED,
|
|
||||||
last_seen=now,
|
|
||||||
)
|
|
||||||
session.add(device)
|
|
||||||
logger.info(f"New device registered: {phone_number}")
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
|
||||||
self._contact_cache[phone_number] = _CacheEntry(
|
|
||||||
expires=now + ttl,
|
|
||||||
trust_level=device.trust_level,
|
|
||||||
has_safety_number=device.safety_number is not None,
|
|
||||||
safety_number=device.safety_number,
|
|
||||||
)
|
|
||||||
|
|
||||||
def has_safety_number(self, phone_number: str) -> bool:
|
|
||||||
"""Check if a device has a safety number on file."""
|
|
||||||
if entry := self._cached(phone_number):
|
|
||||||
return entry.has_safety_number
|
|
||||||
device = self.get_device(phone_number)
|
|
||||||
return device is not None and device.safety_number is not None
|
|
||||||
|
|
||||||
def verify(self, phone_number: str) -> bool:
|
|
||||||
"""Mark a device as verified. Called by admin over SSH.
|
|
||||||
|
|
||||||
Returns True if the device was found and verified.
|
|
||||||
"""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
device = session.execute(
|
|
||||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
|
|
||||||
if not device:
|
|
||||||
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
device.trust_level = TrustLevel.VERIFIED
|
|
||||||
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
|
||||||
session.commit()
|
|
||||||
self._contact_cache[phone_number] = _CacheEntry(
|
|
||||||
expires=utcnow() + _DEFAULT_TTL,
|
|
||||||
trust_level=TrustLevel.VERIFIED,
|
|
||||||
has_safety_number=device.safety_number is not None,
|
|
||||||
safety_number=device.safety_number,
|
|
||||||
)
|
|
||||||
logger.info(f"Device verified: {phone_number}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def block(self, phone_number: str) -> bool:
|
|
||||||
"""Block a device."""
|
|
||||||
return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
|
|
||||||
|
|
||||||
def unverify(self, phone_number: str) -> bool:
|
|
||||||
"""Reset a device to unverified."""
|
|
||||||
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
|
|
||||||
|
|
||||||
def list_devices(self) -> list[SignalDevice]:
|
|
||||||
"""Return all known devices."""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
return list(session.execute(select(SignalDevice)).scalars().all())
|
|
||||||
|
|
||||||
def sync_identities(self) -> None:
|
|
||||||
"""Pull identity list from signal-cli and record any new ones."""
|
|
||||||
identities = self.signal_client.get_identities()
|
|
||||||
for identity in identities:
|
|
||||||
number = identity.get("number", "")
|
|
||||||
safety = identity.get("safety_number", identity.get("fingerprint", ""))
|
|
||||||
if number:
|
|
||||||
self.record_contact(number, safety)
|
|
||||||
|
|
||||||
def _cached(self, phone_number: str) -> _CacheEntry | None:
|
|
||||||
"""Return the cache entry if it exists and hasn't expired."""
|
|
||||||
entry = self._contact_cache.get(phone_number)
|
|
||||||
if entry and utcnow() < entry.expires:
|
|
||||||
return entry
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_device(self, phone_number: str) -> SignalDevice | None:
|
|
||||||
"""Fetch a device by phone number."""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
return session.execute(
|
|
||||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
|
|
||||||
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
|
||||||
"""Update the trust level for a device."""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
device = session.execute(
|
|
||||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
|
|
||||||
if not device:
|
|
||||||
return False
|
|
||||||
|
|
||||||
device.trust_level = level
|
|
||||||
session.commit()
|
|
||||||
ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
|
||||||
self._contact_cache[phone_number] = _CacheEntry(
|
|
||||||
expires=utcnow() + ttl,
|
|
||||||
trust_level=level,
|
|
||||||
has_safety_number=device.safety_number is not None,
|
|
||||||
safety_number=device.safety_number,
|
|
||||||
)
|
|
||||||
if log_msg:
|
|
||||||
logger.info(f"{log_msg}: {phone_number}")
|
|
||||||
return True
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
"""Flexible LLM client for ollama backends."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
from typing import Any, Self
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
|
||||||
"""Talk to an ollama instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Ollama model name.
|
|
||||||
host: Ollama host.
|
|
||||||
port: Ollama port.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: str, host: str, port: int = 11434, *, temperature: float = 0.1) -> None:
|
|
||||||
self.model = model
|
|
||||||
self.temperature = temperature
|
|
||||||
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=120)
|
|
||||||
|
|
||||||
def chat(self, prompt: str, image_data: bytes | None = None, system: str | None = None) -> str:
|
|
||||||
"""Send a text prompt and return the response."""
|
|
||||||
messages: list[dict[str, Any]] = []
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
|
|
||||||
user_msg = {"role": "user", "content": prompt}
|
|
||||||
if image_data:
|
|
||||||
user_msg["images"] = [base64.b64encode(image_data).decode()]
|
|
||||||
|
|
||||||
messages.append(user_msg)
|
|
||||||
return self._generate(messages)
|
|
||||||
|
|
||||||
def _generate(self, messages: list[dict[str, Any]]) -> str:
|
|
||||||
"""Call the ollama chat API."""
|
|
||||||
payload = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": False,
|
|
||||||
"options": {"temperature": self.temperature},
|
|
||||||
}
|
|
||||||
logger.info(f"LLM request to {self.model}")
|
|
||||||
response = self._client.post("/api/chat", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return data["message"]["content"]
|
|
||||||
|
|
||||||
def list_models(self) -> list[str]:
|
|
||||||
"""List available models on the ollama instance."""
|
|
||||||
response = self._client.get("/api/tags")
|
|
||||||
response.raise_for_status()
|
|
||||||
return [m["name"] for m in response.json().get("models", [])]
|
|
||||||
|
|
||||||
def __enter__(self) -> Self:
|
|
||||||
"""Enter the context manager."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: object) -> None:
|
|
||||||
"""Close the HTTP client on exit."""
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
self._client.close()
|
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
"""Signal command and control bot — main entry point."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from os import getenv
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
|
|
||||||
|
|
||||||
from python.common import configure_logger, utcnow
|
|
||||||
from python.orm.common import get_postgres_engine
|
|
||||||
from python.orm.richie.dead_letter_message import DeadLetterMessage
|
|
||||||
from python.signal_bot.commands.inventory import handle_inventory_update
|
|
||||||
from python.signal_bot.device_registry import DeviceRegistry
|
|
||||||
from python.signal_bot.llm_client import LLMClient
|
|
||||||
from python.signal_bot.models import BotConfig, MessageStatus, SignalMessage
|
|
||||||
from python.signal_bot.signal_client import SignalClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
HELP_TEXT = (
|
|
||||||
"Available commands:\n"
|
|
||||||
" inventory <text list> — update van inventory from a text list\n"
|
|
||||||
" inventory (+ photo) — update van inventory from a receipt photo\n"
|
|
||||||
" status — show bot status\n"
|
|
||||||
" help — show this help message\n"
|
|
||||||
"Send a receipt photo with the message 'inventory' to scan it.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def help_action(
|
|
||||||
signal: SignalClient,
|
|
||||||
message: SignalMessage,
|
|
||||||
_llm: LLMClient,
|
|
||||||
_registry: DeviceRegistry,
|
|
||||||
_config: BotConfig,
|
|
||||||
_cmd: str,
|
|
||||||
) -> None:
|
|
||||||
"""Return the help text for the bot."""
|
|
||||||
signal.reply(message, HELP_TEXT)
|
|
||||||
|
|
||||||
|
|
||||||
def status_action(
|
|
||||||
signal: SignalClient,
|
|
||||||
message: SignalMessage,
|
|
||||||
llm: LLMClient,
|
|
||||||
registry: DeviceRegistry,
|
|
||||||
_config: BotConfig,
|
|
||||||
_cmd: str,
|
|
||||||
) -> None:
|
|
||||||
"""Return the status of the bot."""
|
|
||||||
models = llm.list_models()
|
|
||||||
model_list = ", ".join(models[:10])
|
|
||||||
device_count = len(registry.list_devices())
|
|
||||||
signal.reply(
|
|
||||||
message,
|
|
||||||
f"Bot online.\nLLM: {llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def unknown_action(
|
|
||||||
signal: SignalClient,
|
|
||||||
message: SignalMessage,
|
|
||||||
_llm: LLMClient,
|
|
||||||
_registry: DeviceRegistry,
|
|
||||||
_config: BotConfig,
|
|
||||||
cmd: str,
|
|
||||||
) -> None:
|
|
||||||
"""Return an error message for an unknown command."""
|
|
||||||
signal.reply(message, f"Unknown command: {cmd}\n\n{HELP_TEXT}")
|
|
||||||
|
|
||||||
|
|
||||||
def inventory_action(
|
|
||||||
signal: SignalClient,
|
|
||||||
message: SignalMessage,
|
|
||||||
llm: LLMClient,
|
|
||||||
_registry: DeviceRegistry,
|
|
||||||
config: BotConfig,
|
|
||||||
_cmd: str,
|
|
||||||
) -> None:
|
|
||||||
"""Process an inventory update."""
|
|
||||||
handle_inventory_update(message, signal, llm, config.inventory_api_url)
|
|
||||||
|
|
||||||
|
|
||||||
def dispatch(
|
|
||||||
message: SignalMessage,
|
|
||||||
signal: SignalClient,
|
|
||||||
llm: LLMClient,
|
|
||||||
registry: DeviceRegistry,
|
|
||||||
config: BotConfig,
|
|
||||||
) -> None:
|
|
||||||
"""Route an incoming message to the right command handler."""
|
|
||||||
source = message.source
|
|
||||||
|
|
||||||
if not registry.is_verified(source) or not registry.has_safety_number(source):
|
|
||||||
logger.info(f"Device {source} not verified, ignoring message")
|
|
||||||
return
|
|
||||||
|
|
||||||
text = message.message.strip()
|
|
||||||
parts = text.split()
|
|
||||||
|
|
||||||
if not parts and not message.attachments:
|
|
||||||
return
|
|
||||||
|
|
||||||
cmd = parts[0].lower() if parts else ""
|
|
||||||
|
|
||||||
commands = {
|
|
||||||
"help": help_action,
|
|
||||||
"status": status_action,
|
|
||||||
"inventory": inventory_action,
|
|
||||||
}
|
|
||||||
logger.info(f"f{source=} running {cmd=} with {message=}")
|
|
||||||
action = commands.get(cmd)
|
|
||||||
if action is None:
|
|
||||||
if message.attachments:
|
|
||||||
action = inventory_action
|
|
||||||
cmd = "inventory"
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
action(signal, message, llm, registry, config, cmd)
|
|
||||||
|
|
||||||
|
|
||||||
def _process_message(
|
|
||||||
message: SignalMessage,
|
|
||||||
signal: SignalClient,
|
|
||||||
llm: LLMClient,
|
|
||||||
registry: DeviceRegistry,
|
|
||||||
config: BotConfig,
|
|
||||||
) -> None:
|
|
||||||
"""Process a single message, sending it to the dead letter queue after repeated failures."""
|
|
||||||
max_attempts = config.max_message_attempts
|
|
||||||
for attempt in range(1, max_attempts + 1):
|
|
||||||
try:
|
|
||||||
safety_number = signal.get_safety_number(message.source)
|
|
||||||
registry.record_contact(message.source, safety_number)
|
|
||||||
dispatch(message, signal, llm, registry, config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})")
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue")
|
|
||||||
with Session(config.engine) as session:
|
|
||||||
session.add(
|
|
||||||
DeadLetterMessage(
|
|
||||||
source=message.source,
|
|
||||||
message=message.message,
|
|
||||||
received_at=utcnow(),
|
|
||||||
status=MessageStatus.UNPROCESSED,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def run_loop(
|
|
||||||
config: BotConfig,
|
|
||||||
signal: SignalClient,
|
|
||||||
llm: LLMClient,
|
|
||||||
registry: DeviceRegistry,
|
|
||||||
) -> None:
|
|
||||||
"""Listen for messages via WebSocket, reconnecting on failure."""
|
|
||||||
logger.info("Bot started — listening via WebSocket")
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_attempt(config.max_retries),
|
|
||||||
wait=wait_exponential(multiplier=config.reconnect_delay, max=config.max_reconnect_delay),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def _listen() -> None:
|
|
||||||
for message in signal.listen():
|
|
||||||
logger.info(f"Message from {message.source}: {message.message[:80]}")
|
|
||||||
_process_message(message, signal, llm, registry, config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
_listen()
|
|
||||||
except Exception:
|
|
||||||
logger.critical("Max retries exceeded, shutting down")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
|
||||||
log_level: Annotated[str, typer.Option()] = "INFO",
|
|
||||||
llm_timeout: Annotated[int, typer.Option()] = 600,
|
|
||||||
) -> None:
|
|
||||||
"""Run the Signal command and control bot."""
|
|
||||||
configure_logger(log_level)
|
|
||||||
signal_api_url = getenv("SIGNAL_API_URL")
|
|
||||||
phone_number = getenv("SIGNAL_PHONE_NUMBER")
|
|
||||||
inventory_api_url = getenv("INVENTORY_API_URL")
|
|
||||||
|
|
||||||
if signal_api_url is None:
|
|
||||||
error = "SIGNAL_API_URL environment variable not set"
|
|
||||||
raise ValueError(error)
|
|
||||||
if phone_number is None:
|
|
||||||
error = "SIGNAL_PHONE_NUMBER environment variable not set"
|
|
||||||
raise ValueError(error)
|
|
||||||
if inventory_api_url is None:
|
|
||||||
error = "INVENTORY_API_URL environment variable not set"
|
|
||||||
raise ValueError(error)
|
|
||||||
|
|
||||||
engine = get_postgres_engine(name="SIGNALBOT")
|
|
||||||
config = BotConfig(
|
|
||||||
signal_api_url=signal_api_url,
|
|
||||||
phone_number=phone_number,
|
|
||||||
inventory_api_url=inventory_api_url,
|
|
||||||
engine=engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_host = getenv("LLM_HOST")
|
|
||||||
llm_model = getenv("LLM_MODEL", "qwen3-vl:32b")
|
|
||||||
llm_port = int(getenv("LLM_PORT", "11434"))
|
|
||||||
if llm_host is None:
|
|
||||||
error = "LLM_HOST environment variable not set"
|
|
||||||
raise ValueError(error)
|
|
||||||
|
|
||||||
with (
|
|
||||||
SignalClient(config.signal_api_url, config.phone_number) as signal,
|
|
||||||
LLMClient(model=llm_model, host=llm_host, port=llm_port, timeout=llm_timeout) as llm,
|
|
||||||
):
|
|
||||||
registry = DeviceRegistry(signal, engine)
|
|
||||||
run_loop(config, signal, llm, registry)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
typer.run(main)
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
"""Models for the Signal command and control bot."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime # noqa: TC003 - pydantic needs this at runtime
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from sqlalchemy.engine import Engine # noqa: TC002 - pydantic needs this at runtime
|
|
||||||
|
|
||||||
|
|
||||||
class TrustLevel(StrEnum):
|
|
||||||
"""Device trust level."""
|
|
||||||
|
|
||||||
VERIFIED = "verified"
|
|
||||||
UNVERIFIED = "unverified"
|
|
||||||
BLOCKED = "blocked"
|
|
||||||
|
|
||||||
|
|
||||||
class MessageStatus(StrEnum):
|
|
||||||
"""Dead letter queue message status."""
|
|
||||||
|
|
||||||
UNPROCESSED = "unprocessed"
|
|
||||||
PROCESSED = "processed"
|
|
||||||
|
|
||||||
|
|
||||||
class Device(BaseModel):
|
|
||||||
"""A registered device tracked by safety number."""
|
|
||||||
|
|
||||||
phone_number: str
|
|
||||||
safety_number: str
|
|
||||||
trust_level: TrustLevel = TrustLevel.UNVERIFIED
|
|
||||||
first_seen: datetime
|
|
||||||
last_seen: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class SignalMessage(BaseModel):
|
|
||||||
"""An incoming Signal message."""
|
|
||||||
|
|
||||||
source: str
|
|
||||||
timestamp: int
|
|
||||||
message: str = ""
|
|
||||||
attachments: list[str] = []
|
|
||||||
group_id: str | None = None
|
|
||||||
is_receipt: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class SignalEnvelope(BaseModel):
|
|
||||||
"""Raw envelope from signal-cli-rest-api."""
|
|
||||||
|
|
||||||
envelope: dict[str, Any]
|
|
||||||
account: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class InventoryItem(BaseModel):
|
|
||||||
"""An item in the van inventory."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
quantity: float = 1
|
|
||||||
unit: str = "each"
|
|
||||||
category: str = ""
|
|
||||||
notes: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class InventoryUpdate(BaseModel):
|
|
||||||
"""Result of processing an inventory update."""
|
|
||||||
|
|
||||||
items: list[InventoryItem] = []
|
|
||||||
raw_response: str = ""
|
|
||||||
source_type: str = "" # "receipt_photo" or "text_list"
|
|
||||||
|
|
||||||
|
|
||||||
class BotConfig(BaseModel):
|
|
||||||
"""Top-level bot configuration."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
signal_api_url: str
|
|
||||||
phone_number: str
|
|
||||||
inventory_api_url: str
|
|
||||||
engine: Engine
|
|
||||||
reconnect_delay: int = 5
|
|
||||||
max_reconnect_delay: int = 300
|
|
||||||
max_retries: int = 10
|
|
||||||
max_message_attempts: int = 3
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
"""Client for the signal-cli-rest-api."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Any, Self
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import websockets.sync.client
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
from python.signal_bot.models import SignalMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_envelope(envelope: dict[str, Any]) -> SignalMessage | None:
|
|
||||||
"""Parse a signal-cli envelope into a SignalMessage, or None if not a data message."""
|
|
||||||
data_message = envelope.get("dataMessage")
|
|
||||||
if not data_message:
|
|
||||||
return None
|
|
||||||
|
|
||||||
attachment_ids = [att["id"] for att in data_message.get("attachments", []) if "id" in att]
|
|
||||||
|
|
||||||
group_info = data_message.get("groupInfo")
|
|
||||||
group_id = group_info.get("groupId") if group_info else None
|
|
||||||
|
|
||||||
return SignalMessage(
|
|
||||||
source=envelope.get("source", ""),
|
|
||||||
timestamp=envelope.get("timestamp", 0),
|
|
||||||
message=data_message.get("message", "") or "",
|
|
||||||
attachments=attachment_ids,
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SignalClient:
|
|
||||||
"""Communicate with signal-cli-rest-api.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_url: URL of the signal-cli-rest-api (e.g. http://localhost:8989).
|
|
||||||
phone_number: The registered phone number to send/receive as.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base_url: str, phone_number: str) -> None:
|
|
||||||
self.base_url = base_url.rstrip("/")
|
|
||||||
self.phone_number = phone_number
|
|
||||||
self._client = httpx.Client(base_url=self.base_url, timeout=30)
|
|
||||||
|
|
||||||
def _ws_url(self) -> str:
|
|
||||||
"""Build the WebSocket URL from the base HTTP URL."""
|
|
||||||
url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
|
||||||
return f"{url}/v1/receive/{self.phone_number}"
|
|
||||||
|
|
||||||
def listen(self) -> Generator[SignalMessage]:
|
|
||||||
"""Connect via WebSocket and yield messages as they arrive."""
|
|
||||||
ws_url = self._ws_url()
|
|
||||||
logger.info(f"Connecting to WebSocket: {ws_url}")
|
|
||||||
|
|
||||||
with websockets.sync.client.connect(ws_url) as ws:
|
|
||||||
for raw in ws:
|
|
||||||
try:
|
|
||||||
data = json.loads(raw)
|
|
||||||
envelope = data.get("envelope", {})
|
|
||||||
message = _parse_envelope(envelope)
|
|
||||||
if message:
|
|
||||||
yield message
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(f"Non-JSON WebSocket frame: {raw[:200]}")
|
|
||||||
|
|
||||||
def send(self, recipient: str, message: str) -> None:
|
|
||||||
"""Send a text message."""
|
|
||||||
payload = {
|
|
||||||
"message": message,
|
|
||||||
"number": self.phone_number,
|
|
||||||
"recipients": [recipient],
|
|
||||||
}
|
|
||||||
response = self._client.post("/v2/send", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
def send_to_group(self, group_id: str, message: str) -> None:
|
|
||||||
"""Send a message to a group."""
|
|
||||||
payload = {
|
|
||||||
"message": message,
|
|
||||||
"number": self.phone_number,
|
|
||||||
"recipients": [group_id],
|
|
||||||
}
|
|
||||||
response = self._client.post("/v2/send", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
def get_attachment(self, attachment_id: str) -> bytes:
|
|
||||||
"""Download an attachment by ID."""
|
|
||||||
response = self._client.get(f"/v1/attachments/{attachment_id}")
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
def get_identities(self) -> list[dict[str, Any]]:
|
|
||||||
"""List known identities and their trust levels."""
|
|
||||||
response = self._client.get(f"/v1/identities/{self.phone_number}")
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def get_safety_number(self, phone_number: str) -> str | None:
|
|
||||||
"""Look up the safety number for a contact from signal-cli's local store."""
|
|
||||||
for identity in self.get_identities():
|
|
||||||
if identity.get("number") == phone_number:
|
|
||||||
return identity.get("safety_number", identity.get("fingerprint", ""))
|
|
||||||
return None
|
|
||||||
|
|
||||||
def trust_identity(self, number_to_trust: str, *, trust_all_known_keys: bool = False) -> None:
|
|
||||||
"""Trust an identity (verify safety number)."""
|
|
||||||
payload: dict[str, Any] = {}
|
|
||||||
if trust_all_known_keys:
|
|
||||||
payload["trust_all_known_keys"] = True
|
|
||||||
response = self._client.put(
|
|
||||||
f"/v1/identities/{self.phone_number}/trust/{number_to_trust}",
|
|
||||||
json=payload,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
def reply(self, message: SignalMessage, text: str) -> None:
|
|
||||||
"""Reply to a message, routing to group or individual."""
|
|
||||||
if message.group_id:
|
|
||||||
self.send_to_group(message.group_id, text)
|
|
||||||
else:
|
|
||||||
self.send(message.source, text)
|
|
||||||
|
|
||||||
def __enter__(self) -> Self:
|
|
||||||
"""Enter the context manager."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: object) -> None:
|
|
||||||
"""Close the HTTP client on exit."""
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
self._client.close()
|
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
"""Van weather service - fetches weather with masked GPS for privacy."""
|
"""Van weather service - fetches weather with masked GPS for privacy."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import typer
|
import typer
|
||||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||||
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_fixed
|
|
||||||
|
|
||||||
from python.common import configure_logger
|
from python.common import configure_logger
|
||||||
from python.van_weather.models import Config, DailyForecast, HourlyForecast, Weather
|
from python.van_weather.models import Config, DailyForecast, HourlyForecast, Weather
|
||||||
@@ -29,25 +29,15 @@ CONDITION_MAP = {
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_fixed(5),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def get_ha_state(url: str, token: str, entity_id: str) -> float:
|
def get_ha_state(url: str, token: str, entity_id: str) -> float:
|
||||||
"""Get numeric state from Home Asasistant entity."""
|
"""Get numeric state from Home Assistant entity."""
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{url}/api/states/{entity_id}",
|
f"{url}/api/states/{entity_id}",
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
state = response.json()["state"]
|
return float(response.json()["state"])
|
||||||
if state in ("unavailable", "unknown"):
|
|
||||||
error = f"{entity_id} is {state}"
|
|
||||||
raise ValueError(error)
|
|
||||||
return float(state)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_daily_forecast(data: dict[str, dict[str, Any]]) -> list[DailyForecast]:
|
def parse_daily_forecast(data: dict[str, dict[str, Any]]) -> list[DailyForecast]:
|
||||||
@@ -65,9 +55,6 @@ def parse_daily_forecast(data: dict[str, dict[str, Any]]) -> list[DailyForecast]
|
|||||||
temperature=day.get("temperatureHigh"),
|
temperature=day.get("temperatureHigh"),
|
||||||
templow=day.get("temperatureLow"),
|
templow=day.get("temperatureLow"),
|
||||||
precipitation_probability=day.get("precipProbability"),
|
precipitation_probability=day.get("precipProbability"),
|
||||||
moon_phase=day.get("moonPhase"),
|
|
||||||
wind_gust=day.get("windGust"),
|
|
||||||
cloud_cover=day.get("cloudCover"),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -93,12 +80,6 @@ def parse_hourly_forecast(data: dict[str, dict[str, Any]]) -> list[HourlyForecas
|
|||||||
return hourly_forecasts
|
return hourly_forecasts
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_fixed(5),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def fetch_weather(api_key: str, lat: float, lon: float) -> Weather:
|
def fetch_weather(api_key: str, lat: float, lon: float) -> Weather:
|
||||||
"""Fetch weather from Pirate Weather API."""
|
"""Fetch weather from Pirate Weather API."""
|
||||||
url = f"https://api.pirateweather.net/forecast/{api_key}/{lat},{lon}"
|
url = f"https://api.pirateweather.net/forecast/{api_key}/{lat},{lon}"
|
||||||
@@ -121,25 +102,29 @@ def fetch_weather(api_key: str, lat: float, lon: float) -> Weather:
|
|||||||
summary=current.get("summary"),
|
summary=current.get("summary"),
|
||||||
pressure=current.get("pressure"),
|
pressure=current.get("pressure"),
|
||||||
visibility=current.get("visibility"),
|
visibility=current.get("visibility"),
|
||||||
uv_index=current.get("uvIndex"),
|
|
||||||
ozone=current.get("ozone"),
|
|
||||||
nearest_storm_distance=current.get("nearestStormDistance"),
|
|
||||||
nearest_storm_bearing=current.get("nearestStormBearing"),
|
|
||||||
precip_probability=current.get("precipProbability"),
|
|
||||||
cloud_cover=current.get("cloudCover"),
|
|
||||||
daily_forecasts=daily_forecasts,
|
daily_forecasts=daily_forecasts,
|
||||||
hourly_forecasts=hourly_forecasts,
|
hourly_forecasts=hourly_forecasts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_fixed(5),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def post_to_ha(url: str, token: str, weather: Weather) -> None:
|
def post_to_ha(url: str, token: str, weather: Weather) -> None:
|
||||||
"""Post weather data to Home Assistant as sensor entities."""
|
"""Post weather data to Home Assistant as sensor entities."""
|
||||||
|
max_retries = 6
|
||||||
|
retry_delay = 10
|
||||||
|
|
||||||
|
for attempt in range(1, max_retries + 1):
|
||||||
|
try:
|
||||||
|
_post_weather_data(url, token, weather)
|
||||||
|
except requests.RequestException:
|
||||||
|
if attempt == max_retries:
|
||||||
|
logger.exception(f"Failed to post weather to HA after {max_retries} attempts")
|
||||||
|
return
|
||||||
|
logger.warning(f"Post to HA failed (attempt {attempt}/{max_retries}), retrying in {retry_delay}s")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
|
|
||||||
|
def _post_weather_data(url: str, token: str, weather: Weather) -> None:
|
||||||
|
"""Post all weather data to Home Assistant. Raises on failure."""
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
# Post current weather as individual sensors
|
# Post current weather as individual sensors
|
||||||
@@ -176,30 +161,6 @@ def post_to_ha(url: str, token: str, weather: Weather) -> None:
|
|||||||
"state": weather.visibility,
|
"state": weather.visibility,
|
||||||
"attributes": {"unit_of_measurement": "mi"},
|
"attributes": {"unit_of_measurement": "mi"},
|
||||||
},
|
},
|
||||||
"sensor.van_weather_uv_index": {
|
|
||||||
"state": weather.uv_index,
|
|
||||||
"attributes": {"friendly_name": "Van Weather UV Index", "icon": "mdi:sun-wireless"},
|
|
||||||
},
|
|
||||||
"sensor.van_weather_ozone": {
|
|
||||||
"state": weather.ozone,
|
|
||||||
"attributes": {"unit_of_measurement": "DU", "icon": "mdi:earth"},
|
|
||||||
},
|
|
||||||
"sensor.van_weather_nearest_storm_distance": {
|
|
||||||
"state": weather.nearest_storm_distance,
|
|
||||||
"attributes": {"unit_of_measurement": "mi", "icon": "mdi:weather-lightning"},
|
|
||||||
},
|
|
||||||
"sensor.van_weather_nearest_storm_bearing": {
|
|
||||||
"state": weather.nearest_storm_bearing,
|
|
||||||
"attributes": {"unit_of_measurement": "°", "icon": "mdi:weather-lightning"},
|
|
||||||
},
|
|
||||||
"sensor.van_weather_precip_probability": {
|
|
||||||
"state": int((weather.precip_probability or 0) * 100),
|
|
||||||
"attributes": {"unit_of_measurement": "%", "icon": "mdi:weather-rainy"},
|
|
||||||
},
|
|
||||||
"sensor.van_weather_cloud_cover": {
|
|
||||||
"state": int((weather.cloud_cover or 0) * 100),
|
|
||||||
"attributes": {"unit_of_measurement": "%", "icon": "mdi:weather-cloudy"},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for entity_id, data in sensors.items():
|
for entity_id, data in sensors.items():
|
||||||
@@ -248,7 +209,7 @@ def post_to_ha(url: str, token: str, weather: Weather) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def update_weather(config: Config) -> None:
|
def update_weather(config: Config) -> None:
|
||||||
"""Fetch weather using last-known location, post to HA."""
|
"""Fetch GPS, mask it, get weather, post to HA."""
|
||||||
lat = get_ha_state(config.ha_url, config.ha_token, config.lat_entity)
|
lat = get_ha_state(config.ha_url, config.ha_token, config.lat_entity)
|
||||||
lon = get_ha_state(config.ha_url, config.ha_token, config.lon_entity)
|
lon = get_ha_state(config.ha_url, config.ha_token, config.lon_entity)
|
||||||
|
|
||||||
@@ -257,7 +218,7 @@ def update_weather(config: Config) -> None:
|
|||||||
|
|
||||||
logger.info(f"Masked location: {masked_lat}, {masked_lon}")
|
logger.info(f"Masked location: {masked_lat}, {masked_lon}")
|
||||||
|
|
||||||
weather = fetch_weather(config.pirate_weather_api_key, lat, lon)
|
weather = fetch_weather(config.pirate_weather_api_key, masked_lat, masked_lon)
|
||||||
logger.info(f"Weather: {weather.temperature}°F, {weather.condition}")
|
logger.info(f"Weather: {weather.temperature}°F, {weather.condition}")
|
||||||
|
|
||||||
post_to_ha(config.ha_url, config.ha_token, weather)
|
post_to_ha(config.ha_url, config.ha_token, weather)
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ class Config(BaseModel):
|
|||||||
ha_url: str
|
ha_url: str
|
||||||
ha_token: str
|
ha_token: str
|
||||||
pirate_weather_api_key: str
|
pirate_weather_api_key: str
|
||||||
lat_entity: str = "sensor.van_last_known_latitude"
|
lat_entity: str = "sensor.gps_latitude"
|
||||||
lon_entity: str = "sensor.van_last_known_longitude"
|
lon_entity: str = "sensor.gps_longitude"
|
||||||
mask_decimals: int = 1 # ~11km accuracy
|
mask_decimals: int = 1 # ~11km accuracy
|
||||||
|
|
||||||
|
|
||||||
@@ -24,9 +24,6 @@ class DailyForecast(BaseModel):
|
|||||||
temperature: float | None = None # High
|
temperature: float | None = None # High
|
||||||
templow: float | None = None # Low
|
templow: float | None = None # Low
|
||||||
precipitation_probability: float | None = None
|
precipitation_probability: float | None = None
|
||||||
moon_phase: float | None = None
|
|
||||||
wind_gust: float | None = None
|
|
||||||
cloud_cover: float | None = None
|
|
||||||
|
|
||||||
@field_serializer("date_time")
|
@field_serializer("date_time")
|
||||||
def serialize_date_time(self, date_time: datetime) -> str:
|
def serialize_date_time(self, date_time: datetime) -> str:
|
||||||
@@ -60,11 +57,5 @@ class Weather(BaseModel):
|
|||||||
summary: str | None = None
|
summary: str | None = None
|
||||||
pressure: float | None = None
|
pressure: float | None = None
|
||||||
visibility: float | None = None
|
visibility: float | None = None
|
||||||
uv_index: float | None = None
|
|
||||||
ozone: float | None = None
|
|
||||||
nearest_storm_distance: float | None = None
|
|
||||||
nearest_storm_bearing: float | None = None
|
|
||||||
precip_probability: float | None = None
|
|
||||||
cloud_cover: float | None = None
|
|
||||||
daily_forecasts: list[DailyForecast] = []
|
daily_forecasts: list[DailyForecast] = []
|
||||||
hourly_forecasts: list[HourlyForecast] = []
|
hourly_forecasts: list[HourlyForecast] = []
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
ssh-to-age
|
ssh-to-age
|
||||||
gnupg
|
gnupg
|
||||||
age
|
age
|
||||||
|
|
||||||
audiveris
|
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
host = "0.0.0.0";
|
host = "0.0.0.0";
|
||||||
enable = true;
|
enable = true;
|
||||||
|
|
||||||
syncModels = false;
|
syncModels = true;
|
||||||
loadModels = [
|
loadModels = [
|
||||||
"codellama:7b"
|
"codellama:7b"
|
||||||
"deepscaler:1.5b"
|
"deepscaler:1.5b"
|
||||||
|
|||||||
@@ -57,30 +57,6 @@ automation:
|
|||||||
|
|
||||||
template:
|
template:
|
||||||
- sensor:
|
- sensor:
|
||||||
- name: Van Last Known Latitude
|
|
||||||
unique_id: van_last_known_latitude
|
|
||||||
unit_of_measurement: "°"
|
|
||||||
state: >-
|
|
||||||
{% set lat = states('sensor.gps_latitude')|float(none) %}
|
|
||||||
{% set fix = states('sensor.gps_fix')|int(0) %}
|
|
||||||
{% if lat is not none and fix > 0 %}
|
|
||||||
{{ lat }}
|
|
||||||
{% else %}
|
|
||||||
{{ this.state | default('unavailable', true) }}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
- name: Van Last Known Longitude
|
|
||||||
unique_id: van_last_known_longitude
|
|
||||||
unit_of_measurement: "°"
|
|
||||||
state: >-
|
|
||||||
{% set lon = states('sensor.gps_longitude')|float(none) %}
|
|
||||||
{% set fix = states('sensor.gps_fix')|int(0) %}
|
|
||||||
{% if lon is not none and fix > 0 %}
|
|
||||||
{{ lon }}
|
|
||||||
{% else %}
|
|
||||||
{{ this.state | default('unavailable', true) }}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
- name: GPS Location
|
- name: GPS Location
|
||||||
unique_id: gps_location
|
unique_id: gps_location
|
||||||
state: >-
|
state: >-
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ modbus:
|
|||||||
|
|
||||||
# GPS
|
# GPS
|
||||||
- name: GPS Latitude
|
- name: GPS Latitude
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2800
|
address: 2800
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: int32
|
data_type: int32
|
||||||
@@ -88,7 +88,7 @@ modbus:
|
|||||||
unique_id: gps_latitude
|
unique_id: gps_latitude
|
||||||
|
|
||||||
- name: GPS Longitude
|
- name: GPS Longitude
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2802
|
address: 2802
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: int32
|
data_type: int32
|
||||||
@@ -98,7 +98,7 @@ modbus:
|
|||||||
unique_id: gps_longitude
|
unique_id: gps_longitude
|
||||||
|
|
||||||
- name: GPS Course
|
- name: GPS Course
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2804
|
address: 2804
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: uint16
|
data_type: uint16
|
||||||
@@ -109,7 +109,7 @@ modbus:
|
|||||||
unique_id: gps_course
|
unique_id: gps_course
|
||||||
|
|
||||||
- name: GPS Speed
|
- name: GPS Speed
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2805
|
address: 2805
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: uint16
|
data_type: uint16
|
||||||
@@ -120,7 +120,7 @@ modbus:
|
|||||||
unique_id: gps_speed
|
unique_id: gps_speed
|
||||||
|
|
||||||
- name: GPS Fix
|
- name: GPS Fix
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2806
|
address: 2806
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: uint16
|
data_type: uint16
|
||||||
@@ -129,7 +129,7 @@ modbus:
|
|||||||
unique_id: gps_fix
|
unique_id: gps_fix
|
||||||
|
|
||||||
- name: GPS Satellites
|
- name: GPS Satellites
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2807
|
address: 2807
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: uint16
|
data_type: uint16
|
||||||
@@ -138,7 +138,7 @@ modbus:
|
|||||||
unique_id: gps_satellites
|
unique_id: gps_satellites
|
||||||
|
|
||||||
- name: GPS Altitude
|
- name: GPS Altitude
|
||||||
slave: 1
|
slave: 100
|
||||||
address: 2808
|
address: 2808
|
||||||
input_type: holding
|
input_type: holding
|
||||||
data_type: int32
|
data_type: int32
|
||||||
|
|||||||
@@ -6,13 +6,11 @@
|
|||||||
{
|
{
|
||||||
networking.firewall.allowedTCPPorts = [ 8001 ];
|
networking.firewall.allowedTCPPorts = [ 8001 ];
|
||||||
|
|
||||||
users = {
|
users.users.vaninventory = {
|
||||||
users.vaninventory = {
|
isSystemUser = true;
|
||||||
isSystemUser = true;
|
group = "vaninventory";
|
||||||
group = "vaninventory";
|
|
||||||
};
|
|
||||||
groups.vaninventory = { };
|
|
||||||
};
|
};
|
||||||
|
users.groups.vaninventory = { };
|
||||||
|
|
||||||
systemd.services.van_inventory = {
|
systemd.services.van_inventory = {
|
||||||
description = "Van Inventory API";
|
description = "Van Inventory API";
|
||||||
@@ -33,8 +31,8 @@
|
|||||||
|
|
||||||
serviceConfig = {
|
serviceConfig = {
|
||||||
Type = "simple";
|
Type = "simple";
|
||||||
User = "vaninventory";
|
User = "van-inventory";
|
||||||
Group = "vaninventory";
|
Group = "van-inventory";
|
||||||
ExecStart = "${pkgs.my_python}/bin/python -m python.van_inventory.main --host 0.0.0.0 --port 8001";
|
ExecStart = "${pkgs.my_python}/bin/python -m python.van_inventory.main --host 0.0.0.0 --port 8001";
|
||||||
Restart = "on-failure";
|
Restart = "on-failure";
|
||||||
RestartSec = "5s";
|
RestartSec = "5s";
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ in
|
|||||||
8989
|
8989
|
||||||
];
|
];
|
||||||
virtualisation.oci-containers.containers.signal_cli_rest_api = {
|
virtualisation.oci-containers.containers.signal_cli_rest_api = {
|
||||||
image = "bbernhard/signal-cli-rest-api:0.199-dev";
|
image = "bbernhard/signal-cli-rest-api:latest";
|
||||||
ports = [
|
ports = [
|
||||||
"8989:8080"
|
"8989:8080"
|
||||||
];
|
];
|
||||||
|
|||||||
@@ -30,14 +30,11 @@ in
|
|||||||
local hass hass trust
|
local hass hass trust
|
||||||
local gitea gitea trust
|
local gitea gitea trust
|
||||||
|
|
||||||
# signalbot
|
|
||||||
local richie signalbot trust
|
|
||||||
|
|
||||||
# math
|
# math
|
||||||
local postgres math trust
|
local postgres math trust
|
||||||
host postgres math 127.0.0.1/32 trust
|
host postgres math 127.0.0.1/32 trust
|
||||||
host postgres math ::1/128 trust
|
host postgres math ::1/128 trust
|
||||||
host postgres math 192.168.90.1/24 trust
|
host postgres math 192.168.90.1/24 trust
|
||||||
|
|
||||||
'';
|
'';
|
||||||
|
|
||||||
@@ -101,12 +98,6 @@ in
|
|||||||
replication = true;
|
replication = true;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
{
|
|
||||||
name = "signalbot";
|
|
||||||
ensureClauses = {
|
|
||||||
login = true;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
];
|
];
|
||||||
ensureDatabases = [
|
ensureDatabases = [
|
||||||
"hass"
|
"hass"
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
{
|
|
||||||
pkgs,
|
|
||||||
inputs,
|
|
||||||
...
|
|
||||||
}:
|
|
||||||
let
|
|
||||||
vars = import ../vars.nix;
|
|
||||||
in
|
|
||||||
{
|
|
||||||
users = {
|
|
||||||
users.signalbot = {
|
|
||||||
isSystemUser = true;
|
|
||||||
group = "signalbot";
|
|
||||||
};
|
|
||||||
groups.signalbot = { };
|
|
||||||
};
|
|
||||||
|
|
||||||
systemd.services.signal-bot = {
|
|
||||||
description = "Signal command and control bot";
|
|
||||||
after = [
|
|
||||||
"network.target"
|
|
||||||
"podman-signal_cli_rest_api.service"
|
|
||||||
];
|
|
||||||
wants = [ "podman-signal_cli_rest_api.service" ];
|
|
||||||
wantedBy = [ "multi-user.target" ];
|
|
||||||
|
|
||||||
environment = {
|
|
||||||
PYTHONPATH = "${inputs.self}";
|
|
||||||
SIGNALBOT_DB = "richie";
|
|
||||||
SIGNALBOT_USER = "signalbot";
|
|
||||||
SIGNALBOT_HOST = "/run/postgresql";
|
|
||||||
SIGNALBOT_PORT = "5432";
|
|
||||||
};
|
|
||||||
|
|
||||||
serviceConfig = {
|
|
||||||
Type = "simple";
|
|
||||||
User = "signalbot";
|
|
||||||
Group = "signalbot";
|
|
||||||
EnvironmentFile = "${vars.secrets}/services/signal-bot";
|
|
||||||
ExecStart = "${pkgs.my_python}/bin/python -m python.signal_bot.main";
|
|
||||||
StateDirectory = "signal-bot";
|
|
||||||
Restart = "on-failure";
|
|
||||||
RestartSec = "10s";
|
|
||||||
StandardOutput = "journal";
|
|
||||||
StandardError = "journal";
|
|
||||||
NoNewPrivileges = true;
|
|
||||||
ProtectSystem = "strict";
|
|
||||||
ProtectHome = "read-only";
|
|
||||||
PrivateTmp = true;
|
|
||||||
ReadWritePaths = [ "/var/lib/signal-bot" ];
|
|
||||||
ReadOnlyPaths = [
|
|
||||||
"${inputs.self}"
|
|
||||||
];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,294 +0,0 @@
|
|||||||
import zipfile
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
from typer.testing import CliRunner
|
|
||||||
|
|
||||||
from python.sheet_music_ocr.audiveris import AudiverisError, find_audiveris, run_audiveris
|
|
||||||
from python.sheet_music_ocr.main import SUPPORTED_EXTENSIONS, app, extract_mxml_from_mxl
|
|
||||||
from python.sheet_music_ocr.review import LLMProvider, ReviewError, review_mxml
|
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
|
|
||||||
|
|
||||||
def make_mxl(path, xml_content=b"<score-partwise/>"):
|
|
||||||
"""Create a minimal .mxl (ZIP) file with a MusicXML inside."""
|
|
||||||
with zipfile.ZipFile(path, "w") as zf:
|
|
||||||
zf.writestr("score.xml", xml_content)
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractMxmlFromMxl:
|
|
||||||
def test_extracts_xml(self, tmp_path):
|
|
||||||
mxl = tmp_path / "test.mxl"
|
|
||||||
output = tmp_path / "output.mxml"
|
|
||||||
content = b"<score-partwise>hello</score-partwise>"
|
|
||||||
make_mxl(mxl, content)
|
|
||||||
|
|
||||||
result = extract_mxml_from_mxl(mxl, output)
|
|
||||||
|
|
||||||
assert result == output
|
|
||||||
assert output.read_bytes() == content
|
|
||||||
|
|
||||||
def test_skips_meta_inf(self, tmp_path):
|
|
||||||
mxl = tmp_path / "test.mxl"
|
|
||||||
output = tmp_path / "output.mxml"
|
|
||||||
with zipfile.ZipFile(mxl, "w") as zf:
|
|
||||||
zf.writestr("META-INF/container.xml", "<container/>")
|
|
||||||
zf.writestr("score.xml", b"<score/>")
|
|
||||||
|
|
||||||
extract_mxml_from_mxl(mxl, output)
|
|
||||||
|
|
||||||
assert output.read_bytes() == b"<score/>"
|
|
||||||
|
|
||||||
def test_raises_when_no_xml(self, tmp_path):
|
|
||||||
mxl = tmp_path / "test.mxl"
|
|
||||||
output = tmp_path / "output.mxml"
|
|
||||||
with zipfile.ZipFile(mxl, "w") as zf:
|
|
||||||
zf.writestr("readme.txt", "no xml here")
|
|
||||||
|
|
||||||
with pytest.raises(FileNotFoundError, match="No MusicXML"):
|
|
||||||
extract_mxml_from_mxl(mxl, output)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFindAudiveris:
|
|
||||||
def test_raises_when_not_found(self):
|
|
||||||
with (
|
|
||||||
patch("python.sheet_music_ocr.audiveris.shutil.which", return_value=None),
|
|
||||||
pytest.raises(AudiverisError, match="not found"),
|
|
||||||
):
|
|
||||||
find_audiveris()
|
|
||||||
|
|
||||||
def test_returns_path_when_found(self):
|
|
||||||
with patch("python.sheet_music_ocr.audiveris.shutil.which", return_value="/usr/bin/audiveris"):
|
|
||||||
assert find_audiveris() == "/usr/bin/audiveris"
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunAudiveris:
|
|
||||||
def test_raises_on_nonzero_exit(self, tmp_path):
|
|
||||||
with (
|
|
||||||
patch("python.sheet_music_ocr.audiveris.find_audiveris", return_value="audiveris"),
|
|
||||||
patch("python.sheet_music_ocr.audiveris.subprocess.run") as mock_run,
|
|
||||||
):
|
|
||||||
mock_run.return_value.returncode = 1
|
|
||||||
mock_run.return_value.stderr = "something went wrong"
|
|
||||||
|
|
||||||
with pytest.raises(AudiverisError, match="failed"):
|
|
||||||
run_audiveris(tmp_path / "input.pdf", tmp_path / "output")
|
|
||||||
|
|
||||||
def test_raises_when_no_mxl_produced(self, tmp_path):
|
|
||||||
output_dir = tmp_path / "output"
|
|
||||||
output_dir.mkdir()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("python.sheet_music_ocr.audiveris.find_audiveris", return_value="audiveris"),
|
|
||||||
patch("python.sheet_music_ocr.audiveris.subprocess.run") as mock_run,
|
|
||||||
):
|
|
||||||
mock_run.return_value.returncode = 0
|
|
||||||
|
|
||||||
with pytest.raises(AudiverisError, match=r"no \.mxl output"):
|
|
||||||
run_audiveris(tmp_path / "input.pdf", output_dir)
|
|
||||||
|
|
||||||
def test_returns_mxl_path(self, tmp_path):
|
|
||||||
output_dir = tmp_path / "output"
|
|
||||||
output_dir.mkdir()
|
|
||||||
mxl = output_dir / "score.mxl"
|
|
||||||
make_mxl(mxl)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("python.sheet_music_ocr.audiveris.find_audiveris", return_value="audiveris"),
|
|
||||||
patch("python.sheet_music_ocr.audiveris.subprocess.run") as mock_run,
|
|
||||||
):
|
|
||||||
mock_run.return_value.returncode = 0
|
|
||||||
|
|
||||||
result = run_audiveris(tmp_path / "input.pdf", output_dir)
|
|
||||||
assert result == mxl
|
|
||||||
|
|
||||||
|
|
||||||
class TestCli:
|
|
||||||
def test_missing_input_file(self, tmp_path):
|
|
||||||
result = runner.invoke(app, ["convert", str(tmp_path / "nonexistent.pdf")])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "does not exist" in result.output
|
|
||||||
|
|
||||||
def test_unsupported_format(self, tmp_path):
|
|
||||||
bad_file = tmp_path / "music.bmp"
|
|
||||||
bad_file.touch()
|
|
||||||
result = runner.invoke(app, ["convert", str(bad_file)])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "Unsupported format" in result.output
|
|
||||||
|
|
||||||
def test_supported_extensions_complete(self):
|
|
||||||
assert ".pdf" in SUPPORTED_EXTENSIONS
|
|
||||||
assert ".png" in SUPPORTED_EXTENSIONS
|
|
||||||
assert ".jpg" in SUPPORTED_EXTENSIONS
|
|
||||||
assert ".jpeg" in SUPPORTED_EXTENSIONS
|
|
||||||
assert ".tiff" in SUPPORTED_EXTENSIONS
|
|
||||||
|
|
||||||
def test_successful_conversion(self, tmp_path):
|
|
||||||
input_file = tmp_path / "score.pdf"
|
|
||||||
input_file.touch()
|
|
||||||
output_file = tmp_path / "score.mxml"
|
|
||||||
|
|
||||||
mxl_path = tmp_path / "tmp_mxl" / "score.mxl"
|
|
||||||
mxl_path.parent.mkdir()
|
|
||||||
make_mxl(mxl_path, b"<score-partwise/>")
|
|
||||||
|
|
||||||
with patch("python.sheet_music_ocr.main.run_audiveris", return_value=mxl_path):
|
|
||||||
result = runner.invoke(app, ["convert", str(input_file), "-o", str(output_file)])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert output_file.exists()
|
|
||||||
assert "Written" in result.output
|
|
||||||
|
|
||||||
def test_default_output_path(self, tmp_path):
|
|
||||||
input_file = tmp_path / "score.png"
|
|
||||||
input_file.touch()
|
|
||||||
|
|
||||||
mxl_path = tmp_path / "tmp_mxl" / "score.mxl"
|
|
||||||
mxl_path.parent.mkdir()
|
|
||||||
make_mxl(mxl_path)
|
|
||||||
|
|
||||||
with patch("python.sheet_music_ocr.main.run_audiveris", return_value=mxl_path):
|
|
||||||
result = runner.invoke(app, ["convert", str(input_file)])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert (tmp_path / "score.mxml").exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestReviewMxml:
|
|
||||||
def test_raises_when_no_api_key(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
|
||||||
|
|
||||||
with pytest.raises(ReviewError, match="ANTHROPIC_API_KEY"):
|
|
||||||
review_mxml(mxml, LLMProvider.CLAUDE)
|
|
||||||
|
|
||||||
def test_raises_when_no_openai_key(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
|
||||||
|
|
||||||
with pytest.raises(ReviewError, match="OPENAI_API_KEY"):
|
|
||||||
review_mxml(mxml, LLMProvider.OPENAI)
|
|
||||||
|
|
||||||
def test_claude_success(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
|
||||||
|
|
||||||
corrected = "<score-partwise><part/></score-partwise>"
|
|
||||||
mock_response = httpx.Response(
|
|
||||||
200,
|
|
||||||
json={"content": [{"text": corrected}]},
|
|
||||||
request=httpx.Request("POST", "https://api.anthropic.com/v1/messages"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("python.sheet_music_ocr.review.httpx.post", return_value=mock_response):
|
|
||||||
result = review_mxml(mxml, LLMProvider.CLAUDE)
|
|
||||||
|
|
||||||
assert result == corrected
|
|
||||||
|
|
||||||
def test_openai_success(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
|
||||||
|
|
||||||
corrected = "<score-partwise><part/></score-partwise>"
|
|
||||||
mock_response = httpx.Response(
|
|
||||||
200,
|
|
||||||
json={"choices": [{"message": {"content": corrected}}]},
|
|
||||||
request=httpx.Request("POST", "https://api.openai.com/v1/chat/completions"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("python.sheet_music_ocr.review.httpx.post", return_value=mock_response):
|
|
||||||
result = review_mxml(mxml, LLMProvider.OPENAI)
|
|
||||||
|
|
||||||
assert result == corrected
|
|
||||||
|
|
||||||
def test_claude_api_error(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
|
||||||
500,
|
|
||||||
text="Internal Server Error",
|
|
||||||
request=httpx.Request("POST", "https://api.anthropic.com/v1/messages"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("python.sheet_music_ocr.review.httpx.post", return_value=mock_response),
|
|
||||||
pytest.raises(ReviewError, match="Claude API error"),
|
|
||||||
):
|
|
||||||
review_mxml(mxml, LLMProvider.CLAUDE)
|
|
||||||
|
|
||||||
def test_openai_api_error(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
|
||||||
429,
|
|
||||||
text="Rate limited",
|
|
||||||
request=httpx.Request("POST", "https://api.openai.com/v1/chat/completions"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("python.sheet_music_ocr.review.httpx.post", return_value=mock_response),
|
|
||||||
pytest.raises(ReviewError, match="OpenAI API error"),
|
|
||||||
):
|
|
||||||
review_mxml(mxml, LLMProvider.OPENAI)
|
|
||||||
|
|
||||||
|
|
||||||
class TestReviewCli:
|
|
||||||
def test_missing_input_file(self, tmp_path):
|
|
||||||
result = runner.invoke(app, ["review", str(tmp_path / "nonexistent.mxml")])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "does not exist" in result.output
|
|
||||||
|
|
||||||
def test_wrong_extension(self, tmp_path):
|
|
||||||
bad_file = tmp_path / "score.pdf"
|
|
||||||
bad_file.touch()
|
|
||||||
result = runner.invoke(app, ["review", str(bad_file)])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert ".mxml" in result.output
|
|
||||||
|
|
||||||
def test_successful_review(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
output = tmp_path / "corrected.mxml"
|
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
|
||||||
|
|
||||||
corrected = "<score-partwise><part/></score-partwise>"
|
|
||||||
mock_response = httpx.Response(
|
|
||||||
200,
|
|
||||||
json={"content": [{"text": corrected}]},
|
|
||||||
request=httpx.Request("POST", "https://api.anthropic.com/v1/messages"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("python.sheet_music_ocr.review.httpx.post", return_value=mock_response):
|
|
||||||
result = runner.invoke(app, ["review", str(mxml), "-o", str(output)])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "Reviewed" in result.output
|
|
||||||
assert output.read_text() == corrected
|
|
||||||
|
|
||||||
def test_overwrites_input_by_default(self, tmp_path, monkeypatch):
|
|
||||||
mxml = tmp_path / "score.mxml"
|
|
||||||
mxml.write_text("<score-partwise/>")
|
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
|
||||||
|
|
||||||
corrected = "<score-partwise><part/></score-partwise>"
|
|
||||||
mock_response = httpx.Response(
|
|
||||||
200,
|
|
||||||
json={"content": [{"text": corrected}]},
|
|
||||||
request=httpx.Request("POST", "https://api.anthropic.com/v1/messages"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("python.sheet_music_ocr.review.httpx.post", return_value=mock_response):
|
|
||||||
result = runner.invoke(app, ["review", str(mxml)])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert mxml.read_text() == corrected
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
"""Tests for the Signal command and control bot."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import timedelta
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
|
|
||||||
from python.orm.richie.base import RichieBase
|
|
||||||
from python.signal_bot.commands.inventory import (
|
|
||||||
_format_summary,
|
|
||||||
parse_llm_response,
|
|
||||||
)
|
|
||||||
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
|
|
||||||
from python.signal_bot.llm_client import LLMClient
|
|
||||||
from python.signal_bot.main import dispatch
|
|
||||||
from python.signal_bot.models import (
|
|
||||||
BotConfig,
|
|
||||||
InventoryItem,
|
|
||||||
SignalMessage,
|
|
||||||
TrustLevel,
|
|
||||||
)
|
|
||||||
from python.signal_bot.signal_client import SignalClient
|
|
||||||
|
|
||||||
|
|
||||||
class TestModels:
|
|
||||||
def test_trust_level_values(self):
|
|
||||||
assert TrustLevel.VERIFIED == "verified"
|
|
||||||
assert TrustLevel.UNVERIFIED == "unverified"
|
|
||||||
assert TrustLevel.BLOCKED == "blocked"
|
|
||||||
|
|
||||||
def test_signal_message_defaults(self):
|
|
||||||
msg = SignalMessage(source="+1234", timestamp=0)
|
|
||||||
assert msg.message == ""
|
|
||||||
assert msg.attachments == []
|
|
||||||
assert msg.group_id is None
|
|
||||||
|
|
||||||
def test_inventory_item_defaults(self):
|
|
||||||
item = InventoryItem(name="wrench")
|
|
||||||
assert item.quantity == 1
|
|
||||||
assert item.unit == "each"
|
|
||||||
assert item.category == ""
|
|
||||||
|
|
||||||
|
|
||||||
class TestInventoryParsing:
|
|
||||||
def test_parse_llm_response_basic(self):
|
|
||||||
raw = '[{"name": "water", "quantity": 6, "unit": "gallon", "category": "supplies", "notes": ""}]'
|
|
||||||
items = parse_llm_response(raw)
|
|
||||||
assert len(items) == 1
|
|
||||||
assert items[0].name == "water"
|
|
||||||
assert items[0].quantity == 6
|
|
||||||
assert items[0].unit == "gallon"
|
|
||||||
|
|
||||||
def test_parse_llm_response_with_code_fence(self):
|
|
||||||
raw = '```json\n[{"name": "tape", "quantity": 1, "unit": "each", "category": "tools", "notes": ""}]\n```'
|
|
||||||
items = parse_llm_response(raw)
|
|
||||||
assert len(items) == 1
|
|
||||||
assert items[0].name == "tape"
|
|
||||||
|
|
||||||
def test_parse_llm_response_invalid_json(self):
|
|
||||||
with pytest.raises(json.JSONDecodeError):
|
|
||||||
parse_llm_response("not json at all")
|
|
||||||
|
|
||||||
def test_format_summary(self):
|
|
||||||
items = [InventoryItem(name="water", quantity=6, unit="gallon", category="supplies")]
|
|
||||||
summary = _format_summary(items)
|
|
||||||
assert "water" in summary
|
|
||||||
assert "x6" in summary
|
|
||||||
assert "gallon" in summary
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeviceRegistry:
|
|
||||||
@pytest.fixture
|
|
||||||
def signal_mock(self):
|
|
||||||
return MagicMock(spec=SignalClient)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def engine(self):
|
|
||||||
engine = create_engine("sqlite://")
|
|
||||||
RichieBase.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def registry(self, signal_mock, engine):
|
|
||||||
return DeviceRegistry(signal_mock, engine)
|
|
||||||
|
|
||||||
def test_new_device_is_unverified(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc123")
|
|
||||||
assert not registry.is_verified("+1234")
|
|
||||||
|
|
||||||
def test_verify_device(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc123")
|
|
||||||
assert registry.verify("+1234")
|
|
||||||
assert registry.is_verified("+1234")
|
|
||||||
|
|
||||||
def test_verify_unknown_device(self, registry):
|
|
||||||
assert not registry.verify("+9999")
|
|
||||||
|
|
||||||
def test_block_device(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc123")
|
|
||||||
assert registry.block("+1234")
|
|
||||||
assert not registry.is_verified("+1234")
|
|
||||||
|
|
||||||
def test_safety_number_change_resets_trust(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc123")
|
|
||||||
registry.verify("+1234")
|
|
||||||
assert registry.is_verified("+1234")
|
|
||||||
registry.record_contact("+1234", "different_safety_number")
|
|
||||||
assert not registry.is_verified("+1234")
|
|
||||||
|
|
||||||
def test_persistence(self, signal_mock, engine):
|
|
||||||
reg1 = DeviceRegistry(signal_mock, engine)
|
|
||||||
reg1.record_contact("+1234", "abc123")
|
|
||||||
reg1.verify("+1234")
|
|
||||||
|
|
||||||
reg2 = DeviceRegistry(signal_mock, engine)
|
|
||||||
assert reg2.is_verified("+1234")
|
|
||||||
|
|
||||||
def test_list_devices(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
registry.record_contact("+5678", "def")
|
|
||||||
assert len(registry.list_devices()) == 2
|
|
||||||
|
|
||||||
|
|
||||||
class TestContactCache:
|
|
||||||
@pytest.fixture
|
|
||||||
def signal_mock(self):
|
|
||||||
return MagicMock(spec=SignalClient)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def engine(self):
|
|
||||||
engine = create_engine("sqlite://")
|
|
||||||
RichieBase.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def registry(self, signal_mock, engine):
|
|
||||||
return DeviceRegistry(signal_mock, engine)
|
|
||||||
|
|
||||||
def test_second_call_uses_cache(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
assert "+1234" in registry._contact_cache
|
|
||||||
|
|
||||||
with patch.object(registry, "engine") as mock_engine:
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
mock_engine.assert_not_called()
|
|
||||||
|
|
||||||
def test_unverified_gets_default_ttl(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
from python.common import utcnow
|
|
||||||
|
|
||||||
entry = registry._contact_cache["+1234"]
|
|
||||||
expected = utcnow() + _DEFAULT_TTL
|
|
||||||
assert abs((entry.expires - expected).total_seconds()) < 2
|
|
||||||
assert entry.trust_level == TrustLevel.UNVERIFIED
|
|
||||||
assert entry.has_safety_number is True
|
|
||||||
|
|
||||||
def test_blocked_gets_blocked_ttl(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
registry.block("+1234")
|
|
||||||
from python.common import utcnow
|
|
||||||
|
|
||||||
entry = registry._contact_cache["+1234"]
|
|
||||||
expected = utcnow() + _BLOCKED_TTL
|
|
||||||
assert abs((entry.expires - expected).total_seconds()) < 2
|
|
||||||
assert entry.trust_level == TrustLevel.BLOCKED
|
|
||||||
|
|
||||||
def test_verify_updates_cache(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
registry.verify("+1234")
|
|
||||||
entry = registry._contact_cache["+1234"]
|
|
||||||
assert entry.trust_level == TrustLevel.VERIFIED
|
|
||||||
|
|
||||||
def test_block_updates_cache(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
registry.block("+1234")
|
|
||||||
entry = registry._contact_cache["+1234"]
|
|
||||||
assert entry.trust_level == TrustLevel.BLOCKED
|
|
||||||
|
|
||||||
def test_unverify_updates_cache(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
registry.verify("+1234")
|
|
||||||
registry.unverify("+1234")
|
|
||||||
entry = registry._contact_cache["+1234"]
|
|
||||||
assert entry.trust_level == TrustLevel.UNVERIFIED
|
|
||||||
|
|
||||||
def test_is_verified_uses_cache(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
registry.verify("+1234")
|
|
||||||
with patch.object(registry, "engine") as mock_engine:
|
|
||||||
assert registry.is_verified("+1234") is True
|
|
||||||
mock_engine.assert_not_called()
|
|
||||||
|
|
||||||
def test_has_safety_number_uses_cache(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
with patch.object(registry, "engine") as mock_engine:
|
|
||||||
assert registry.has_safety_number("+1234") is True
|
|
||||||
mock_engine.assert_not_called()
|
|
||||||
|
|
||||||
def test_no_safety_number_cached(self, registry):
|
|
||||||
registry.record_contact("+1234", None)
|
|
||||||
with patch.object(registry, "engine") as mock_engine:
|
|
||||||
assert registry.has_safety_number("+1234") is False
|
|
||||||
mock_engine.assert_not_called()
|
|
||||||
|
|
||||||
def test_expired_cache_hits_db(self, registry):
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
old = registry._contact_cache["+1234"]
|
|
||||||
registry._contact_cache["+1234"] = _CacheEntry(
|
|
||||||
expires=old.expires - timedelta(minutes=10),
|
|
||||||
trust_level=old.trust_level,
|
|
||||||
has_safety_number=old.has_safety_number,
|
|
||||||
safety_number=old.safety_number,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
||||||
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
|
||||||
mock_device = MagicMock()
|
|
||||||
mock_device.trust_level = TrustLevel.UNVERIFIED
|
|
||||||
mock_session.execute.return_value.scalar_one_or_none.return_value = mock_device
|
|
||||||
registry.record_contact("+1234", "abc")
|
|
||||||
mock_session.execute.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDispatch:
|
|
||||||
@pytest.fixture
|
|
||||||
def signal_mock(self):
|
|
||||||
return MagicMock(spec=SignalClient)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def llm_mock(self):
|
|
||||||
return MagicMock(spec=LLMClient)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def registry_mock(self):
|
|
||||||
mock = MagicMock(spec=DeviceRegistry)
|
|
||||||
mock.is_verified.return_value = True
|
|
||||||
mock.has_safety_number.return_value = True
|
|
||||||
return mock
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config(self):
|
|
||||||
engine = create_engine("sqlite://")
|
|
||||||
return BotConfig(
|
|
||||||
signal_api_url="http://localhost:8080",
|
|
||||||
phone_number="+1234567890",
|
|
||||||
inventory_api_url="http://localhost:9090",
|
|
||||||
engine=engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
|
||||||
registry_mock.is_verified.return_value = False
|
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
|
||||||
signal_mock.reply.assert_not_called()
|
|
||||||
|
|
||||||
def test_help_command(self, signal_mock, llm_mock, registry_mock, config):
|
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
|
||||||
signal_mock.reply.assert_called_once()
|
|
||||||
assert "Available commands" in signal_mock.reply.call_args[0][1]
|
|
||||||
|
|
||||||
def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
|
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
|
||||||
signal_mock.reply.assert_not_called()
|
|
||||||
|
|
||||||
def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
|
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
|
||||||
signal_mock.reply.assert_not_called()
|
|
||||||
|
|
||||||
def test_status_command(self, signal_mock, llm_mock, registry_mock, config):
|
|
||||||
llm_mock.list_models.return_value = ["model1", "model2"]
|
|
||||||
llm_mock.model = "test:7b"
|
|
||||||
registry_mock.list_devices.return_value = []
|
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="status")
|
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
|
||||||
signal_mock.reply.assert_called_once()
|
|
||||||
assert "Bot online" in signal_mock.reply.call_args[0][1]
|
|
||||||
@@ -12,7 +12,6 @@
|
|||||||
obs-studio
|
obs-studio
|
||||||
obsidian
|
obsidian
|
||||||
vlc
|
vlc
|
||||||
qalculate-gtk
|
|
||||||
# graphics tools
|
# graphics tools
|
||||||
gimp3
|
gimp3
|
||||||
xcursorgen
|
xcursorgen
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
obsidian
|
obsidian
|
||||||
prismlauncher
|
prismlauncher
|
||||||
prusa-slicer
|
prusa-slicer
|
||||||
qalculate-gtk
|
|
||||||
vlc
|
vlc
|
||||||
# browser
|
# browser
|
||||||
chromium
|
chromium
|
||||||
|
|||||||
Reference in New Issue
Block a user