fixed tests and treeftm

This commit is contained in:
2026-03-18 19:19:08 -04:00
parent cc73dfc467
commit 3f4373d1f6
8 changed files with 115 additions and 83 deletions

View File

@@ -112,4 +112,5 @@ exclude_lines = [
[tool.pytest.ini_options]
addopts = "-n auto -ra"
testpaths = ["tests"]
# --cov=system_tools --cov-report=term-missing --cov-report=xml --cov-report=html --cov-branch

View File

@@ -76,7 +76,7 @@ def downgrade() -> None:
sa.Column(
"id",
sa.SMALLINT(),
server_default=sa.text("nextval(f'{schema}.role_id_seq'::regclass)"),
server_default=sa.text(f"nextval('{schema}.role_id_seq'::regclass)"),
autoincrement=True,
nullable=False,
),

View File

@@ -11,6 +11,7 @@ from python.orm.richie.contact import (
Need,
RelationshipType,
)
__all__ = [
"Bill",
"Contact",

View File

@@ -4,8 +4,7 @@ from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
@@ -41,7 +40,7 @@ class SignalDevice(SignalBotTableBase):
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
safety_number: Mapped[str | None]
trust_level: Mapped[TrustLevel] = mapped_column(
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"),
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
default=TrustLevel.UNVERIFIED,
)
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
@@ -58,6 +57,6 @@ class DeadLetterMessage(SignalBotTableBase):
message: Mapped[str] = mapped_column(Text)
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
status: Mapped[MessageStatus] = mapped_column(
ENUM(MessageStatus, name="message_status", create_type=True, schema="main"),
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
default=MessageStatus.UNPROCESSED,
)

View File

@@ -63,10 +63,9 @@ class DeviceRegistry:
return
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -100,10 +99,9 @@ class DeviceRegistry:
Returns True if the device was found and verified.
"""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -141,10 +139,9 @@ class DeviceRegistry:
def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
@@ -153,9 +150,7 @@ class DeviceRegistry:
if any(record.name == role for record in device.roles):
return True
role_record = session.execute(
select(RoleRecord).where(RoleRecord.name == role)
).scalar_one_or_none()
role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
if not role_record:
logger.warning(f"Unknown role: {role}")
@@ -170,10 +165,9 @@ class DeviceRegistry:
def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
@@ -188,19 +182,16 @@ class DeviceRegistry:
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False
role_names = [str(role) for role in roles]
records = list(
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
device.roles = records
session.commit()
self._update_cache(phone_number, device)
@@ -235,10 +226,9 @@ class DeviceRegistry:
def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session:
return (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
return session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device."""
@@ -254,10 +244,9 @@ class DeviceRegistry:
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device."""
with Session(self.engine) as session:
device = (
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
.scalar_one_or_none()
)
device = session.execute(
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
).scalar_one_or_none()
if not device:
return False

View File

@@ -106,10 +106,14 @@ class Bot:
"""Route an incoming message to the right command handler."""
source = message.source
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source):
if not self.registry.is_verified(source):
logger.info(f"Device {source} not verified, ignoring message")
return
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
logger.warning(f"Admin device {source} missing safety number, ignoring message")
return
text = message.message.strip()
parts = text.split()

40
tests/conftest.py Normal file
View File

@@ -0,0 +1,40 @@
"""Shared test fixtures."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from sqlalchemy import create_engine, event
from python.orm.signal_bot.base import SignalBotBase
if TYPE_CHECKING:
from collections.abc import Generator
from sqlalchemy.engine import Engine
@pytest.fixture(scope="session")
def sqlite_engine() -> Generator[Engine]:
"""Create an in-memory SQLite engine for testing."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, _connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
SignalBotBase.metadata.create_all(engine)
yield engine
engine.dispose()
@pytest.fixture
def engine(sqlite_engine: Engine) -> Generator[Engine]:
"""Yield the shared engine after cleaning all tables between tests."""
yield sqlite_engine
with sqlite_engine.begin() as connection:
for table in reversed(SignalBotBase.metadata.sorted_tables):
connection.execute(table.delete())

View File

@@ -7,9 +7,7 @@ from datetime import timedelta
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from python.orm.richie.base import RichieBase
from python.signal_bot.commands.inventory import (
_format_summary,
parse_llm_response,
@@ -17,7 +15,7 @@ from python.signal_bot.commands.inventory import (
from python.signal_bot.commands.location import _format_location, handle_location_request
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
from python.signal_bot.llm_client import LLMClient
from python.signal_bot.main import dispatch
from python.signal_bot.main import Bot
from python.signal_bot.models import (
BotConfig,
InventoryItem,
@@ -78,12 +76,6 @@ class TestDeviceRegistry:
def signal_mock(self):
return MagicMock(spec=SignalClient)
@pytest.fixture
def engine(self):
engine = create_engine("sqlite://")
RichieBase.metadata.create_all(engine)
return engine
@pytest.fixture
def registry(self, signal_mock, engine):
return DeviceRegistry(signal_mock, engine)
@@ -131,12 +123,6 @@ class TestContactCache:
def signal_mock(self):
return MagicMock(spec=SignalClient)
@pytest.fixture
def engine(self):
engine = create_engine("sqlite://")
RichieBase.metadata.create_all(engine)
return engine
@pytest.fixture
def registry(self, signal_mock, engine):
return DeviceRegistry(signal_mock, engine)
@@ -215,6 +201,7 @@ class TestContactCache:
trust_level=old.trust_level,
has_safety_number=old.has_safety_number,
safety_number=old.safety_number,
roles=old.roles,
)
with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
@@ -229,25 +216,15 @@ class TestContactCache:
class TestLocationCommand:
def test_format_location_from_attributes(self):
payload = {
"state": "whatever",
"attributes": {
"latitude": 12.34,
"longitude": 56.78,
"speed": "45 mph",
"last_updated": "2024-01-01T00:00:00+00:00",
},
}
response = _format_location(payload)
def test_format_location(self):
response = _format_location("12.34", "56.78")
assert "12.34, 56.78" in response
assert "maps.google.com" in response
assert "Speed: 45 mph" in response
def test_handle_location_request_without_config(self):
signal = MagicMock(spec=SignalClient)
message = SignalMessage(source="+1234", timestamp=0, message="location")
handle_location_request(message, signal, None, None, "sensor.gps_location")
handle_location_request(message, signal, None, None)
signal.reply.assert_called_once()
assert "not configured" in signal.reply.call_args[0][1]
@@ -266,11 +243,12 @@ class TestDispatch:
mock = MagicMock(spec=DeviceRegistry)
mock.is_verified.return_value = True
mock.has_safety_number.return_value = True
mock.has_role.return_value = False
mock.get_roles.return_value = []
return mock
@pytest.fixture
def config(self):
engine = create_engine("sqlite://")
def config(self, engine):
return BotConfig(
signal_api_url="http://localhost:8080",
phone_number="+1234567890",
@@ -278,46 +256,66 @@ class TestDispatch:
engine=engine,
)
def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config):
@pytest.fixture
def bot(self, signal_mock, llm_mock, registry_mock, config):
return Bot(signal_mock, llm_mock, registry_mock, config)
def test_unverified_device_ignored(self, bot, signal_mock, registry_mock):
registry_mock.is_verified.return_value = False
msg = SignalMessage(source="+1234", timestamp=0, message="help")
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_help_command(self, signal_mock, llm_mock, registry_mock, config):
def test_admin_without_safety_number_ignored(self, bot, signal_mock, registry_mock):
registry_mock.has_safety_number.return_value = False
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="help")
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_non_admin_without_safety_number_allowed(self, bot, signal_mock, registry_mock):
registry_mock.has_safety_number.return_value = False
registry_mock.has_role.return_value = False
registry_mock.get_roles.return_value = []
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
def test_help_command(self, bot, signal_mock, registry_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
assert "Available commands" in signal_mock.reply.call_args[0][1]
def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config):
def test_unknown_command_ignored(self, bot, signal_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config):
def test_non_command_message_ignored(self, bot, signal_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_status_command(self, signal_mock, llm_mock, registry_mock, config):
def test_status_command(self, bot, signal_mock, llm_mock, registry_mock):
llm_mock.list_models.return_value = ["model1", "model2"]
llm_mock.model = "test:7b"
registry_mock.list_devices.return_value = []
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="status")
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
assert "Bot online" in signal_mock.reply.call_args[0][1]
def test_location_command(self, signal_mock, llm_mock, registry_mock, config):
def test_location_command(self, bot, signal_mock, registry_mock, config):
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="location")
with patch("python.signal_bot.main.handle_location_request") as mock_location:
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
bot.dispatch(msg)
mock_location.assert_called_once_with(
msg,
signal_mock,
config.ha_url,
config.ha_token,
config.ha_location_entity,
)