mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 04:58:19 -04:00
moved device registry to postgresql
This commit is contained in:
@@ -2,15 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.common import utcnow
|
||||
from python.signal_bot.models import Device, TrustLevel
|
||||
from python.orm.richie.signal_device import SignalDevice
|
||||
from python.signal_bot.models import TrustLevel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from python.signal_bot.signal_client import SignalClient
|
||||
|
||||
@@ -25,90 +28,77 @@ class DeviceRegistry:
|
||||
signal-cli to trust the identity.
|
||||
|
||||
Only VERIFIED devices may execute commands.
|
||||
|
||||
Args:
|
||||
signal_client: The Signal API client (used to sync identities).
|
||||
registry_path: Path to the JSON file that persists device state.
|
||||
"""
|
||||
|
||||
def __init__(self, signal_client: SignalClient, registry_path: Path) -> None:
|
||||
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
||||
self.signal_client = signal_client
|
||||
self.registry_path = registry_path
|
||||
self._devices: dict[str, Device] = {}
|
||||
self._load()
|
||||
self.engine = engine
|
||||
|
||||
def is_verified(self, phone_number: str) -> bool:
|
||||
"""Check if a phone number is verified."""
|
||||
device = self._devices.get(phone_number)
|
||||
device = self._get(phone_number)
|
||||
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
||||
|
||||
def is_blocked(self, phone_number: str) -> bool:
|
||||
"""Check if a phone number is blocked."""
|
||||
device = self._devices.get(phone_number)
|
||||
return device is not None and device.trust_level == TrustLevel.BLOCKED
|
||||
|
||||
def record_contact(self, phone_number: str, safety_number: str) -> Device:
|
||||
def record_contact(self, phone_number: str, safety_number: str) -> SignalDevice:
|
||||
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
||||
now = utcnow()
|
||||
if phone_number in self._devices:
|
||||
device = self._devices[phone_number]
|
||||
if device.safety_number != safety_number:
|
||||
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
|
||||
device.safety_number = safety_number
|
||||
device.trust_level = TrustLevel.UNVERIFIED
|
||||
device.last_seen = now
|
||||
else:
|
||||
device = Device(
|
||||
phone_number=phone_number,
|
||||
safety_number=safety_number,
|
||||
trust_level=TrustLevel.UNVERIFIED,
|
||||
first_seen=now,
|
||||
last_seen=now,
|
||||
)
|
||||
self._devices[phone_number] = device
|
||||
logger.info(f"New device registered: {phone_number}")
|
||||
with Session(self.engine) as session:
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
self._save()
|
||||
return device
|
||||
if device:
|
||||
if device.safety_number != safety_number:
|
||||
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
|
||||
device.safety_number = safety_number
|
||||
device.trust_level = TrustLevel.UNVERIFIED
|
||||
device.last_seen = now
|
||||
else:
|
||||
device = SignalDevice(
|
||||
phone_number=phone_number,
|
||||
safety_number=safety_number,
|
||||
trust_level=TrustLevel.UNVERIFIED,
|
||||
last_seen=now,
|
||||
)
|
||||
session.add(device)
|
||||
logger.info(f"New device registered: {phone_number}")
|
||||
|
||||
session.commit()
|
||||
session.refresh(device)
|
||||
return device
|
||||
|
||||
def verify(self, phone_number: str) -> bool:
|
||||
"""Mark a device as verified. Called by admin over SSH.
|
||||
|
||||
Returns True if the device was found and verified.
|
||||
"""
|
||||
device = self._devices.get(phone_number)
|
||||
if not device:
|
||||
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
||||
return False
|
||||
with Session(self.engine) as session:
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
device.trust_level = TrustLevel.VERIFIED
|
||||
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
||||
self._save()
|
||||
logger.info(f"Device verified: {phone_number}")
|
||||
return True
|
||||
if not device:
|
||||
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
||||
return False
|
||||
|
||||
device.trust_level = TrustLevel.VERIFIED
|
||||
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
||||
session.commit()
|
||||
logger.info(f"Device verified: {phone_number}")
|
||||
return True
|
||||
|
||||
def block(self, phone_number: str) -> bool:
|
||||
"""Block a device."""
|
||||
device = self._devices.get(phone_number)
|
||||
if not device:
|
||||
return False
|
||||
device.trust_level = TrustLevel.BLOCKED
|
||||
self._save()
|
||||
logger.info(f"Device blocked: {phone_number}")
|
||||
return True
|
||||
return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
|
||||
|
||||
def unverify(self, phone_number: str) -> bool:
|
||||
"""Reset a device to unverified."""
|
||||
device = self._devices.get(phone_number)
|
||||
if not device:
|
||||
return False
|
||||
device.trust_level = TrustLevel.UNVERIFIED
|
||||
self._save()
|
||||
return True
|
||||
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
|
||||
|
||||
def list_devices(self) -> list[Device]:
|
||||
def list_devices(self) -> list[SignalDevice]:
|
||||
"""Return all known devices."""
|
||||
return list(self._devices.values())
|
||||
with Session(self.engine) as session:
|
||||
return list(session.execute(select(SignalDevice)).scalars().all())
|
||||
|
||||
def sync_identities(self) -> None:
|
||||
"""Pull identity list from signal-cli and record any new ones."""
|
||||
@@ -119,17 +109,25 @@ class DeviceRegistry:
|
||||
if number:
|
||||
self.record_contact(number, safety)
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load registry from disk."""
|
||||
if not self.registry_path.exists():
|
||||
return
|
||||
data: list[dict[str, Any]] = json.loads(self.registry_path.read_text())
|
||||
for entry in data:
|
||||
device = Device.model_validate(entry)
|
||||
self._devices[device.phone_number] = device
|
||||
def _get(self, phone_number: str) -> SignalDevice | None:
|
||||
"""Fetch a device by phone number."""
|
||||
with Session(self.engine) as session:
|
||||
return session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Persist registry to disk."""
|
||||
self.registry_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = [device.model_dump(mode="json") for device in self._devices.values()]
|
||||
self.registry_path.write_text(json.dumps(data, indent=2) + "\n")
|
||||
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
||||
"""Update the trust level for a device."""
|
||||
with Session(self.engine) as session:
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not device:
|
||||
return False
|
||||
|
||||
device.trust_level = level
|
||||
session.commit()
|
||||
if log_msg:
|
||||
logger.info(f"{log_msg}: {phone_number}")
|
||||
return True
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Annotated
|
||||
import typer
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_postgres_engine
|
||||
from python.signal_bot.commands.inventory import handle_inventory_update
|
||||
from python.signal_bot.device_registry import DeviceRegistry
|
||||
from python.signal_bot.llm_client import LLMClient
|
||||
@@ -153,7 +154,6 @@ def main(
|
||||
llm_model: Annotated[str, typer.Option(envvar="LLM_MODEL")] = "qwen3-vl:32b",
|
||||
llm_port: Annotated[int, typer.Option(envvar="LLM_PORT")] = 11434,
|
||||
inventory_file: Annotated[str, typer.Option(envvar="INVENTORY_FILE")] = "/var/lib/signal-bot/van_inventory.json",
|
||||
registry_file: Annotated[str, typer.Option(envvar="REGISTRY_FILE")] = "/var/lib/signal-bot/device_registry.json",
|
||||
log_level: Annotated[str, typer.Option()] = "INFO",
|
||||
) -> None:
|
||||
"""Run the Signal command and control bot."""
|
||||
@@ -165,12 +165,13 @@ def main(
|
||||
inventory_file=inventory_file,
|
||||
)
|
||||
|
||||
engine = get_postgres_engine(name="RICHIE")
|
||||
|
||||
with (
|
||||
SignalClient(config.signal_api_url, config.phone_number) as signal,
|
||||
LLMClient(model=llm_model, host=llm_host, port=llm_port) as llm,
|
||||
):
|
||||
registry = DeviceRegistry(signal, Path(registry_file))
|
||||
registry = DeviceRegistry(signal, engine)
|
||||
run_loop(config, signal, llm, registry)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user