diff --git a/python/signal_bot/device_registry.py b/python/signal_bot/device_registry.py index 48c96de..55970c4 100644 --- a/python/signal_bot/device_registry.py +++ b/python/signal_bot/device_registry.py @@ -6,7 +6,7 @@ import logging from datetime import datetime, timedelta from typing import TYPE_CHECKING, NamedTuple -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.orm import Session from python.common import utcnow @@ -150,7 +150,7 @@ class DeviceRegistry: logger.warning(f"Cannot grant role for unknown device: {phone_number}") return False - if any(rr.name == role for rr in device.roles): + if any(record.name == role for record in device.roles): return True role_record = session.execute( @@ -179,7 +179,7 @@ class DeviceRegistry: logger.warning(f"Cannot revoke role for unknown device: {phone_number}") return False - device.roles = [rr for rr in device.roles if rr.name != role] + device.roles = [record for record in device.roles if record.name != role] session.commit() self._update_cache(phone_number, device) logger.info(f"Device {phone_number} revoked role {role}") @@ -197,7 +197,7 @@ class DeviceRegistry: logger.warning(f"Cannot set roles for unknown device: {phone_number}") return False - role_names = [str(r) for r in roles] + role_names = [str(role) for role in roles] records = list( session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all() ) @@ -272,4 +272,26 @@ class DeviceRegistry: def _extract_roles(device: SignalDevice) -> list[Role]: """Convert a device's RoleRecord objects to a list of Role enums.""" - return [Role(rr.name) for rr in device.roles] + return [Role(record.name) for record in device.roles] + + +def sync_roles(engine: Engine) -> None: + """Sync the Role enum to the role table, adding new and removing stale entries.""" + expected = {role.value for role in Role} + + with Session(engine) as session: + existing = {record.name for record in session.execute(select(RoleRecord)).scalars().all()} + + to_add = expected - existing + to_remove = existing - expected + + for name in to_add: + session.add(RoleRecord(name=name)) + logger.info(f"Role added: {name}") + + if to_remove: + session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove))) + for name in to_remove: + logger.info(f"Role removed: {name}") + + session.commit() diff --git a/python/signal_bot/main.py b/python/signal_bot/main.py index 08e98d9..133e8d3 100644 --- a/python/signal_bot/main.py +++ b/python/signal_bot/main.py @@ -19,7 +19,7 @@ from python.orm.common import get_postgres_engine from python.orm.richie.dead_letter_message import DeadLetterMessage from python.signal_bot.commands.inventory import handle_inventory_update from python.signal_bot.commands.location import handle_location_request -from python.signal_bot.device_registry import DeviceRegistry +from python.signal_bot.device_registry import DeviceRegistry, sync_roles from python.signal_bot.llm_client import LLMClient from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage from python.signal_bot.signal_client import SignalClient @@ -203,6 +203,7 @@ def main( raise ValueError(error) engine = get_postgres_engine(name="SIGNALBOT") + sync_roles(engine) config = BotConfig( signal_api_url=signal_api_url, phone_number=phone_number,