mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 04:58:19 -04:00
added sync_roles
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import TYPE_CHECKING, NamedTuple
|
from typing import TYPE_CHECKING, NamedTuple
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from python.common import utcnow
|
from python.common import utcnow
|
||||||
@@ -150,7 +150,7 @@ class DeviceRegistry:
|
|||||||
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if any(rr.name == role for rr in device.roles):
|
if any(record.name == role for record in device.roles):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
role_record = session.execute(
|
role_record = session.execute(
|
||||||
@@ -179,7 +179,7 @@ class DeviceRegistry:
|
|||||||
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
device.roles = [rr for rr in device.roles if rr.name != role]
|
device.roles = [record for record in device.roles if record.name != role]
|
||||||
session.commit()
|
session.commit()
|
||||||
self._update_cache(phone_number, device)
|
self._update_cache(phone_number, device)
|
||||||
logger.info(f"Device {phone_number} revoked role {role}")
|
logger.info(f"Device {phone_number} revoked role {role}")
|
||||||
@@ -197,7 +197,7 @@ class DeviceRegistry:
|
|||||||
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
role_names = [str(r) for r in roles]
|
role_names = [str(role) for role in roles]
|
||||||
records = list(
|
records = list(
|
||||||
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
|
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
|
||||||
)
|
)
|
||||||
@@ -272,4 +272,26 @@ class DeviceRegistry:
|
|||||||
|
|
||||||
def _extract_roles(device: SignalDevice) -> list[Role]:
|
def _extract_roles(device: SignalDevice) -> list[Role]:
|
||||||
"""Convert a device's RoleRecord objects to a list of Role enums."""
|
"""Convert a device's RoleRecord objects to a list of Role enums."""
|
||||||
return [Role(rr.name) for rr in device.roles]
|
return [Role(record.name) for record in device.roles]
|
||||||
|
|
||||||
|
|
||||||
|
def sync_roles(engine: Engine) -> None:
|
||||||
|
"""Sync the Role enum to the role table, adding new and removing stale entries."""
|
||||||
|
expected = {role.value for role in Role}
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
existing = {record.name for record in session.execute(select(RoleRecord)).scalars().all()}
|
||||||
|
|
||||||
|
to_add = expected - existing
|
||||||
|
to_remove = existing - expected
|
||||||
|
|
||||||
|
for name in to_add:
|
||||||
|
session.add(RoleRecord(name=name))
|
||||||
|
logger.info(f"Role added: {name}")
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
session.execute(delete(RoleRecord).where(RoleRecord.name.in_(to_remove)))
|
||||||
|
for name in to_remove:
|
||||||
|
logger.info(f"Role removed: {name}")
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from python.orm.common import get_postgres_engine
|
|||||||
from python.orm.richie.dead_letter_message import DeadLetterMessage
|
from python.orm.richie.dead_letter_message import DeadLetterMessage
|
||||||
from python.signal_bot.commands.inventory import handle_inventory_update
|
from python.signal_bot.commands.inventory import handle_inventory_update
|
||||||
from python.signal_bot.commands.location import handle_location_request
|
from python.signal_bot.commands.location import handle_location_request
|
||||||
from python.signal_bot.device_registry import DeviceRegistry
|
from python.signal_bot.device_registry import DeviceRegistry, sync_roles
|
||||||
from python.signal_bot.llm_client import LLMClient
|
from python.signal_bot.llm_client import LLMClient
|
||||||
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
|
from python.signal_bot.models import BotConfig, MessageStatus, Role, SignalMessage
|
||||||
from python.signal_bot.signal_client import SignalClient
|
from python.signal_bot.signal_client import SignalClient
|
||||||
@@ -203,6 +203,7 @@ def main(
|
|||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
engine = get_postgres_engine(name="SIGNALBOT")
|
engine = get_postgres_engine(name="SIGNALBOT")
|
||||||
|
sync_roles(engine)
|
||||||
config = BotConfig(
|
config = BotConfig(
|
||||||
signal_api_url=signal_api_url,
|
signal_api_url=signal_api_url,
|
||||||
phone_number=phone_number,
|
phone_number=phone_number,
|
||||||
|
|||||||
Reference in New Issue
Block a user