moved device registry to postgresql

This commit is contained in:
2026-03-09 15:07:48 -04:00
parent 14338e34df
commit 7ad321e5e2
7 changed files with 177 additions and 79 deletions

View File

@@ -45,6 +45,18 @@ def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
Path(filename).write_text(dynamic_schema_file) Path(filename).write_text(dynamic_schema_file)
@write_hooks.register("import_postgresql")
def import_postgresql(filename: str, _options: dict[Any, Any]) -> None:
"""Add postgresql dialect import when postgresql types are used."""
content = Path(filename).read_text()
if "postgresql." in content and "from sqlalchemy.dialects import postgresql" not in content:
content = content.replace(
"import sqlalchemy as sa\n",
"import sqlalchemy as sa\nfrom sqlalchemy.dialects import postgresql\n",
)
Path(filename).write_text(content)
@write_hooks.register("ruff") @write_hooks.register("ruff")
def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None: def ruff_check_and_format(filename: str, _options: dict[Any, Any]) -> None:
"""Docstring for ruff_check_and_format.""" """Docstring for ruff_check_and_format."""

View File

@@ -0,0 +1,58 @@
"""adding SignalDevice for DeviceRegistry for signal bot.
Revision ID: 4c410c16e39c
Revises: 3f71565e38de
Create Date: 2026-03-09 14:51:24.228976
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "4c410c16e39c"
down_revision: str | None = "3f71565e38de"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"signal_device",
sa.Column("phone_number", sa.String(length=50), nullable=False),
sa.Column("safety_number", sa.String(), nullable=False),
sa.Column(
"trust_level",
postgresql.ENUM("VERIFIED", "UNVERIFIED", "BLOCKED", name="trust_level", schema=schema),
nullable=False,
),
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_signal_device")),
sa.UniqueConstraint("phone_number", name=op.f("uq_signal_device_phone_number")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("signal_device", schema=schema)
# ### end Alembic commands ###

View File

@@ -58,8 +58,9 @@ class DatabaseConfig:
cfg.set_main_option("version_path_separator", "os") cfg.set_main_option("version_path_separator", "os")
cfg.set_main_option("version_locations", self.version_location) cfg.set_main_option("version_locations", self.version_location)
cfg.set_main_option("revision_environment", "true") cfg.set_main_option("revision_environment", "true")
cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,ruff") cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,import_postgresql,ruff")
cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema") cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema")
cfg.set_section_option("post_write_hooks", "import_postgresql.type", "import_postgresql")
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff") cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
cfg.attributes["base"] = self.get_base() cfg.attributes["base"] = self.get_base()
cfg.attributes["env_prefix"] = self.env_prefix cfg.attributes["env_prefix"] = self.env_prefix
@@ -73,7 +74,7 @@ DATABASES: dict[str, DatabaseConfig] = {
version_location="python/alembic/richie/versions", version_location="python/alembic/richie/versions",
base_module="python.orm.richie.base", base_module="python.orm.richie.base",
base_class_name="RichieBase", base_class_name="RichieBase",
models_module="python.orm.richie.contact", models_module="python.orm.richie",
), ),
"van_inventory": DatabaseConfig( "van_inventory": DatabaseConfig(
env_prefix="VAN_INVENTORY", env_prefix="VAN_INVENTORY",

View File

@@ -11,6 +11,7 @@ from python.orm.richie.contact import (
Need, Need,
RelationshipType, RelationshipType,
) )
from python.orm.richie.signal_device import SignalDevice
__all__ = [ __all__ = [
"Bill", "Bill",
@@ -21,6 +22,7 @@ __all__ = [
"Need", "Need",
"RelationshipType", "RelationshipType",
"RichieBase", "RichieBase",
"SignalDevice",
"TableBase", "TableBase",
"Vote", "Vote",
"VoteRecord", "VoteRecord",

View File

@@ -0,0 +1,26 @@
"""Signal bot device registry models."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, String
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.orm import Mapped, mapped_column
from python.orm.richie.base import TableBase
from python.signal_bot.models import TrustLevel
class SignalDevice(TableBase):
"""A Signal device tracked by phone number and safety number."""
__tablename__ = "signal_device"
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
safety_number: Mapped[str]
trust_level: Mapped[TrustLevel] = mapped_column(
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"),
default=TrustLevel.UNVERIFIED,
)
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))

