mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 04:58:19 -04:00
fixed tests and treeftm
This commit is contained in:
@@ -76,7 +76,7 @@ def downgrade() -> None:
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.SMALLINT(),
|
||||
server_default=sa.text("nextval(f'{schema}.role_id_seq'::regclass)"),
|
||||
server_default=sa.text(f"nextval('{schema}.role_id_seq'::regclass)"),
|
||||
autoincrement=True,
|
||||
nullable=False,
|
||||
),
|
||||
|
||||
@@ -11,6 +11,7 @@ from python.orm.richie.contact import (
|
||||
Need,
|
||||
RelationshipType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Bill",
|
||||
"Contact",
|
||||
|
||||
@@ -4,8 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import ENUM
|
||||
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
|
||||
@@ -41,7 +40,7 @@ class SignalDevice(SignalBotTableBase):
|
||||
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
safety_number: Mapped[str | None]
|
||||
trust_level: Mapped[TrustLevel] = mapped_column(
|
||||
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"),
|
||||
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
|
||||
default=TrustLevel.UNVERIFIED,
|
||||
)
|
||||
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
@@ -58,6 +57,6 @@ class DeadLetterMessage(SignalBotTableBase):
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
status: Mapped[MessageStatus] = mapped_column(
|
||||
ENUM(MessageStatus, name="message_status", create_type=True, schema="main"),
|
||||
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
|
||||
default=MessageStatus.UNPROCESSED,
|
||||
)
|
||||
|
||||
@@ -63,10 +63,9 @@ class DeviceRegistry:
|
||||
return
|
||||
|
||||
with Session(self.engine) as session:
|
||||
device = (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if device:
|
||||
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
||||
@@ -100,10 +99,9 @@ class DeviceRegistry:
|
||||
Returns True if the device was found and verified.
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
device = (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
||||
@@ -141,10 +139,9 @@ class DeviceRegistry:
|
||||
def grant_role(self, phone_number: str, role: Role) -> bool:
|
||||
"""Add a role to a device. Called by admin over SSH."""
|
||||
with Session(self.engine) as session:
|
||||
device = (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
||||
@@ -153,9 +150,7 @@ class DeviceRegistry:
|
||||
if any(record.name == role for record in device.roles):
|
||||
return True
|
||||
|
||||
role_record = session.execute(
|
||||
select(RoleRecord).where(RoleRecord.name == role)
|
||||
).scalar_one_or_none()
|
||||
role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
|
||||
|
||||
if not role_record:
|
||||
logger.warning(f"Unknown role: {role}")
|
||||
@@ -170,10 +165,9 @@ class DeviceRegistry:
|
||||
def revoke_role(self, phone_number: str, role: Role) -> bool:
|
||||
"""Remove a role from a device. Called by admin over SSH."""
|
||||
with Session(self.engine) as session:
|
||||
device = (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
||||
@@ -188,19 +182,16 @@ class DeviceRegistry:
|
||||
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
|
||||
"""Replace all roles for a device. Called by admin over SSH."""
|
||||
with Session(self.engine) as session:
|
||||
device = (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not device:
|
||||
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
||||
return False
|
||||
|
||||
role_names = [str(role) for role in roles]
|
||||
records = list(
|
||||
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
|
||||
)
|
||||
records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
|
||||
device.roles = records
|
||||
session.commit()
|
||||
self._update_cache(phone_number, device)
|
||||
@@ -235,10 +226,9 @@ class DeviceRegistry:
|
||||
def _load_device(self, phone_number: str) -> SignalDevice | None:
|
||||
"""Fetch a device by phone number (with joined roles)."""
|
||||
with Session(self.engine) as session:
|
||||
return (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
return session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
|
||||
"""Refresh the cache entry for a device."""
|
||||
@@ -254,10 +244,9 @@ class DeviceRegistry:
|
||||
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
||||
"""Update the trust level for a device."""
|
||||
with Session(self.engine) as session:
|
||||
device = (
|
||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not device:
|
||||
return False
|
||||
|
||||
@@ -106,10 +106,14 @@ class Bot:
|
||||
"""Route an incoming message to the right command handler."""
|
||||
source = message.source
|
||||
|
||||
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source):
|
||||
if not self.registry.is_verified(source):
|
||||
logger.info(f"Device {source} not verified, ignoring message")
|
||||
return
|
||||
|
||||
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
|
||||
logger.warning(f"Admin device {source} missing safety number, ignoring message")
|
||||
return
|
||||
|
||||
text = message.message.strip()
|
||||
parts = text.split()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user