From 7ad321e5e202377429581cf8ad60009db4a2f130 Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Mon, 9 Mar 2026 15:07:48 -0400 Subject: [PATCH] moved device registry to postgresql --- python/alembic/env.py | 12 ++ ...device_for_deviceregistry__4c410c16e39c.py | 58 +++++++ python/database_cli.py | 5 +- python/orm/richie/__init__.py | 2 + python/orm/richie/signal_device.py | 26 +++ python/signal_bot/device_registry.py | 148 +++++++++--------- python/signal_bot/main.py | 5 +- 7 files changed, 177 insertions(+), 79 deletions(-) create mode 100644 python/alembic/richie/versions/2026_03_09-adding_signaldevice_for_deviceregistry__4c410c16e39c.py create mode 100644 python/orm/richie/signal_device.py diff --git a/python/alembic/env.py b/python/alembic/env.py index 667e04b..34a2a07 100644 --- a/python/alembic/env.py +++ b/python/alembic/env.py @@ -45,6 +45,18 @@ def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None: Path(filename).write_text(dynamic_schema_file) +@write_hooks.register("import_postgresql") +def import_postgresql(filename: str, _options: dict[Any, Any]) -> None: + """Add postgresql dialect import when postgresql types are used.""" + content = Path(filename).read_text() + if "postgresql." in content and "from sqlalchemy.dialects import postgresql" not in content: + content = content.replace( + "import sqlalchemy as sa\n", + "import sqlalchemy as sa\nfrom sqlalchemy.dialects import postgresql\n", + ) + Path(filename).write_text(content) + + @write_hooks.register("ruff") def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None: """Docstring for ruff_check_and_format.""" diff --git a/python/alembic/richie/versions/2026_03_09-adding_signaldevice_for_deviceregistry__4c410c16e39c.py b/python/alembic/richie/versions/2026_03_09-adding_signaldevice_for_deviceregistry__4c410c16e39c.py new file mode 100644 index 0000000..3583e38 --- /dev/null +++ b/python/alembic/richie/versions/2026_03_09-adding_signaldevice_for_deviceregistry__4c410c16e39c.py @@ -0,0 +1,58 @@ +"""adding SignalDevice for DeviceRegistry for signal bot. + +Revision ID: 4c410c16e39c +Revises: 3f71565e38de +Create Date: 2026-03-09 14:51:24.228976 + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +from python.orm import RichieBase + +if TYPE_CHECKING: + from collections.abc import Sequence + +# revision identifiers, used by Alembic. +revision: str = "4c410c16e39c" +down_revision: str | None = "3f71565e38de" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +schema = RichieBase.schema_name + + +def upgrade() -> None: + """Upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "signal_device", + sa.Column("phone_number", sa.String(length=50), nullable=False), + sa.Column("safety_number", sa.String(), nullable=False), + sa.Column( + "trust_level", + postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema), + nullable=False, + ), + sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_signal_device")), + sa.UniqueConstraint("phone_number", name=op.f("uq_signal_device_phone_number")), + schema=schema, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("signal_device", schema=schema) + # ### end Alembic commands ### diff --git a/python/database_cli.py b/python/database_cli.py index e3603c3..e28c19b 100644 --- a/python/database_cli.py +++ b/python/database_cli.py @@ -58,8 +58,9 @@ class DatabaseConfig: cfg.set_main_option("version_path_separator", "os") cfg.set_main_option("version_locations", self.version_location) cfg.set_main_option("revision_environment", "true") - cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,ruff") + cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,import_postgresql,ruff") cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema") + cfg.set_section_option("post_write_hooks", "import_postgresql.type", "import_postgresql") cfg.set_section_option("post_write_hooks", "ruff.type", "ruff") cfg.attributes["base"] = self.get_base() cfg.attributes["env_prefix"] = self.env_prefix @@ -73,7 +74,7 @@ DATABASES: dict[str, DatabaseConfig] = { version_location="python/alembic/richie/versions", base_module="python.orm.richie.base", base_class_name="RichieBase", - models_module="python.orm.richie.contact", + models_module="python.orm.richie", ), "van_inventory": DatabaseConfig( env_prefix="VAN_INVENTORY", diff --git a/python/orm/richie/__init__.py b/python/orm/richie/__init__.py index 762387d..0ed7d9a 100644 --- a/python/orm/richie/__init__.py +++ b/python/orm/richie/__init__.py @@ -11,6 +11,7 @@ from python.orm.richie.contact import ( Need, RelationshipType, ) +from python.orm.richie.signal_device import SignalDevice __all__ = [ "Bill", @@ -21,6 +22,7 @@ __all__ = [ "Need", "RelationshipType", "RichieBase", + "SignalDevice", "TableBase", "Vote", "VoteRecord", diff --git a/python/orm/richie/signal_device.py b/python/orm/richie/signal_device.py new file mode 100644 index 0000000..a2650df --- /dev/null +++ b/python/orm/richie/signal_device.py @@ -0,0 +1,26 @@ +"""Signal bot device registry models.""" + +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import DateTime, String +from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy.orm import Mapped, mapped_column + +from python.orm.richie.base import TableBase +from python.signal_bot.models import TrustLevel + + +class SignalDevice(TableBase): + """A Signal device tracked by phone number and safety number.""" + + __tablename__ = "signal_device" + + phone_number: Mapped[str] = mapped_column(String(50), unique=True) + safety_number: Mapped[str] + trust_level: Mapped[TrustLevel] = mapped_column( + ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"), + default=TrustLevel.UNVERIFIED, + ) + last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True)) diff --git a/python/signal_bot/device_registry.py b/python/signal_bot/device_registry.py index 3f5f03f..4fafe4d 100644 --- a/python/signal_bot/device_registry.py +++ b/python/signal_bot/device_registry.py @@ -2,15 +2,18 @@ from __future__ import annotations -import json import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING + +from sqlalchemy import select +from sqlalchemy.orm import Session from python.common import utcnow -from python.signal_bot.models import Device, TrustLevel +from python.orm.richie.signal_device import SignalDevice +from python.signal_bot.models import TrustLevel if TYPE_CHECKING: - from pathlib import Path + from sqlalchemy.engine import Engine from python.signal_bot.signal_client import SignalClient @@ -25,90 +28,77 @@ class DeviceRegistry: signal-cli to trust the identity. Only VERIFIED devices may execute commands. - - Args: - signal_client: The Signal API client (used to sync identities). - registry_path: Path to the JSON file that persists device state. """ - def __init__(self, signal_client: SignalClient, registry_path: Path) -> None: + def __init__(self, signal_client: SignalClient, engine: Engine) -> None: self.signal_client = signal_client - self.registry_path = registry_path - self._devices: dict[str, Device] = {} - self._load() + self.engine = engine def is_verified(self, phone_number: str) -> bool: """Check if a phone number is verified.""" - device = self._devices.get(phone_number) + device = self._get(phone_number) return device is not None and device.trust_level == TrustLevel.VERIFIED - def is_blocked(self, phone_number: str) -> bool: - """Check if a phone number is blocked.""" - device = self._devices.get(phone_number) - return device is not None and device.trust_level == TrustLevel.BLOCKED - - def record_contact(self, phone_number: str, safety_number: str) -> Device: + def record_contact(self, phone_number: str, safety_number: str) -> SignalDevice: """Record seeing a device. Creates entry if new, updates last_seen.""" now = utcnow() - if phone_number in self._devices: - device = self._devices[phone_number] - if device.safety_number != safety_number: - logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED") - device.safety_number = safety_number - device.trust_level = TrustLevel.UNVERIFIED - device.last_seen = now - else: - device = Device( - phone_number=phone_number, - safety_number=safety_number, - trust_level=TrustLevel.UNVERIFIED, - first_seen=now, - last_seen=now, - ) - self._devices[phone_number] = device - logger.info(f"New device registered: {phone_number}") + with Session(self.engine) as session: + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() - self._save() - return device + if device: + if device.safety_number != safety_number: + logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED") + device.safety_number = safety_number + device.trust_level = TrustLevel.UNVERIFIED + device.last_seen = now + else: + device = SignalDevice( + phone_number=phone_number, + safety_number=safety_number, + trust_level=TrustLevel.UNVERIFIED, + last_seen=now, + ) + session.add(device) + logger.info(f"New device registered: {phone_number}") + + session.commit() + session.refresh(device) + return device def verify(self, phone_number: str) -> bool: """Mark a device as verified. Called by admin over SSH. Returns True if the device was found and verified. """ - device = self._devices.get(phone_number) - if not device: - logger.warning(f"Cannot verify unknown device: {phone_number}") - return False + with Session(self.engine) as session: + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() - device.trust_level = TrustLevel.VERIFIED - self.signal_client.trust_identity(phone_number, trust_all_known_keys=True) - self._save() - logger.info(f"Device verified: {phone_number}") - return True + if not device: + logger.warning(f"Cannot verify unknown device: {phone_number}") + return False + + device.trust_level = TrustLevel.VERIFIED + self.signal_client.trust_identity(phone_number, trust_all_known_keys=True) + session.commit() + logger.info(f"Device verified: {phone_number}") + return True def block(self, phone_number: str) -> bool: """Block a device.""" - device = self._devices.get(phone_number) - if not device: - return False - device.trust_level = TrustLevel.BLOCKED - self._save() - logger.info(f"Device blocked: {phone_number}") - return True + return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked") def unverify(self, phone_number: str) -> bool: """Reset a device to unverified.""" - device = self._devices.get(phone_number) - if not device: - return False - device.trust_level = TrustLevel.UNVERIFIED - self._save() - return True + return self._set_trust(phone_number, TrustLevel.UNVERIFIED) - def list_devices(self) -> list[Device]: + def list_devices(self) -> list[SignalDevice]: """Return all known devices.""" - return list(self._devices.values()) + with Session(self.engine) as session: + return list(session.execute(select(SignalDevice)).scalars().all()) def sync_identities(self) -> None: """Pull identity list from signal-cli and record any new ones.""" @@ -119,17 +109,25 @@ class DeviceRegistry: if number: self.record_contact(number, safety) - def _load(self) -> None: - """Load registry from disk.""" - if not self.registry_path.exists(): - return - data: list[dict[str, Any]] = json.loads(self.registry_path.read_text()) - for entry in data: - device = Device.model_validate(entry) - self._devices[device.phone_number] = device + def _get(self, phone_number: str) -> SignalDevice | None: + """Fetch a device by phone number.""" + with Session(self.engine) as session: + return session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() - def _save(self) -> None: - """Persist registry to disk.""" - self.registry_path.parent.mkdir(parents=True, exist_ok=True) - data = [device.model_dump(mode="json") for device in self._devices.values()] - self.registry_path.write_text(json.dumps(data, indent=2) + "\n") + def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool: + """Update the trust level for a device.""" + with Session(self.engine) as session: + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() + + if not device: + return False + + device.trust_level = level + session.commit() + if log_msg: + logger.info(f"{log_msg}: {phone_number}") + return True diff --git a/python/signal_bot/main.py b/python/signal_bot/main.py index c6ce0bd..6b16595 100644 --- a/python/signal_bot/main.py +++ b/python/signal_bot/main.py @@ -10,6 +10,7 @@ from typing import Annotated import typer from python.common import configure_logger +from python.orm.common import get_postgres_engine from python.signal_bot.commands.inventory import handle_inventory_update from python.signal_bot.device_registry import DeviceRegistry from python.signal_bot.llm_client import LLMClient @@ -153,7 +154,6 @@ def main( llm_model: Annotated[str, typer.Option(envvar="LLM_MODEL")] = "qwen3-vl:32b", llm_port: Annotated[int, typer.Option(envvar="LLM_PORT")] = 11434, inventory_file: Annotated[str, typer.Option(envvar="INVENTORY_FILE")] = "/var/lib/signal-bot/van_inventory.json", - registry_file: Annotated[str, typer.Option(envvar="REGISTRY_FILE")] = "/var/lib/signal-bot/device_registry.json", log_level: Annotated[str, typer.Option()] = "INFO", ) -> None: """Run the Signal command and control bot.""" @@ -165,12 +165,13 @@ def main( inventory_file=inventory_file, ) + engine = get_postgres_engine(name="RICHIE") with ( SignalClient(config.signal_api_url, config.phone_number) as signal, LLMClient(model=llm_model, host=llm_host, port=llm_port) as llm, ): - registry = DeviceRegistry(signal, Path(registry_file)) + registry = DeviceRegistry(signal, engine) run_loop(config, signal, llm, registry)