diff --git a/python/alembic/richie/versions/2026_03_16-adding_roles_to_signal_devices_2ef7ba690159.py b/python/alembic/richie/versions/2026_03_16-adding_roles_to_signal_devices_2ef7ba690159.py new file mode 100644 index 0000000..902e1b1 --- /dev/null +++ b/python/alembic/richie/versions/2026_03_16-adding_roles_to_signal_devices_2ef7ba690159.py @@ -0,0 +1,66 @@ +"""adding roles to signal devices. + +Revision ID: 2ef7ba690159 +Revises: a1b2c3d4e5f6 +Create Date: 2026-03-16 19:22:38.020350 + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from alembic import op + +from python.orm import RichieBase + +if TYPE_CHECKING: + from collections.abc import Sequence + +# revision identifiers, used by Alembic. +revision: str = "2ef7ba690159" +down_revision: str | None = "a1b2c3d4e5f6" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +schema = RichieBase.schema_name + + +def upgrade() -> None: + """Upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "role", + sa.Column("name", sa.String(length=50), nullable=False), + sa.Column("id", sa.SmallInteger(), nullable=False), + sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_role")), + sa.UniqueConstraint("name", name=op.f("uq_role_name")), + schema=schema, + ) + op.create_table( + "device_role", + sa.Column("device_id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.SmallInteger(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint( + ["device_id"], [f"{schema}.signal_device.id"], name=op.f("fk_device_role_device_id_signal_device") + ), + sa.ForeignKeyConstraint(["role_id"], [f"{schema}.role.id"], name=op.f("fk_device_role_role_id_role")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_device_role")), + sa.UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"), + schema=schema, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("device_role", schema=schema) + op.drop_table("role", schema=schema) + # ### end Alembic commands ### diff --git a/python/orm/richie/__init__.py b/python/orm/richie/__init__.py index 6448952..85806c7 100644 --- a/python/orm/richie/__init__.py +++ b/python/orm/richie/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from python.orm.richie.base import RichieBase, TableBase +from python.orm.richie.base import RichieBase, TableBase, TableBaseBig, TableBaseSmall from python.orm.richie.congress import Bill, Legislator, Vote, VoteRecord from python.orm.richie.contact import ( Contact, @@ -12,7 +12,7 @@ from python.orm.richie.contact import ( RelationshipType, ) from python.orm.richie.dead_letter_message import DeadLetterMessage -from python.orm.richie.signal_device import SignalDevice +from python.orm.richie.signal_device import DeviceRole, RoleRecord, SignalDevice __all__ = [ "Bill", @@ -20,12 +20,16 @@ __all__ = [ "ContactNeed", "ContactRelationship", "DeadLetterMessage", + "DeviceRole", + "RoleRecord", "Legislator", "Need", "RelationshipType", "RichieBase", "SignalDevice", "TableBase", + "TableBaseBig", + "TableBaseSmall", "Vote", "VoteRecord", ] diff --git a/python/orm/richie/base.py b/python/orm/richie/base.py index 20b7231..9810726 100644 --- a/python/orm/richie/base.py +++ b/python/orm/richie/base.py @@ -4,7 +4,7 @@ from __future__ import annotations from datetime import datetime -from sqlalchemy import DateTime, MetaData, func +from sqlalchemy import BigInteger, DateTime, MetaData, SmallInteger, func from sqlalchemy.ext.declarative import AbstractConcreteBase from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -22,12 +22,9 @@ class RichieBase(DeclarativeBase): ) -class TableBase(AbstractConcreteBase, RichieBase): - """Abstract concrete base for richie tables with IDs and timestamps.""" +class _TableMixin: + """Shared timestamp columns for all table bases.""" - __abstract__ = True - - id: Mapped[int] = mapped_column(primary_key=True) created: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -37,3 +34,27 @@ class TableBase(AbstractConcreteBase, RichieBase): server_default=func.now(), onupdate=func.now(), ) + + +class TableBaseSmall(_TableMixin, AbstractConcreteBase, RichieBase): + """Table with SmallInteger primary key.""" + + __abstract__ = True + + id: Mapped[int] = mapped_column(SmallInteger, primary_key=True) + + +class TableBase(_TableMixin, AbstractConcreteBase, RichieBase): + """Table with Integer primary key.""" + + __abstract__ = True + + id: Mapped[int] = mapped_column(primary_key=True) + + +class TableBaseBig(_TableMixin, AbstractConcreteBase, RichieBase): + """Table with BigInteger primary key.""" + + __abstract__ = True + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) diff --git a/python/orm/richie/signal_device.py b/python/orm/richie/signal_device.py index c36a40a..ea62bc4 100644 --- a/python/orm/richie/signal_device.py +++ b/python/orm/richie/signal_device.py @@ -1,17 +1,38 @@ -"""Signal bot device registry models.""" +"""Signal bot device and role ORM models.""" from __future__ import annotations from datetime import datetime -from sqlalchemy import DateTime, String +from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, UniqueConstraint from sqlalchemy.dialects.postgresql import ENUM -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship -from python.orm.richie.base import TableBase +from python.orm.richie.base import TableBase, TableBaseSmall from python.signal_bot.models import TrustLevel +class RoleRecord(TableBaseSmall): + """Lookup table for RBAC roles, keyed by smallint.""" + + __tablename__ = "role" + + name: Mapped[str] = mapped_column(String(50), unique=True) + + +class DeviceRole(TableBase): + """Association between a device and a role.""" + + __tablename__ = "device_role" + __table_args__ = ( + UniqueConstraint("device_id", "role_id", name="uq_device_role_device_role"), + {"schema": "main"}, + ) + + device_id: Mapped[int] = mapped_column(ForeignKey("main.signal_device.id")) + role_id: Mapped[int] = mapped_column(SmallInteger, ForeignKey("main.role.id")) + + class SignalDevice(TableBase): """A Signal device tracked by phone number and safety number.""" @@ -24,3 +45,5 @@ class SignalDevice(TableBase): default=TrustLevel.UNVERIFIED, ) last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + + roles: Mapped[list[RoleRecord]] = relationship(secondary=DeviceRole.__table__) diff --git a/python/signal_bot/device_registry.py b/python/signal_bot/device_registry.py index f6445be..48c96de 100644 --- a/python/signal_bot/device_registry.py +++ b/python/signal_bot/device_registry.py @@ -10,8 +10,8 @@ from sqlalchemy import select from sqlalchemy.orm import Session from python.common import utcnow -from python.orm.richie.signal_device import SignalDevice -from python.signal_bot.models import TrustLevel +from python.orm.richie.signal_device import RoleRecord, SignalDevice +from python.signal_bot.models import Role, TrustLevel if TYPE_CHECKING: from sqlalchemy.engine import Engine @@ -29,6 +29,7 @@ class _CacheEntry(NamedTuple): trust_level: TrustLevel has_safety_number: bool safety_number: str | None + roles: list[Role] class DeviceRegistry: @@ -50,7 +51,7 @@ class DeviceRegistry: """Check if a phone number is verified.""" if entry := self._cached(phone_number): return entry.trust_level == TrustLevel.VERIFIED - device = self.get_device(phone_number) + device = self._load_device(phone_number) return device is not None and device.trust_level == TrustLevel.VERIFIED def record_contact(self, phone_number: str, safety_number: str | None = None) -> None: @@ -62,9 +63,10 @@ class DeviceRegistry: return with Session(self.engine) as session: - device = session.execute( - select(SignalDevice).where(SignalDevice.phone_number == phone_number) - ).scalar_one_or_none() + device = ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) if device: if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED: @@ -83,20 +85,13 @@ class DeviceRegistry: logger.info(f"New device registered: {phone_number}") session.commit() - - ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL - self._contact_cache[phone_number] = _CacheEntry( - expires=now + ttl, - trust_level=device.trust_level, - has_safety_number=device.safety_number is not None, - safety_number=device.safety_number, - ) + self._update_cache(phone_number, device) def has_safety_number(self, phone_number: str) -> bool: """Check if a device has a safety number on file.""" if entry := self._cached(phone_number): return entry.has_safety_number - device = self.get_device(phone_number) + device = self._load_device(phone_number) return device is not None and device.safety_number is not None def verify(self, phone_number: str) -> bool: @@ -105,9 +100,10 @@ class DeviceRegistry: Returns True if the device was found and verified. """ with Session(self.engine) as session: - device = session.execute( - select(SignalDevice).where(SignalDevice.phone_number == phone_number) - ).scalar_one_or_none() + device = ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) if not device: logger.warning(f"Cannot verify unknown device: {phone_number}") @@ -116,12 +112,7 @@ class DeviceRegistry: device.trust_level = TrustLevel.VERIFIED self.signal_client.trust_identity(phone_number, trust_all_known_keys=True) session.commit() - self._contact_cache[phone_number] = _CacheEntry( - expires=utcnow() + _DEFAULT_TTL, - trust_level=TrustLevel.VERIFIED, - has_safety_number=device.safety_number is not None, - safety_number=device.safety_number, - ) + self._update_cache(phone_number, device) logger.info(f"Device verified: {phone_number}") return True @@ -133,6 +124,91 @@ class DeviceRegistry: """Reset a device to unverified.""" return self._set_trust(phone_number, TrustLevel.UNVERIFIED) + # -- role management ------------------------------------------------------ + + def get_roles(self, phone_number: str) -> list[Role]: + """Return the roles for a device, defaulting to empty.""" + if entry := self._cached(phone_number): + return entry.roles + device = self._load_device(phone_number) + return _extract_roles(device) if device else [] + + def has_role(self, phone_number: str, role: Role) -> bool: + """Check if a device has a specific role or is admin.""" + roles = self.get_roles(phone_number) + return Role.ADMIN in roles or role in roles + + def grant_role(self, phone_number: str, role: Role) -> bool: + """Add a role to a device. Called by admin over SSH.""" + with Session(self.engine) as session: + device = ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) + + if not device: + logger.warning(f"Cannot grant role for unknown device: {phone_number}") + return False + + if any(rr.name == role for rr in device.roles): + return True + + role_record = session.execute( + select(RoleRecord).where(RoleRecord.name == role) + ).scalar_one_or_none() + + if not role_record: + logger.warning(f"Unknown role: {role}") + return False + + device.roles.append(role_record) + session.commit() + self._update_cache(phone_number, device) + logger.info(f"Device {phone_number} granted role {role}") + return True + + def revoke_role(self, phone_number: str, role: Role) -> bool: + """Remove a role from a device. Called by admin over SSH.""" + with Session(self.engine) as session: + device = ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) + + if not device: + logger.warning(f"Cannot revoke role for unknown device: {phone_number}") + return False + + device.roles = [rr for rr in device.roles if rr.name != role] + session.commit() + self._update_cache(phone_number, device) + logger.info(f"Device {phone_number} revoked role {role}") + return True + + def set_roles(self, phone_number: str, roles: list[Role]) -> bool: + """Replace all roles for a device. Called by admin over SSH.""" + with Session(self.engine) as session: + device = ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) + + if not device: + logger.warning(f"Cannot set roles for unknown device: {phone_number}") + return False + + role_names = [str(r) for r in roles] + records = list( + session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all() + ) + device.roles = records + session.commit() + self._update_cache(phone_number, device) + logger.info(f"Device {phone_number} roles set to {role_names}") + return True + + # -- queries -------------------------------------------------------------- + def list_devices(self) -> list[SignalDevice]: """Return all known devices.""" with Session(self.engine) as session: @@ -147,6 +223,8 @@ class DeviceRegistry: if number: self.record_contact(number, safety) + # -- internals ------------------------------------------------------------ + def _cached(self, phone_number: str) -> _CacheEntry | None: """Return the cache entry if it exists and hasn't expired.""" entry = self._contact_cache.get(phone_number) @@ -154,32 +232,44 @@ class DeviceRegistry: return entry return None - def get_device(self, phone_number: str) -> SignalDevice | None: - """Fetch a device by phone number.""" + def _load_device(self, phone_number: str) -> SignalDevice | None: + """Fetch a device by phone number (with joined roles).""" with Session(self.engine) as session: - return session.execute( - select(SignalDevice).where(SignalDevice.phone_number == phone_number) - ).scalar_one_or_none() + return ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) + + def _update_cache(self, phone_number: str, device: SignalDevice) -> None: + """Refresh the cache entry for a device.""" + ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL + self._contact_cache[phone_number] = _CacheEntry( + expires=utcnow() + ttl, + trust_level=device.trust_level, + has_safety_number=device.safety_number is not None, + safety_number=device.safety_number, + roles=_extract_roles(device), + ) def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool: """Update the trust level for a device.""" with Session(self.engine) as session: - device = session.execute( - select(SignalDevice).where(SignalDevice.phone_number == phone_number) - ).scalar_one_or_none() + device = ( + session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) + .scalar_one_or_none() + ) if not device: return False device.trust_level = level session.commit() - ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL - self._contact_cache[phone_number] = _CacheEntry( - expires=utcnow() + ttl, - trust_level=level, - has_safety_number=device.safety_number is not None, - safety_number=device.safety_number, - ) + self._update_cache(phone_number, device) if log_msg: logger.info(f"{log_msg}: {phone_number}") return True + + +def _extract_roles(device: SignalDevice) -> list[Role]: + """Convert a device's RoleRecord objects to a list of Role enums.""" + return [Role(rr.name) for rr in device.roles] diff --git a/python/signal_bot/main.py b/python/signal_bot/main.py index 1eca710..08e98d9 100644 --- a/python/signal_bot/main.py +++ b/python/signal_bot/main.py @@ -3,8 +3,12 @@ from __future__ import annotations import logging +from dataclasses import dataclass from os import getenv -from typing import Annotated +from typing import TYPE_CHECKING, Annotated + +if TYPE_CHECKING: + from collections.abc import Callable import typer from sqlalchemy.orm import Session @@ -17,186 +21,165 @@ from python.signal_bot.commands.inventory import handle_inventory_update from python.signal_bot.commands.location import handle_location_request from python.signal_bot.device_registry import DeviceRegistry from python.signal_bot.llm_client import LLMClient -from python.signal_bot.models import BotConfig, MessageStatus, SignalMessage +from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage from python.signal_bot.signal_client import SignalClient logger = logging.getLogger(__name__) -HELP_TEXT = ( - "Available commands:\n" - " inventory — update van inventory from a text list\n" - " inventory (+ photo) — update van inventory from a receipt photo\n" - " status — show bot status\n" - " location — get current van location\n" - " help — show this help message\n" - "Send a receipt photo with the message 'inventory' to scan it.\n" -) +@dataclass(frozen=True, slots=True) +class Command: + """A registered bot command.""" + + action: Callable[[SignalMessage, str], None] + help_text: str + role: Role | None # None = no role required (always allowed) -def help_action( - signal: SignalClient, - message: SignalMessage, - _llm: LLMClient, - _registry: DeviceRegistry, - _config: BotConfig, - _cmd: str, -) -> None: - """Return the help text for the bot.""" - signal.reply(message, HELP_TEXT) +class Bot: + """Holds shared resources and dispatches incoming messages to command handlers.""" + def __init__( + self, + signal: SignalClient, + llm: LLMClient, + registry: DeviceRegistry, + config: BotConfig, + ) -> None: + self.signal = signal + self.llm = llm + self.registry = registry + self.config = config + self.commands: dict[str, Command] = { + "help": Command(action=self._help, help_text="show this help message", role=None), + "status": Command(action=self._status, help_text="show bot status", role=Role.STATUS), + "inventory": Command( + action=self._inventory, + help_text="update van inventory from a text list or receipt photo", + role=Role.INVENTORY, + ), + "location": Command( + action=self._location, + help_text="get current van location", + role=Role.LOCATION, + ), + } -def status_action( - signal: SignalClient, - message: SignalMessage, - llm: LLMClient, - registry: DeviceRegistry, - _config: BotConfig, - _cmd: str, -) -> None: - """Return the status of the bot.""" - models = llm.list_models() - model_list = ", ".join(models[:10]) - device_count = len(registry.list_devices()) - signal.reply( - message, - f"Bot online.\nLLM: {llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}", - ) + # -- actions -------------------------------------------------------------- + def _help(self, message: SignalMessage, _cmd: str) -> None: + """Return help text filtered to the sender's roles.""" + self.signal.reply(message, self._build_help(self.registry.get_roles(message.source))) -def unknown_action( - signal: SignalClient, - message: SignalMessage, - _llm: LLMClient, - _registry: DeviceRegistry, - _config: BotConfig, - cmd: str, -) -> None: - """Return an error message for an unknown command.""" - signal.reply(message, f"Unknown command: {cmd}\n\n{HELP_TEXT}") - - -def inventory_action( - signal: SignalClient, - message: SignalMessage, - llm: LLMClient, - _registry: DeviceRegistry, - config: BotConfig, - _cmd: str, -) -> None: - """Process an inventory update.""" - handle_inventory_update(message, signal, llm, config.inventory_api_url) - - -def location_action( - signal: SignalClient, - message: SignalMessage, - _llm: LLMClient, - _registry: DeviceRegistry, - config: BotConfig, - _cmd: str, -) -> None: - """Reply with current van location.""" - handle_location_request(message, signal, config.ha_url, config.ha_token, config.ha_location_entity) - - -def dispatch( - message: SignalMessage, - signal: SignalClient, - llm: LLMClient, - registry: DeviceRegistry, - config: BotConfig, -) -> None: - """Route an incoming message to the right command handler.""" - source = message.source - - if not registry.is_verified(source) or not registry.has_safety_number(source): - logger.info(f"Device {source} not verified, ignoring message") - return - - text = message.message.strip() - parts = text.split() - - if not parts and not message.attachments: - return - - cmd = parts[0].lower() if parts else "" - - commands = { - "help": help_action, - "status": status_action, - "inventory": inventory_action, - "location": location_action, - } - logger.info(f"f{source=} running {cmd=} with {message=}") - action = commands.get(cmd) - if action is None: - if message.attachments: - action = inventory_action - cmd = "inventory" - else: - return - - action(signal, message, llm, registry, config, cmd) - - -def _process_message( - message: SignalMessage, - signal: SignalClient, - llm: LLMClient, - registry: DeviceRegistry, - config: BotConfig, -) -> None: - """Process a single message, sending it to the dead letter queue after repeated failures.""" - max_attempts = config.max_message_attempts - for attempt in range(1, max_attempts + 1): - try: - safety_number = signal.get_safety_number(message.source) - registry.record_contact(message.source, safety_number) - dispatch(message, signal, llm, registry, config) - except Exception: - logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})") - else: - return - - logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue") - with Session(config.engine) as session: - session.add( - DeadLetterMessage( - source=message.source, - message=message.message, - received_at=utcnow(), - status=MessageStatus.UNPROCESSED, - ) + def _status(self, message: SignalMessage, _cmd: str) -> None: + """Return the status of the bot.""" + models = self.llm.list_models() + model_list = ", ".join(models[:10]) + device_count = len(self.registry.list_devices()) + self.signal.reply( + message, + f"Bot online.\nLLM: {self.llm.model}\nAvailable models: {model_list}\nKnown devices: {device_count}", ) - session.commit() + def _inventory(self, message: SignalMessage, _cmd: str) -> None: + """Process an inventory update.""" + handle_inventory_update(message, self.signal, self.llm, self.config.inventory_api_url) -def run_loop( - config: BotConfig, - signal: SignalClient, - llm: LLMClient, - registry: DeviceRegistry, -) -> None: - """Listen for messages via WebSocket, reconnecting on failure.""" - logger.info("Bot started — listening via WebSocket") + def _location(self, message: SignalMessage, _cmd: str) -> None: + """Reply with current van location.""" + handle_location_request( + message, self.signal, self.config.ha_url, self.config.ha_token, self.config.ha_location_entity + ) - @retry( - stop=stop_after_attempt(config.max_retries), - wait=wait_exponential(multiplier=config.reconnect_delay, max=config.max_reconnect_delay), - before_sleep=before_sleep_log(logger, logging.WARNING), - reraise=True, - ) - def _listen() -> None: - for message in signal.listen(): - logger.info(f"Message from {message.source}: {message.message[:80]}") - _process_message(message, signal, llm, registry, config) + # -- dispatch ------------------------------------------------------------- - try: - _listen() - except Exception: - logger.critical("Max retries exceeded, shutting down") - raise + def _build_help(self, roles: list[Role]) -> str: + """Build help text showing only the commands the user can access.""" + is_admin = Role.ADMIN in roles + lines = ["Available commands:"] + for name, cmd in self.commands.items(): + if cmd.role is None or is_admin or cmd.role in roles: + lines.append(f" {name:20s} — {cmd.help_text}") + return "\n".join(lines) + + def dispatch(self, message: SignalMessage) -> None: + """Route an incoming message to the right command handler.""" + source = message.source + + if not self.registry.is_verified(source) or not self.registry.has_safety_number(source): + logger.info(f"Device {source} not verified, ignoring message") + return + + text = message.message.strip() + parts = text.split() + + if not parts and not message.attachments: + return + + cmd = parts[0].lower() if parts else "" + + logger.info(f"f{source=} running {cmd=} with {message=}") + + command = self.commands.get(cmd) + if command is None: + if message.attachments: + command = self.commands["inventory"] + cmd = "inventory" + else: + return + + if command.role is not None and not self.registry.has_role(source, command.role): + logger.warning(f"Device {source} denied access to {cmd!r}") + self.signal.reply(message, f"Permission denied: you do not have the '{command.role}' role.") + return + + command.action(message, cmd) + + def process_message(self, message: SignalMessage) -> None: + """Process a single message, sending it to the dead letter queue after repeated failures.""" + max_attempts = self.config.max_message_attempts + for attempt in range(1, max_attempts + 1): + try: + safety_number = self.signal.get_safety_number(message.source) + self.registry.record_contact(message.source, safety_number) + self.dispatch(message) + except Exception: + logger.exception(f"Failed to process message (attempt {attempt}/{max_attempts})") + else: + return + + logger.error(f"Message from {message.source} failed {max_attempts} times, sending to dead letter queue") + with Session(self.config.engine) as session: + session.add( + DeadLetterMessage( + source=message.source, + message=message.message, + received_at=utcnow(), + status=MessageStatus.UNPROCESSED, + ) + ) + session.commit() + + def run(self) -> None: + """Listen for messages via WebSocket, reconnecting on failure.""" + logger.info("Bot started — listening via WebSocket") + + @retry( + stop=stop_after_attempt(self.config.max_retries), + wait=wait_exponential(multiplier=self.config.reconnect_delay, max=self.config.max_reconnect_delay), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + def _listen() -> None: + for message in self.signal.listen(): + logger.info(f"Message from {message.source}: {message.message[:80]}") + self.process_message(message) + + try: + _listen() + except Exception: + logger.critical("Max retries exceeded, shutting down") + raise def main( @@ -242,7 +225,8 @@ def main( LLMClient(model=llm_model, host=llm_host, port=llm_port, timeout=llm_timeout) as llm, ): registry = DeviceRegistry(signal, engine) - run_loop(config, signal, llm, registry) + bot = Bot(signal, llm, registry, config) + bot.run() if __name__ == "__main__": diff --git a/python/signal_bot/models.py b/python/signal_bot/models.py index 0baddad..1c9a0be 100644 --- a/python/signal_bot/models.py +++ b/python/signal_bot/models.py @@ -18,6 +18,15 @@ class TrustLevel(StrEnum): BLOCKED = "blocked" +class Role(StrEnum): + """RBAC roles — one per command, plus admin which grants all.""" + + ADMIN = "admin" + STATUS = "status" + INVENTORY = "inventory" + LOCATION = "location" + + class MessageStatus(StrEnum): """Dead letter queue message status."""