fixed tests and treeftm

This commit is contained in:
2026-03-18 19:19:08 -04:00
parent cc73dfc467
commit 3f4373d1f6
8 changed files with 115 additions and 83 deletions

View File

@@ -63,10 +63,9 @@ class DeviceRegistry:
return
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -100,10 +99,9 @@ class DeviceRegistry:
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -141,10 +139,9 @@ class DeviceRegistry:
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
@@ -153,9 +150,7 @@ class DeviceRegistry:
if any(record.name == role for record in device.roles):
return True
role_record = session.execute(
select(RoleRecord).where(RoleRecord.name == role)
).scalar_one_or_none()
role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
@@ -170,10 +165,9 @@ class DeviceRegistry:
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
@@ -188,19 +182,16 @@ class DeviceRegistry:
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(role) for role in roles]
records = list(
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
device.roles = records
session.commit()
self._update_cache(phone_number, device)
@@ -235,10 +226,9 @@ class DeviceRegistry:
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
return session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
@@ -254,10 +244,9 @@ class DeviceRegistry:
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
return False

View File

@@ -106,10 +106,14 @@ class Bot:
"""Route an incoming message to the right command handler."""
source = message.source
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source):
if not self.registry.is_verified(source):
logger.info(f"Device {source} not verified, ignoring message")
return
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
logger.warning(f"Admin device {source} missing safety number, ignoring message")
return
text = message.message.strip()
parts = text.split()