added bot class and rbac style auth with dynamic help msg base on roles

This commit is contained in:
2026-03-16 19:24:57 -04:00
parent a19b1c7e60
commit 7d2fbaea43
7 changed files with 414 additions and 217 deletions

View File

@@ -10,8 +10,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from python.common import utcnow
from python.orm.richie.signal_device import SignalDevice
from python.signal_bot.models import TrustLevel
from python.orm.richie.signal_device import RoleRecord, SignalDevice
from python.signal_bot.models import Role, TrustLevel
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
@@ -29,6 +29,7 @@ class _CacheEntry(NamedTuple):
trust_level: TrustLevel
has_safety_number: bool
safety_number: str | None
roles: list[Role]
class DeviceRegistry:
@@ -50,7 +51,7 @@ class DeviceRegistry:
"""Check if a phone number is verified."""
if entry := self._cached(phone_number):
return entry.trust_level == TrustLevel.VERIFIED
device = self.get_device(phone_number)
device = self._load_device(phone_number)
return device is not None and device.trust_level == TrustLevel.VERIFIED
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
@@ -62,9 +63,10 @@ class DeviceRegistry:
return
with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -83,20 +85,13 @@ class DeviceRegistry:
logger.info(f"New device registered: {phone_number}")
session.commit()
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=now + ttl,
trust_level=device.trust_level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
)
self._update_cache(phone_number, device)
def has_safety_number(self, phone_number: str) -> bool:
"""Check if a device has a safety number on file."""
if entry := self._cached(phone_number):
return entry.has_safety_number
device = self.get_device(phone_number)
device = self._load_device(phone_number)
return device is not None and device.safety_number is not None
def verify(self, phone_number: str) -> bool:
@@ -105,9 +100,10 @@ class DeviceRegistry:
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -116,12 +112,7 @@ class DeviceRegistry:
device.trust_level = TrustLevel.VERIFIED
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
session.commit()
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + _DEFAULT_TTL,
trust_level=TrustLevel.VERIFIED,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
)
self._update_cache(phone_number, device)
logger.info(f"Device verified: {phone_number}")
return True
@@ -133,6 +124,91 @@ class DeviceRegistry:
"""Reset a device to unverified."""
return self._set_trust(phone_number, TrustLevel.UNVERIFIED)
# -- role management ------------------------------------------------------
def get_roles(self, phone_number: str) -> list[Role]:
"""Return the roles for a device, defaulting to empty."""
if entry := self._cached(phone_number):
return entry.roles
device = self._load_device(phone_number)
return _extract_roles(device) if device else []
def has_role(self, phone_number: str, role: Role) -> bool:
"""Check if a device has a specific role or is admin."""
roles = self.get_roles(phone_number)
return Role.ADMIN in roles or role in roles
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
return False
if any(rr.name == role for rr in device.roles):
return True
role_record = session.execute(
select(RoleRecord).where(RoleRecord.name == role)
).scalar_one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
return False
device.roles.append(role_record)
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} granted role {role}")
return True
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
return False
device.roles = [rr for rr in device.roles if rr.name != role]
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} revoked role {role}")
return True
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(r) for r in roles]
records = list(
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
device.roles = records
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} roles set to {role_names}")
return True
# -- queries --------------------------------------------------------------
def list_devices(self) -> list[SignalDevice]:
"""Return all known devices."""
with Session(self.engine) as session:
@@ -147,6 +223,8 @@ class DeviceRegistry:
if number:
self.record_contact(number, safety)
# -- internals ------------------------------------------------------------
def _cached(self, phone_number: str) -> _CacheEntry | None:
"""Return the cache entry if it exists and hasn't expired."""
entry = self._contact_cache.get(phone_number)
@@ -154,32 +232,44 @@ class DeviceRegistry:
return entry
return None
def get_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number."""
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
return (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + ttl,
trust_level=device.trust_level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
roles=_extract_roles(device),
)
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
if not device:
return False
device.trust_level = level
session.commit()
ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL
self._contact_cache[phone_number] = _CacheEntry(
expires=utcnow() + ttl,
trust_level=level,
has_safety_number=device.safety_number is not None,
safety_number=device.safety_number,
)
self._update_cache(phone_number, device)
if log_msg:
logger.info(f"{log_msg}: {phone_number}")
return True
def _extract_roles(device: SignalDevice) -> list[Role]:
"""Convert a device's RoleRecord objects to a list of Role enums."""
return [Role(rr.name) for rr in device.roles]

View File

