fixed tests and treeftm

This commit is contained in:
2026-03-18 19:19:08 -04:00
parent cc73dfc467
commit 3f4373d1f6
8 changed files with 115 additions and 83 deletions

View File

@@ -76,7 +76,7 @@ def downgrade() -> None:
sa.Column(
"id",
sa.SMALLINT(),
server_default=sa.text("nextval(f'{schema}.role_id_seq'::regclass)"),
server_default=sa.text(f"nextval('{schema}.role_id_seq'::regclass)"),
autoincrement=True,
nullable=False,
),

View File

@@ -11,6 +11,7 @@ from python.orm.richie.contact import (
Need,
RelationshipType,
)
__all__ = [
"Bill",
"Contact",

View File

@@ -4,8 +4,7 @@ from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
@@ -41,7 +40,7 @@ class SignalDevice(SignalBotTableBase):
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
safety_number: Mapped[str | None]
trust_level: Mapped[TrustLevel] = mapped_column(
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"),
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
default=TrustLevel.UNVERIFIED,
)
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
@@ -58,6 +57,6 @@ class DeadLetterMessage(SignalBotTableBase):
message: Mapped[str] = mapped_column(Text)
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
status: Mapped[MessageStatus] = mapped_column(
ENUM(MessageStatus, name="message_status", create_type=True, schema="main"),
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
default=MessageStatus.UNPROCESSED,
)

View File

@@ -63,10 +63,9 @@ class DeviceRegistry:
return
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -100,10 +99,9 @@ class DeviceRegistry:
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -141,10 +139,9 @@ class DeviceRegistry:
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
@@ -153,9 +150,7 @@ class DeviceRegistry:
if any(record.name == role for record in device.roles):
return True
role_record = session.execute(
select(RoleRecord).where(RoleRecord.name == role)
).scalar_one_or_none()
role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
@@ -170,10 +165,9 @@ class DeviceRegistry:
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
@@ -188,19 +182,16 @@ class DeviceRegistry:
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(role) for role in roles]
records = list(
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
device.roles = records
session.commit()
self._update_cache(phone_number, device)
@@ -235,10 +226,9 @@ class DeviceRegistry:
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
return session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
@@ -254,10 +244,9 @@ class DeviceRegistry:
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
return False

View File

@@ -106,10 +106,14 @@ class Bot:
"""Route an incoming message to the right command handler."""
source = message.source
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source):
if not self.registry.is_verified(source):
logger.info(f"Device {source} not verified, ignoring message")
return
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
logger.warning(f"Admin device {source} missing safety number, ignoring message")
return
text = message.message.strip()
parts = text.split()