fixed tests and treeftm

This commit is contained in:
2026-03-18 19:19:08 -04:00
parent cc73dfc467
commit 3f4373d1f6
8 changed files with 115 additions and 83 deletions

View File

@@ -112,4 +112,5 @@ exclude_lines = [
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-n auto -ra" addopts = "-n auto -ra"
testpaths = ["tests"]
# --cov=system_tools --cov-report=term-missing --cov-report=xml --cov-report=html --cov-branch # --cov=system_tools --cov-report=term-missing --cov-report=xml --cov-report=html --cov-branch

View File

@@ -76,7 +76,7 @@ def downgrade() -> None:
sa.Column( sa.Column(
"id", "id",
sa.SMALLINT(), sa.SMALLINT(),
server_default=sa.text("nextval(f'{schema}.role_id_seq'::regclass)"), server_default=sa.text(f"nextval('{schema}.role_id_seq'::regclass)"),
autoincrement=True, autoincrement=True,
nullable=False, nullable=False,
), ),

View File

@@ -11,6 +11,7 @@ from python.orm.richie.contact import (
Need, Need,
RelationshipType, RelationshipType,
) )
__all__ = [ __all__ = [
"Bill", "Bill",
"Contact", "Contact",

View File

@@ -4,8 +4,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
@@ -41,7 +40,7 @@ class SignalDevice(SignalBotTableBase):
phone_number: Mapped[str] = mapped_column(String(50), unique=True) phone_number: Mapped[str] = mapped_column(String(50), unique=True)
safety_number: Mapped[str | None] safety_number: Mapped[str | None]
trust_level: Mapped[TrustLevel] = mapped_column( trust_level: Mapped[TrustLevel] = mapped_column(
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"), Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
default=TrustLevel.UNVERIFIED, default=TrustLevel.UNVERIFIED,
) )
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True)) last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
@@ -58,6 +57,6 @@ class DeadLetterMessage(SignalBotTableBase):
message: Mapped[str] = mapped_column(Text) message: Mapped[str] = mapped_column(Text)
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
status: Mapped[MessageStatus] = mapped_column( status: Mapped[MessageStatus] = mapped_column(
ENUM(MessageStatus, name="message_status", create_type=True, schema="main"), Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
default=MessageStatus.UNPROCESSED, default=MessageStatus.UNPROCESSED,
) )

View File

