added bot class and rbac style auth with dynamic help msg base on roles

This commit is contained in:
2026-03-16 19:24:57 -04:00
parent a19b1c7e60
commit 7d2fbaea43
7 changed files with 414 additions and 217 deletions

View File

@@ -10,8 +10,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from python.common import utcnow
from python.orm.richie.signal_device import SignalDevice
from python.signal_bot.models import TrustLevel
from python.orm.richie.signal_device import RoleRecord, SignalDevice
from python.signal_bot.models import Role, TrustLevel
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
@@ -29,6 +29,7 @@ class _CacheEntry(NamedTuple):
trust_level: TrustLevel
has_safety_number: bool
safety_number: str | None
roles: list[Role]
class DeviceRegistry:
@@ -50,7 +51,7 @@ class DeviceRegistry:
"""Check if a phone number is verified."""
if entry := self._cached(phone_number):
return entry.trust_level == TrustLevel.VERIFIED
device = self.get_device(phone_number)
device = self._load_device(phone_number)
return device is not None and device.trust_level == TrustLevel.VERIFIED
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
@@ -62,9 +63,10 @@ class DeviceRegistry:
return
with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -83,20 +85,13 @@ class DeviceRegistry:
logger.info(f"New device registered: {phone_number}")
session.commit()
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=now + ttl,
trust_level=device.trust_level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
)
self._update_cache(phone_number, device)
def has_safety_number(self, phone_number: str) -> bool:
"""Check if a device has a safety number on file."""
if entry := self._cached(phone_number):
return entry.has_safety_number
device = self.get_device(phone_number)
device = self._load_device(phone_number)
return device is not None and device.safety_number is not None
def verify(self, phone_number: str) -> bool:
@@ -105,9 +100,10 @@ class DeviceRegistry:
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -116,12 +112,7 @@ class DeviceRegistry:
device.trust_level = TrustLevel.VERIFIED
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
session.commit()
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + _DEFAULT_TTL,
trust_level=TrustLevel.VERIFIED,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
)
self._update_cache(phone_number, device)
logger.info(f"Device verified: {phone_number}")
return True
@@ -133,6 +124,91 @@ class DeviceRegistry:
"""Reset a device to unverified."""
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
# -- role management ------------------------------------------------------
def get_roles(self, phone_number: str) -> list[Role]:
"""Return the roles for a device, defaulting to empty."""
if entry := self._cached(phone_number):
return entry.roles
device = self._load_device(phone_number)
return _extract_roles(device) if device else []
def has_role(self, phone_number: str, role: Role) -> bool:
"""Check if a device has a specific role or is admin."""
roles = self.get_roles(phone_number)
return Role.ADMIN in roles or role in roles
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
return False
if any(rr.name == role for rr in device.roles):
return True
role_record = session.execute(
select(RoleRecord).where(RoleRecord.name == role)
).scalar_one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
return False
device.roles.append(role_record)
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} granted role {role}")
return True
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
return False
device.roles = [rr for rr in device.roles if rr.name != role]
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} revoked role {role}")
return True
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(r) for r in roles]
records = list(
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
device.roles = records
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} roles set to {role_names}")
return True
# -- queries --------------------------------------------------------------
def list_devices(self) -> list[SignalDevice]:
"""Return all known devices."""
with Session(self.engine) as session:
@@ -147,6 +223,8 @@ class DeviceRegistry:
if number:
self.record_contact(number, safety)
# -- internals ------------------------------------------------------------
def _cached(self, phone_number: str) -> _CacheEntry | None:
"""Return the cache entry if it exists and hasn't expired."""
entry = self._contact_cache.get(phone_number)
@@ -154,32 +232,44 @@ class DeviceRegistry:
return entry
return None
def get_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number."""
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
return (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + ttl,
trust_level=device.trust_level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
roles=_extract_roles(device),
)
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
return False
device.trust_level = level
session.commit()
ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + ttl,
trust_level=level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
)
self._update_cache(phone_number, device)
if log_msg:
logger.info(f"{log_msg}: {phone_number}")
return True
def _extract_roles(device: SignalDevice) -> list[Role]:
"""Convert a device's RoleRecord objects to a list of Role enums."""
return [Role(rr.name) for rr in device.roles]