added sync_roles

This commit is contained in:
2026-03-17 09:58:20 -04:00
parent 7d2fbaea43
commit 1b3e6725ea
2 changed files with 29 additions and 6 deletions

View File

@@ -6,7 +6,7 @@ import logging
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, NamedTuple
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from python.common import utcnow
@@ -150,7 +150,7 @@ class DeviceRegistry:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
return False
if any(rr.name == role for rr in device.roles):
if any(record.name == role for record in device.roles):
return True
role_record = session.execute(
@@ -179,7 +179,7 @@ class DeviceRegistry:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
return False
device.roles = [rr for rr in device.roles if rr.name != role]
device.roles = [record for record in device.roles if record.name != role]
session.commit()
self._update_cache(phone_number, device)
logger.info(f"Device {phone_number} revoked role {role}")
@@ -197,7 +197,7 @@ class DeviceRegistry:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(r) for r in roles]
role_names = [str(role) for role in roles]
records = list(
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
@@ -272,4 +272,26 @@ class DeviceRegistry:
def _extract_roles(device: SignalDevice) -> list[Role]:
"""Convert a device's RoleRecord objects to a list of Role enums."""
return [Role(rr.name) for rr in device.roles]
return [Role(record.name) for record in device.roles]
def sync_roles(engine: Engine) -> None:
"""Sync the Role enum to the role table, adding new and removing stale entries."""
expected = {role.value for role in Role}
with Session(engine) as session:
existing = {record.name for record in session.execute(select(RoleRecord)).scalars().all()}
to_add = expected - existing
to_remove = existing - expected
for name in to_add:
session.add(RoleRecord(name=name))
logger.info(f"Role added: {name}")
if to_remove:
session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove)))
for name in to_remove:
logger.info(f"Role removed: {name}")
session.commit()

View File

@@ -19,7 +19,7 @@ from python.orm.common import get_postgres_engine
from python.orm.richie.dead_letter_message import DeadLetterMessage
from python.signal_bot.commands.inventory import handle_inventory_update
from python.signal_bot.commands.location import handle_location_request
from python.signal_bot.device_registry import DeviceRegistry
from python.signal_bot.device_registry import DeviceRegistry, sync_roles
from python.signal_bot.llm_client import LLMClient
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
from python.signal_bot.signal_client import SignalClient
@@ -203,6 +203,7 @@ def main(
raise ValueError(error)
engine = get_postgres_engine(name="SIGNALBOT")
sync_roles(engine)
config = BotConfig(
signal_api_url=signal_api_url,
phone_number=phone_number,