@@ -3,8 +3,12 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from os import getenv
from typing import Annotated
from typing import TYPE_CHECKING, Annotated
if TYPE_CHECKING:
from collections.abc import Callable
import typer
from sqlalchemy.orm import Session
@@ -17,186 +21,165 @@ from python.signal_bot.commands.inventory import handle_inventory_update
from python.signal_bot.commands.location import handle_location_request
from python.signal_bot.device_registry import DeviceRegistry
from python.signal_bot.llm_client import LLMClient
from python.signal_bot.models import BotConfig, MessageStatus, SignalMessage
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
from python.signal_bot.signal_client import SignalClient
logger = logging.getLogger(__name__)
HELP_TEXT = (
"Available commands:\n"
" inventory <text list> — update van inventory from a text list\n"
" inventory (+ photo) — update van inventory from a receipt photo\n"
" status — show bot status\n"
" location — get current van location\n"
" help — show this help message\n"
"Send a receipt photo with the message 'inventory' to scan it.\n"
)
@dataclass(frozen=True, slots=True)
class Command:
"""A registered bot command."""
action: Callable[[SignalMessage, str], None]
help_text: str
role: Role | None # None = no role required (always allowed)
def help_action(
signal: SignalClient,
message: SignalMessage,
_llm: LLMClient,
_registry: DeviceRegistry,
_config: BotConfig,
_cmd: str,
) -> None:
"""Return the help text for the bot."""
signal.reply(message, HELP_TEXT)
class Bot:
"""Holds shared resources and dispatches incoming messages to command handlers."""
def __init__(
self,
signal: SignalClient,
llm: LLMClient,
registry: DeviceRegistry,
config: BotConfig,
) -> None:
self.signal = signal
self.llm = llm
self.registry = registry
self.config = config
self.commands: dict[str, Command] = {
"help": Command(action=self._help, help_text="show this help message", role=None),
"status": Command(action=self._status, help_text="show bot status", role=Role.STATUS),
"inventory": Command(
action=self._inventory,
help_text="update van inventory from a text list or receipt photo",
role=Role.INVENTORY,
),
"location": Command(
action=self._location,
help_text="get current van location",
role=Role.LOCATION,
),
}
def status_action(
signal: SignalClient,
message: SignalMessage,
llm: LLMClient,
registry: DeviceRegistry,
_config: BotConfig,
_cmd: str,
) -> None:
"""Return the status of the bot."""
models = llm.list_models()
model_list = ", ".join(models[:10])
device_count = len(registry.list_devices())
signal.reply(
message,
f"Bot online.\nLLM: {llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}",
)
# -- actions --------------------------------------------------------------
def _help(self, message: SignalMessage, _cmd: str) -> None:
"""Return help text filtered to the sender's roles."""
self.signal.reply(message, self._build_help(self.registry.get_roles(message.source)))
def unknown_action(
signal: SignalClient,
message: SignalMessage,
_llm: LLMClient,
_registry: DeviceRegistry,
_config: BotConfig,
cmd: str,
) -> None:
"""Return an error message for an unknown command."""
signal.reply(message, f"Unknown command: {cmd}\n\n{HELP_TEXT}")
def inventory_action(
signal: SignalClient,
message: SignalMessage,
llm: LLMClient,
_registry: DeviceRegistry,
config: BotConfig,
_cmd: str,
) -> None:
"""Process an inventory update."""
handle_inventory_update(message, signal, llm, config.inventory_api_url)
def location_action(
signal: SignalClient,
message: SignalMessage,
_llm: LLMClient,
_registry: DeviceRegistry,
config: BotConfig,
_cmd: str,
) -> None:
"""Reply with current van location."""
handle_location_request(message, signal, config.ha_url, config.ha_token, config.ha_location_entity)
def dispatch(
message: SignalMessage,
signal: SignalClient,
llm: LLMClient,
registry: DeviceRegistry,
config: BotConfig,
) -> None:
"""Route an incoming message to the right command handler."""
source = message.source
if not registry.is_verified(source) or not registry.has_safety_number(source):
logger.info(f"Device {source} not verified, ignoring message")
return
text = message.message.strip()
parts = text.split()
if not parts and not message.attachments:
return
cmd = parts[0].lower() if parts else ""
commands = {
"help": help_action,
"status": status_action,
"inventory": inventory_action,
"location": location_action,
}
logger.info(f"f{source=} running {cmd=} with {message=}")
action = commands.get(cmd)
if action is None:
if message.attachments:
action = inventory_action
cmd = "inventory"
else:
return
action(signal, message, llm, registry, config, cmd)
def _process_message(
message: SignalMessage,
signal: SignalClient,
llm: LLMClient,
registry: DeviceRegistry,
config: BotConfig,
) -> None:
"""Process a single message, sending it to the dead letter queue after repeated failures."""
max_attempts = config.max_message_attempts
for attempt in range(1, max_attempts + 1):
try:
safety_number = signal.get_safety_number(message.source)
registry.record_contact(message.source, safety_number)
dispatch(message, signal, llm, registry, config)
except Exception:
logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})")
else:
return
logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue")
with Session(config.engine) as session:
session.add(
DeadLetterMessage(
source=message.source,
message=message.message,
received_at=utcnow(),
status=MessageStatus.UNPROCESSED,
)
def _status(self, message: SignalMessage, _cmd: str) -> None:
"""Return the status of the bot."""
models = self.llm.list_models()
model_list = ", ".join(models[:10])
device_count = len(self.registry.list_devices())
self.signal.reply(
message,
f"Bot online.\nLLM: {self.llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}",
)
session.commit()
def _inventory(self, message: SignalMessage, _cmd: str) -> None:
"""Process an inventory update."""
handle_inventory_update(message, self.signal, self.llm, self.config.inventory_api_url)
def run_loop(
config: BotConfig,
signal: SignalClient,
llm: LLMClient,
registry: DeviceRegistry,
) -> None:
"""Listen for messages via WebSocket, reconnecting on failure."""
logger.info("Bot started — listening via WebSocket")
def _location(self, message: SignalMessage, _cmd: str) -> None:
"""Reply with current van location."""
handle_location_request(
message, self.signal, self.config.ha_url, self.config.ha_token, self.config.ha_location_entity
)
@retry(
stop=stop_after_attempt(config.max_retries),
wait=wait_exponential(multiplier=config.reconnect_delay, max=config.max_reconnect_delay),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
def _listen() -> None:
for message in signal.listen():
logger.info(f"Message from {message.source}: {message.message[:80]}")
_process_message(message, signal, llm, registry, config)
# -- dispatch -------------------------------------------------------------
try:
_listen()
except Exception:
logger.critical("Max retries exceeded, shutting down")
raise
def _build_help(self, roles: list[Role]) -> str:
"""Build help text showing only the commands the user can access."""
is_admin = Role.ADMIN in roles
lines = ["Available commands:"]
for name, cmd in self.commands.items():
if cmd.role is None or is_admin or cmd.role in roles:
lines.append(f" {name:20s}{cmd.help_text}")
return "\n".join(lines)
def dispatch(self, message: SignalMessage) -> None:
"""Route an incoming message to the right command handler."""
source = message.source
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source):
logger.info(f"Device {source} not verified, ignoring message")
return
text = message.message.strip()
parts = text.split()
if not parts and not message.attachments:
return
cmd = parts[0].lower() if parts else ""
logger.info(f"f{source=} running {cmd=} with {message=}")
command = self.commands.get(cmd)
if command is None:
if message.attachments:
command = self.commands["inventory"]
cmd = "inventory"
else:
return
if command.role is not None and not self.registry.has_role(source, command.role):
logger.warning(f"Device {source} denied access to {cmd!r}")
self.signal.reply(message, f"Permission denied: you do not have the '{command.role}' role.")
return
command.action(message, cmd)
def process_message(self, message: SignalMessage) -> None:
"""Process a single message, sending it to the dead letter queue after repeated failures."""
max_attempts = self.config.max_message_attempts
for attempt in range(1, max_attempts + 1):
try:
safety_number = self.signal.get_safety_number(message.source)
self.registry.record_contact(message.source, safety_number)
self.dispatch(message)
except Exception:
logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})")
else:
return
logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue")
with Session(self.config.engine) as session:
session.add(
DeadLetterMessage(
source=message.source,
message=message.message,
received_at=utcnow(),
status=MessageStatus.UNPROCESSED,
)
)
session.commit()
def run(self) -> None:
"""Listen for messages via WebSocket, reconnecting on failure."""
logger.info("Bot started — listening via WebSocket")
@retry(
stop=stop_after_attempt(self.config.max_retries),
wait=wait_exponential(multiplier=self.config.reconnect_delay, max=self.config.max_reconnect_delay),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
def _listen() -> None:
for message in self.signal.listen():
logger.info(f"Message from {message.source}: {message.message[:80]}")
self.process_message(message)
try:
_listen()
except Exception:
logger.critical("Max retries exceeded, shutting down")
raise
def main(
@@ -242,7 +225,8 @@ def main(
LLMClient(model=llm_model, host=llm_host, port=llm_port, timeout=llm_timeout) as llm,
):
registry = DeviceRegistry(signal, engine)
run_loop(config, signal, llm, registry)
bot = Bot(signal, llm, registry, config)
bot.run()
if __name__ == "__main__":

View File

@@ -18,6 +18,15 @@ class TrustLevel(StrEnum):
BLOCKED = "blocked"
class Role(StrEnum):
"""RBAC roles — one per command, plus admin which grants all."""
ADMIN = "admin"
STATUS = "status"
INVENTORY = "inventory"
LOCATION = "location"
class MessageStatus(StrEnum):
"""Dead letter queue message status."""