mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
298 lines
11 KiB
Python
298 lines
11 KiB
Python
"""Device registry — tracks verified/unverified devices by safety number."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import TYPE_CHECKING, NamedTuple
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from python.common import utcnow
|
|
from python.orm.signal_bot.models import RoleRecord, SignalDevice
|
|
from python.signal_bot.models import Role, TrustLevel
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.engine import Engine
|
|
|
|
from python.signal_bot.signal_client import SignalClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_BLOCKED_TTL = timedelta(minutes=60)
|
|
_DEFAULT_TTL = timedelta(minutes=5)
|
|
|
|
|
|
class _CacheEntry(NamedTuple):
|
|
expires: datetime
|
|
trust_level: TrustLevel
|
|
has_safety_number: bool
|
|
safety_number: str | None
|
|
roles: list[Role]
|
|
|
|
|
|
class DeviceRegistry:
|
|
"""Manage device trust based on Signal safety numbers.
|
|
|
|
Devices start as UNVERIFIED. An admin verifies them over SSH by calling
|
|
``verify(phone_number)`` which marks the device VERIFIED and also tells
|
|
signal-cli to trust the identity.
|
|
|
|
Only VERIFIED devices may execute commands.
|
|
"""
|
|
|
|
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
|
self.signal_client = signal_client
|
|
self.engine = engine
|
|
self._contact_cache: dict[str, _CacheEntry] = {}
|
|
|
|
def is_verified(self, phone_number: str) -> bool:
|
|
"""Check if a phone number is verified."""
|
|
if entry := self._cached(phone_number):
|
|
return entry.trust_level == TrustLevel.VERIFIED
|
|
device = self._load_device(phone_number)
|
|
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
|
|
|
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
|
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
|
now = utcnow()
|
|
|
|
entry = self._cached(phone_number)
|
|
if entry and entry.safety_number == safety_number:
|
|
return
|
|
|
|
with Session(self.engine) as session:
|
|
device = (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
if device:
|
|
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
|
logger.warning(f"Safety number changed for {phone_number}, resetting to UNVERIFIED")
|
|
device.safety_number = safety_number
|
|
device.trust_level = TrustLevel.UNVERIFIED
|
|
device.last_seen = now
|
|
else:
|
|
device = SignalDevice(
|
|
phone_number=phone_number,
|
|
safety_number=safety_number,
|
|
trust_level=TrustLevel.UNVERIFIED,
|
|
last_seen=now,
|
|
)
|
|
session.add(device)
|
|
logger.info(f"New device registered: {phone_number}")
|
|
|
|
session.commit()
|
|
self._update_cache(phone_number, device)
|
|
|
|
def has_safety_number(self, phone_number: str) -> bool:
|
|
"""Check if a device has a safety number on file."""
|
|
if entry := self._cached(phone_number):
|
|
return entry.has_safety_number
|
|
device = self._load_device(phone_number)
|
|
return device is not None and device.safety_number is not None
|
|
|
|
def verify(self, phone_number: str) -> bool:
|
|
"""Mark a device as verified. Called by admin over SSH.
|
|
|
|
Returns True if the device was found and verified.
|
|
"""
|
|
with Session(self.engine) as session:
|
|
device = (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
if not device:
|
|
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
|
return False
|
|
|
|
device.trust_level = TrustLevel.VERIFIED
|
|
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
|
session.commit()
|
|
self._update_cache(phone_number, device)
|
|
logger.info(f"Device verified: {phone_number}")
|
|
return True
|
|
|
|
def block(self, phone_number: str) -> bool:
|
|
"""Block a device."""
|
|
return self._set_trust(phone_number, TrustLevel.BLOCKED, "Device blocked")
|
|
|
|
def unverify(self, phone_number: str) -> bool:
|
|
"""Reset a device to unverified."""
|
|
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
|
|
|
|
# -- role management ------------------------------------------------------
|
|
|
|
def get_roles(self, phone_number: str) -> list[Role]:
|
|
"""Return the roles for a device, defaulting to empty."""
|
|
if entry := self._cached(phone_number):
|
|
return entry.roles
|
|
device = self._load_device(phone_number)
|
|
return _extract_roles(device) if device else []
|
|
|
|
def has_role(self, phone_number: str, role: Role) -> bool:
|
|
"""Check if a device has a specific role or is admin."""
|
|
roles = self.get_roles(phone_number)
|
|
return Role.ADMIN in roles or role in roles
|
|
|
|
def grant_role(self, phone_number: str, role: Role) -> bool:
|
|
"""Add a role to a device. Called by admin over SSH."""
|
|
with Session(self.engine) as session:
|
|
device = (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
if not device:
|
|
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
|
return False
|
|
|
|
if any(record.name == role for record in device.roles):
|
|
return True
|
|
|
|
role_record = session.execute(
|
|
select(RoleRecord).where(RoleRecord.name == role)
|
|
).scalar_one_or_none()
|
|
|
|
if not role_record:
|
|
logger.warning(f"Unknown role: {role}")
|
|
return False
|
|
|
|
device.roles.append(role_record)
|
|
session.commit()
|
|
self._update_cache(phone_number, device)
|
|
logger.info(f"Device {phone_number} granted role {role}")
|
|
return True
|
|
|
|
def revoke_role(self, phone_number: str, role: Role) -> bool:
|
|
"""Remove a role from a device. Called by admin over SSH."""
|
|
with Session(self.engine) as session:
|
|
device = (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
if not device:
|
|
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
|
return False
|
|
|
|
device.roles = [record for record in device.roles if record.name != role]
|
|
session.commit()
|
|
self._update_cache(phone_number, device)
|
|
logger.info(f"Device {phone_number} revoked role {role}")
|
|
return True
|
|
|
|
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
|
|
"""Replace all roles for a device. Called by admin over SSH."""
|
|
with Session(self.engine) as session:
|
|
device = (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
if not device:
|
|
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
|
return False
|
|
|
|
role_names = [str(role) for role in roles]
|
|
records = list(
|
|
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
|
|
)
|
|
device.roles = records
|
|
session.commit()
|
|
self._update_cache(phone_number, device)
|
|
logger.info(f"Device {phone_number} roles set to {role_names}")
|
|
return True
|
|
|
|
# -- queries --------------------------------------------------------------
|
|
|
|
def list_devices(self) -> list[SignalDevice]:
|
|
"""Return all known devices."""
|
|
with Session(self.engine) as session:
|
|
return list(session.execute(select(SignalDevice)).scalars().all())
|
|
|
|
def sync_identities(self) -> None:
|
|
"""Pull identity list from signal-cli and record any new ones."""
|
|
identities = self.signal_client.get_identities()
|
|
for identity in identities:
|
|
number = identity.get("number", "")
|
|
safety = identity.get("safety_number", identity.get("fingerprint", ""))
|
|
if number:
|
|
self.record_contact(number, safety)
|
|
|
|
# -- internals ------------------------------------------------------------
|
|
|
|
def _cached(self, phone_number: str) -> _CacheEntry | None:
|
|
"""Return the cache entry if it exists and hasn't expired."""
|
|
entry = self._contact_cache.get(phone_number)
|
|
if entry and utcnow() < entry.expires:
|
|
return entry
|
|
return None
|
|
|
|
def _load_device(self, phone_number: str) -> SignalDevice | None:
|
|
"""Fetch a device by phone number (with joined roles)."""
|
|
with Session(self.engine) as session:
|
|
return (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
|
|
"""Refresh the cache entry for a device."""
|
|
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
|
self._contact_cache[phone_number] = _CacheEntry(
|
|
expires=utcnow() + ttl,
|
|
trust_level=device.trust_level,
|
|
has_safety_number=device.safety_number is not None,
|
|
safety_number=device.safety_number,
|
|
roles=_extract_roles(device),
|
|
)
|
|
|
|
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
|
"""Update the trust level for a device."""
|
|
with Session(self.engine) as session:
|
|
device = (
|
|
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
|
.scalar_one_or_none()
|
|
)
|
|
|
|
if not device:
|
|
return False
|
|
|
|
device.trust_level = level
|
|
session.commit()
|
|
self._update_cache(phone_number, device)
|
|
if log_msg:
|
|
logger.info(f"{log_msg}: {phone_number}")
|
|
return True
|
|
|
|
|
|
def _extract_roles(device: SignalDevice) -> list[Role]:
|
|
"""Convert a device's RoleRecord objects to a list of Role enums."""
|
|
return [Role(record.name) for record in device.roles]
|
|
|
|
|
|
def sync_roles(engine: Engine) -> None:
|
|
"""Sync the Role enum to the role table, adding new and removing stale entries."""
|
|
expected = {role.value for role in Role}
|
|
|
|
with Session(engine) as session:
|
|
existing = {record.name for record in session.execute(select(RoleRecord)).scalars().all()}
|
|
|
|
to_add = expected - existing
|
|
to_remove = existing - expected
|
|
|
|
for name in to_add:
|
|
session.add(RoleRecord(name=name))
|
|
logger.info(f"Role added: {name}")
|
|
|
|
if to_remove:
|
|
session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove)))
|
|
for name in to_remove:
|
|
logger.info(f"Role removed: {name}")
|
|
|
|
session.commit()
|