@@ -63,10 +63,9 @@ class DeviceRegistry:
return return
with Session(self.engine) as session: with Session(self.engine) as session:
device = ( device = session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
if device: if device:
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED: if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
@@ -100,10 +99,9 @@ class DeviceRegistry:
Returns True if the device was found and verified. Returns True if the device was found and verified.
""" """
with Session(self.engine) as session: with Session(self.engine) as session:
device = ( device = session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
if not device: if not device:
logger.warning(f"Cannot verify unknown device: {phone_number}") logger.warning(f"Cannot verify unknown device: {phone_number}")
@@ -141,10 +139,9 @@ class DeviceRegistry:
def grant_role(self, phone_number: str, role: Role) -> bool: def grant_role(self, phone_number: str, role: Role) -> bool:
"""Add a role to a device. Called by admin over SSH.""" """Add a role to a device. Called by admin over SSH."""
with Session(self.engine) as session: with Session(self.engine) as session:
device = ( device = session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
if not device: if not device:
logger.warning(f"Cannot grant role for unknown device: {phone_number}") logger.warning(f"Cannot grant role for unknown device: {phone_number}")
@@ -153,9 +150,7 @@ class DeviceRegistry:
if any(record.name == role for record in device.roles): if any(record.name == role for record in device.roles):
return True return True
role_record = session.execute( role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
select(RoleRecord).where(RoleRecord.name == role)
).scalar_one_or_none()
if not role_record: if not role_record:
logger.warning(f"Unknown role: {role}") logger.warning(f"Unknown role: {role}")
@@ -170,10 +165,9 @@ class DeviceRegistry:
def revoke_role(self, phone_number: str, role: Role) -> bool: def revoke_role(self, phone_number: str, role: Role) -> bool:
"""Remove a role from a device. Called by admin over SSH.""" """Remove a role from a device. Called by admin over SSH."""
with Session(self.engine) as session: with Session(self.engine) as session:
device = ( device = session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
if not device: if not device:
logger.warning(f"Cannot revoke role for unknown device: {phone_number}") logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
@@ -188,19 +182,16 @@ class DeviceRegistry:
def set_roles(self, phone_number: str, roles: list[Role]) -> bool: def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
"""Replace all roles for a device. Called by admin over SSH.""" """Replace all roles for a device. Called by admin over SSH."""
with Session(self.engine) as session: with Session(self.engine) as session:
device = ( device = session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
if not device: if not device:
logger.warning(f"Cannot set roles for unknown device: {phone_number}") logger.warning(f"Cannot set roles for unknown device: {phone_number}")
return False return False
role_names = [str(role) for role in roles] role_names = [str(role) for role in roles]
records = list( records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
)
device.roles = records device.roles = records
session.commit() session.commit()
self._update_cache(phone_number, device) self._update_cache(phone_number, device)
@@ -235,10 +226,9 @@ class DeviceRegistry:
def _load_device(self, phone_number: str) -> SignalDevice | None: def _load_device(self, phone_number: str) -> SignalDevice | None:
"""Fetch a device by phone number (with joined roles).""" """Fetch a device by phone number (with joined roles)."""
with Session(self.engine) as session: with Session(self.engine) as session:
return ( return session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
def _update_cache(self, phone_number: str, device: SignalDevice) -> None: def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
"""Refresh the cache entry for a device.""" """Refresh the cache entry for a device."""
@@ -254,10 +244,9 @@ class DeviceRegistry:
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool: def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
"""Update the trust level for a device.""" """Update the trust level for a device."""
with Session(self.engine) as session: with Session(self.engine) as session:
device = ( device = session.execute(
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) select(SignalDevice).where(SignalDevice.phone_number == phone_number)
.scalar_one_or_none() ).scalar_one_or_none()
)
if not device: if not device:
return False return False

View File

@@ -106,10 +106,14 @@ class Bot:
"""Route an incoming message to the right command handler.""" """Route an incoming message to the right command handler."""
source = message.source source = message.source
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source): if not self.registry.is_verified(source):
logger.info(f"Device {source} not verified, ignoring message") logger.info(f"Device {source} not verified, ignoring message")
return return
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
logger.warning(f"Admin device {source} missing safety number, ignoring message")
return
text = message.message.strip() text = message.message.strip()
parts = text.split() parts = text.split()

40
tests/conftest.py Normal file
View File

@@ -0,0 +1,40 @@
"""Shared test fixtures."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from sqlalchemy import create_engine, event
from python.orm.signal_bot.base import SignalBotBase
if TYPE_CHECKING:
from collections.abc import Generator
from sqlalchemy.engine import Engine
@pytest.fixture(scope="session")
def sqlite_engine() -> Generator[Engine]:
"""Create an in-memory SQLite engine for testing."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, _connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
SignalBotBase.metadata.create_all(engine)
yield engine
engine.dispose()
@pytest.fixture
def engine(sqlite_engine: Engine) -> Generator[Engine]:
"""Yield the shared engine after cleaning all tables between tests."""
yield sqlite_engine
with sqlite_engine.begin() as connection:
for table in reversed(SignalBotBase.metadata.sorted_tables):
connection.execute(table.delete())

View File

@@ -7,9 +7,7 @@ from datetime import timedelta
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from sqlalchemy import create_engine
from python.orm.richie.base import RichieBase
from python.signal_bot.commands.inventory import ( from python.signal_bot.commands.inventory import (
_format_summary, _format_summary,
parse_llm_response, parse_llm_response,
@@ -17,7 +15,7 @@ from python.signal_bot.commands.inventory import (
from python.signal_bot.commands.location import _format_location, handle_location_request from python.signal_bot.commands.location import _format_location, handle_location_request
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
from python.signal_bot.llm_client import LLMClient from python.signal_bot.llm_client import LLMClient
from python.signal_bot.main import dispatch from python.signal_bot.main import Bot
from python.signal_bot.models import ( from python.signal_bot.models import (
BotConfig, BotConfig,
InventoryItem, InventoryItem,
@@ -78,12 +76,6 @@ class TestDeviceRegistry:
def signal_mock(self): def signal_mock(self):
return MagicMock(spec=SignalClient) return MagicMock(spec=SignalClient)
@pytest.fixture
def engine(self):
engine = create_engine("sqlite://")
RichieBase.metadata.create_all(engine)
return engine
@pytest.fixture @pytest.fixture
def registry(self, signal_mock, engine): def registry(self, signal_mock, engine):
return DeviceRegistry(signal_mock, engine) return DeviceRegistry(signal_mock, engine)
@@ -131,12 +123,6 @@ class TestContactCache:
def signal_mock(self): def signal_mock(self):
return MagicMock(spec=SignalClient) return MagicMock(spec=SignalClient)
@pytest.fixture
def engine(self):
engine = create_engine("sqlite://")
RichieBase.metadata.create_all(engine)
return engine
@pytest.fixture @pytest.fixture
def registry(self, signal_mock, engine): def registry(self, signal_mock, engine):
return DeviceRegistry(signal_mock, engine) return DeviceRegistry(signal_mock, engine)
@@ -215,6 +201,7 @@ class TestContactCache:
trust_level=old.trust_level, trust_level=old.trust_level,
has_safety_number=old.has_safety_number, has_safety_number=old.has_safety_number,
safety_number=old.safety_number, safety_number=old.safety_number,
roles=old.roles,
) )
with patch("python.signal_bot.device_registry.Session") as mock_session_cls: with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
@@ -229,25 +216,15 @@ class TestContactCache:
class TestLocationCommand: class TestLocationCommand:
def test_format_location_from_attributes(self): def test_format_location(self):
payload = { response = _format_location("12.34", "56.78")
"state": "whatever",
"attributes": {
"latitude": 12.34,
"longitude": 56.78,
"speed": "45 mph",
"last_updated": "2024-01-01T00:00:00+00:00",
},
}
response = _format_location(payload)
assert "12.34, 56.78" in response assert "12.34, 56.78" in response
assert "maps.google.com" in response assert "maps.google.com" in response
assert "Speed: 45 mph" in response
def test_handle_location_request_without_config(self): def test_handle_location_request_without_config(self):
signal = MagicMock(spec=SignalClient) signal = MagicMock(spec=SignalClient)
message = SignalMessage(source="+1234", timestamp=0, message="location") message = SignalMessage(source="+1234", timestamp=0, message="location")
handle_location_request(message, signal, None, None, "sensor.gps_location") handle_location_request(message, signal, None, None)
signal.reply.assert_called_once() signal.reply.assert_called_once()
assert "not configured" in signal.reply.call_args[0][1] assert "not configured" in signal.reply.call_args[0][1]
@@ -266,11 +243,12 @@ class TestDispatch:
mock = MagicMock(spec=DeviceRegistry) mock = MagicMock(spec=DeviceRegistry)
mock.is_verified.return_value = True mock.is_verified.return_value = True
mock.has_safety_number.return_value = True mock.has_safety_number.return_value = True
mock.has_role.return_value = False
mock.get_roles.return_value = []
return mock return mock
@pytest.fixture @pytest.fixture
def config(self): def config(self, engine):
engine = create_engine("sqlite://")
return BotConfig( return BotConfig(
signal_api_url="http://localhost:8080", signal_api_url="http://localhost:8080",
phone_number="+1234567890", phone_number="+1234567890",
@@ -278,46 +256,66 @@ class TestDispatch:
engine=engine, engine=engine,
) )
def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config): @pytest.fixture
def bot(self, signal_mock, llm_mock, registry_mock, config):
return Bot(signal_mock, llm_mock, registry_mock, config)
def test_unverified_device_ignored(self, bot, signal_mock, registry_mock):
registry_mock.is_verified.return_value = False registry_mock.is_verified.return_value = False
msg = SignalMessage(source="+1234", timestamp=0, message="help") msg = SignalMessage(source="+1234", timestamp=0, message="help")
dispatch(msg, signal_mock, llm_mock, registry_mock, config) bot.dispatch(msg)
signal_mock.reply.assert_not_called() signal_mock.reply.assert_not_called()
def test_help_command(self, signal_mock, llm_mock, registry_mock, config): def test_admin_without_safety_number_ignored(self, bot, signal_mock, registry_mock):
registry_mock.has_safety_number.return_value = False
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="help") msg = SignalMessage(source="+1234", timestamp=0, message="help")
dispatch(msg, signal_mock, llm_mock, registry_mock, config) bot.dispatch(msg)
signal_mock.reply.assert_not_called()
def test_non_admin_without_safety_number_allowed(self, bot, signal_mock, registry_mock):
registry_mock.has_safety_number.return_value = False
registry_mock.has_role.return_value = False
registry_mock.get_roles.return_value = []
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_called_once()
def test_help_command(self, bot, signal_mock, registry_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="help")
bot.dispatch(msg)
signal_mock.reply.assert_called_once() signal_mock.reply.assert_called_once()
assert "Available commands" in signal_mock.reply.call_args[0][1] assert "Available commands" in signal_mock.reply.call_args[0][1]
def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config): def test_unknown_command_ignored(self, bot, signal_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="foobar") msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
dispatch(msg, signal_mock, llm_mock, registry_mock, config) bot.dispatch(msg)
signal_mock.reply.assert_not_called() signal_mock.reply.assert_not_called()
def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config): def test_non_command_message_ignored(self, bot, signal_mock):
msg = SignalMessage(source="+1234", timestamp=0, message="hello there") msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
dispatch(msg, signal_mock, llm_mock, registry_mock, config) bot.dispatch(msg)
signal_mock.reply.assert_not_called() signal_mock.reply.assert_not_called()
def test_status_command(self, signal_mock, llm_mock, registry_mock, config): def test_status_command(self, bot, signal_mock, llm_mock, registry_mock):
llm_mock.list_models.return_value = ["model1", "model2"] llm_mock.list_models.return_value = ["model1", "model2"]
llm_mock.model = "test:7b" llm_mock.model = "test:7b"
registry_mock.list_devices.return_value = [] registry_mock.list_devices.return_value = []
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="status") msg = SignalMessage(source="+1234", timestamp=0, message="status")
dispatch(msg, signal_mock, llm_mock, registry_mock, config) bot.dispatch(msg)
signal_mock.reply.assert_called_once() signal_mock.reply.assert_called_once()
assert "Bot online" in signal_mock.reply.call_args[0][1] assert "Bot online" in signal_mock.reply.call_args[0][1]
def test_location_command(self, signal_mock, llm_mock, registry_mock, config): def test_location_command(self, bot, signal_mock, registry_mock, config):
registry_mock.has_role.return_value = True
msg = SignalMessage(source="+1234", timestamp=0, message="location") msg = SignalMessage(source="+1234", timestamp=0, message="location")
with patch("python.signal_bot.main.handle_location_request") as mock_location: with patch("python.signal_bot.main.handle_location_request") as mock_location:
dispatch(msg, signal_mock, llm_mock, registry_mock, config) bot.dispatch(msg)
mock_location.assert_called_once_with( mock_location.assert_called_once_with(
msg, msg,
signal_mock, signal_mock,
config.ha_url, config.ha_url,
config.ha_token, config.ha_token,
config.ha_location_entity,
) )