View File

@@ -2,15 +2,18 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from sqlalchemy import select
from sqlalchemy.orm import Session
from python.common import utcnow from python.common import utcnow
from python.signal_bot.models import Device, TrustLevel from python.orm.richie.signal_device import SignalDevice
from python.signal_bot.models import TrustLevel
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path from sqlalchemy.engine import Engine
from python.signal_bot.signal_client import SignalClient from python.signal_bot.signal_client import SignalClient
@@ -25,50 +28,43 @@ class DeviceRegistry:
signal-cli to trust the identity. signal-cli to trust the identity.
Only VERIFIED devices may execute commands. Only VERIFIED devices may execute commands.
Args:
signal_client: The Signal API client (used to sync identities).
registry_path: Path to the JSON file that persists device state.
""" """
def __init__(self, signal_client: SignalClient, registry_path: Path) -> None: def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
self.signal_client = signal_client self.signal_client = signal_client
self.registry_path = registry_path self.engine = engine
self._devices: dict[str, Device] = {}
self._load()
def is_verified(self, phone_number: str) -> bool: def is_verified(self, phone_number: str) -> bool:
"""Check if a phone number is verified.""" """Check if a phone number is verified."""
device = self._devices.get(phone_number) device = self._get(phone_number)
return device is not None and device.trust_level == TrustLevel.VERIFIED return device is not None and device.trust_level == TrustLevel.VERIFIED
def is_blocked(self, phone_number: str) -> bool: def record_contact(self, phone_number: str, safety_number: str) -> SignalDevice:
"""Check if a phone number is blocked."""
device = self._devices.get(phone_number)
return device is not None and device.trust_level == TrustLevel.BLOCKED
def record_contact(self, phone_number: str, safety_number: str) -> Device:
"""Record seeing a device. Creates entry if new, updates last_seen.""" """Record seeing a device. Creates entry if new, updates last_seen."""
now = utcnow() now = utcnow()
if phone_number in self._devices: with Session(self.engine) as session:
device = self._devices[phone_number] device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if device:
if device.safety_number != safety_number: if device.safety_number != safety_number:
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED") logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
device.safety_number = safety_number device.safety_number = safety_number
device.trust_level = TrustLevel.UNVERIFIED device.trust_level = TrustLevel.UNVERIFIED
device.last_seen = now device.last_seen = now
else: else:
device = Device( device = SignalDevice(
phone_number=phone_number, phone_number=phone_number,
safety_number=safety_number, safety_number=safety_number,
trust_level=TrustLevel.UNVERIFIED, trust_level=TrustLevel.UNVERIFIED,
first_seen=now,
last_seen=now, last_seen=now,
) )
self._devices[phone_number] = device session.add(device)
logger.info(f"New device registered: {phone_number}") logger.info(f"New device registered: {phone_number}")
self._save() session.commit()
session.refresh(device)
return device return device
def verify(self, phone_number: str) -> bool: def verify(self, phone_number: str) -> bool:
@@ -76,39 +72,33 @@ class DeviceRegistry:
Returns True if the device was found and verified. Returns True if the device was found and verified.
""" """
device = self._devices.get(phone_number) with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device: if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}") logger.warning(f"Cannot verify unknown device: {phone_number}")
return False return False
device.trust_level = TrustLevel.VERIFIED device.trust_level = TrustLevel.VERIFIED
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True) self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
self._save() session.commit()
logger.info(f"Device verified: {phone_number}") logger.info(f"Device verified: {phone_number}")
return True return True
def block(self, phone_number: str) -> bool: def block(self, phone_number: str) -> bool:
"""Block a device.""" """Block a device."""
device = self._devices.get(phone_number) return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
if not device:
return False
device.trust_level = TrustLevel.BLOCKED
self._save()
logger.info(f"Device blocked: {phone_number}")
return True
def unverify(self, phone_number: str) -> bool: def unverify(self, phone_number: str) -> bool:
"""Reset a device to unverified.""" """Reset a device to unverified."""
device = self._devices.get(phone_number) return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
if not device:
return False
device.trust_level = TrustLevel.UNVERIFIED
self._save()
return True
def list_devices(self) -> list[Device]: def list_devices(self) -> list[SignalDevice]:
"""Return all known devices.""" """Return all known devices."""
return list(self._devices.values()) with Session(self.engine) as session:
return list(session.execute(select(SignalDevice)).scalars().all())
def sync_identities(self) -> None: def sync_identities(self) -> None:
"""Pull identity list from signal-cli and record any new ones.""" """Pull identity list from signal-cli and record any new ones."""
@@ -119,17 +109,25 @@ class DeviceRegistry:
if number: if number:
self.record_contact(number, safety) self.record_contact(number, safety)
def _load(self) -> None: def _get(self, phone_number: str) -> SignalDevice | None:
"""Load registry from disk.""" """Fetch a device by phone number."""
if not self.registry_path.exists(): with Session(self.engine) as session:
return return session.execute(
data: list[dict[str, Any]] = json.loads(self.registry_path.read_text()) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
for entry in data: ).scalar_one_or_none()
device = Device.model_validate(entry)
self._devices[device.phone_number] = device
def _save(self) -> None: def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Persist registry to disk.""" """Update the trust level for a device."""
self.registry_path.parent.mkdir(parents=True, exist_ok=True) with Session(self.engine) as session:
data = [device.model_dump(mode="json") for device in self._devices.values()] device = session.execute(
self.registry_path.write_text(json.dumps(data, indent=2) + "\n") select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
return False
device.trust_level = level
session.commit()
if log_msg:
logger.info(f"{log_msg}: {phone_number}")
return True

View File

@@ -10,6 +10,7 @@ from typing import Annotated
import typer import typer
from python.common import configure_logger from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.signal_bot.commands.inventory import handle_inventory_update from python.signal_bot.commands.inventory import handle_inventory_update
from python.signal_bot.device_registry import DeviceRegistry from python.signal_bot.device_registry import DeviceRegistry
from python.signal_bot.llm_client import LLMClient from python.signal_bot.llm_client import LLMClient
@@ -153,7 +154,6 @@ def main(
llm_model: Annotated[str, typer.Option(envvar="LLM_MODEL")] = "qwen3-vl:32b", llm_model: Annotated[str, typer.Option(envvar="LLM_MODEL")] = "qwen3-vl:32b",
llm_port: Annotated[int, typer.Option(envvar="LLM_PORT")] = 11434, llm_port: Annotated[int, typer.Option(envvar="LLM_PORT")] = 11434,
inventory_file: Annotated[str, typer.Option(envvar="INVENTORY_FILE")] = "/var/lib/signal-bot/van_inventory.json", inventory_file: Annotated[str, typer.Option(envvar="INVENTORY_FILE")] = "/var/lib/signal-bot/van_inventory.json",
registry_file: Annotated[str, typer.Option(envvar="REGISTRY_FILE")] = "/var/lib/signal-bot/device_registry.json",
log_level: Annotated[str, typer.Option()] = "INFO", log_level: Annotated[str, typer.Option()] = "INFO",
) -> None: ) -> None:
"""Run the Signal command and control bot.""" """Run the Signal command and control bot."""
@@ -165,12 +165,13 @@ def main(
inventory_file=inventory_file, inventory_file=inventory_file,
) )
engine = get_postgres_engine(name="RICHIE")
with ( with (
SignalClient(config.signal_api_url, config.phone_number) as signal, SignalClient(config.signal_api_url, config.phone_number) as signal,
LLMClient(model=llm_model, host=llm_host, port=llm_port) as llm, LLMClient(model=llm_model, host=llm_host, port=llm_port) as llm,
): ):
registry = DeviceRegistry(signal, Path(registry_file)) registry = DeviceRegistry(signal, engine)
run_loop(config, signal, llm, registry) run_loop(config, signal, llm, registry)