diff --git a/pyproject.toml b/pyproject.toml index 3f58e7b..72afef0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,4 +112,5 @@ exclude_lines = [ [tool.pytest.ini_options] addopts = "-n auto -ra" +testpaths = ["tests"] # --cov=system_tools --cov-report=term-missing --cov-report=xml --cov-report=html --cov-branch diff --git a/python/alembic/richie/versions/2026_03_18-seprating_signal_bot_database_6b275323f435.py b/python/alembic/richie/versions/2026_03_18-seprating_signal_bot_database_6b275323f435.py index 22d4621..89af2ef 100644 --- a/python/alembic/richie/versions/2026_03_18-seprating_signal_bot_database_6b275323f435.py +++ b/python/alembic/richie/versions/2026_03_18-seprating_signal_bot_database_6b275323f435.py @@ -76,7 +76,7 @@ def downgrade() -> None: sa.Column( "id", sa.SMALLINT(), - server_default=sa.text("nextval(f'{schema}.role_id_seq'::regclass)"), + server_default=sa.text(f"nextval('{schema}.role_id_seq'::regclass)"), autoincrement=True, nullable=False, ), diff --git a/python/orm/richie/__init__.py b/python/orm/richie/__init__.py index 0d11aef..7be641a 100644 --- a/python/orm/richie/__init__.py +++ b/python/orm/richie/__init__.py @@ -11,6 +11,7 @@ from python.orm.richie.contact import ( Need, RelationshipType, ) + __all__ = [ "Bill", "Contact", diff --git a/python/orm/signal_bot/models.py b/python/orm/signal_bot/models.py index 4dc6e45..126fee5 100644 --- a/python/orm/signal_bot/models.py +++ b/python/orm/signal_bot/models.py @@ -4,8 +4,7 @@ from __future__ import annotations from datetime import datetime -from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint -from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall @@ -41,7 +40,7 @@ class SignalDevice(SignalBotTableBase): phone_number: Mapped[str] = mapped_column(String(50), unique=True) safety_number: Mapped[str | None] trust_level: Mapped[TrustLevel] = mapped_column( - ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"), + Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False), default=TrustLevel.UNVERIFIED, ) last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True)) @@ -58,6 +57,6 @@ class DeadLetterMessage(SignalBotTableBase): message: Mapped[str] = mapped_column(Text) received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) status: Mapped[MessageStatus] = mapped_column( - ENUM(MessageStatus, name="message_status", create_type=True, schema="main"), + Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False), default=MessageStatus.UNPROCESSED, ) diff --git a/python/signal_bot/device_registry.py b/python/signal_bot/device_registry.py index 5d0737a..d4dd299 100644 --- a/python/signal_bot/device_registry.py +++ b/python/signal_bot/device_registry.py @@ -63,10 +63,9 @@ class DeviceRegistry: return with Session(self.engine) as session: - device = ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() if device: if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED: @@ -100,10 +99,9 @@ class DeviceRegistry: Returns True if the device was found and verified. """ with Session(self.engine) as session: - device = ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() if not device: logger.warning(f"Cannot verify unknown device: {phone_number}") @@ -141,10 +139,9 @@ class DeviceRegistry: def grant_role(self, phone_number: str, role: Role) -> bool: """Add a role to a device. Called by admin over SSH.""" with Session(self.engine) as session: - device = ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() if not device: logger.warning(f"Cannot grant role for unknown device: {phone_number}") @@ -153,9 +150,7 @@ class DeviceRegistry: if any(record.name == role for record in device.roles): return True - role_record = session.execute( - select(RoleRecord).where(RoleRecord.name == role) - ).scalar_one_or_none() + role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none() if not role_record: logger.warning(f"Unknown role: {role}") @@ -170,10 +165,9 @@ class DeviceRegistry: def revoke_role(self, phone_number: str, role: Role) -> bool: """Remove a role from a device. Called by admin over SSH.""" with Session(self.engine) as session: - device = ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() if not device: logger.warning(f"Cannot revoke role for unknown device: {phone_number}") @@ -188,19 +182,16 @@ class DeviceRegistry: def set_roles(self, phone_number: str, roles: list[Role]) -> bool: """Replace all roles for a device. Called by admin over SSH.""" with Session(self.engine) as session: - device = ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() if not device: logger.warning(f"Cannot set roles for unknown device: {phone_number}") return False role_names = [str(role) for role in roles] - records = list( - session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all() - ) + records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()) device.roles = records session.commit() self._update_cache(phone_number, device) @@ -235,10 +226,9 @@ class DeviceRegistry: def _load_device(self, phone_number: str) -> SignalDevice | None: """Fetch a device by phone number (with joined roles).""" with Session(self.engine) as session: - return ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + return session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() def _update_cache(self, phone_number: str, device: SignalDevice) -> None: """Refresh the cache entry for a device.""" @@ -254,10 +244,9 @@ class DeviceRegistry: def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool: """Update the trust level for a device.""" with Session(self.engine) as session: - device = ( - session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number)) - .scalar_one_or_none() - ) + device = session.execute( + select(SignalDevice).where(SignalDevice.phone_number == phone_number) + ).scalar_one_or_none() if not device: return False diff --git a/python/signal_bot/main.py b/python/signal_bot/main.py index 5432765..bc847d2 100644 --- a/python/signal_bot/main.py +++ b/python/signal_bot/main.py @@ -106,10 +106,14 @@ class Bot: """Route an incoming message to the right command handler.""" source = message.source - if not self.registry.is_verified(source) or not self.registry.has_safety_number(source): + if not self.registry.is_verified(source): logger.info(f"Device {source} not verified, ignoring message") return + if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN): + logger.warning(f"Admin device {source} missing safety number, ignoring message") + return + text = message.message.strip() parts = text.split() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f626133 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +"""Shared test fixtures.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from sqlalchemy import create_engine, event + +from python.orm.signal_bot.base import SignalBotBase + +if TYPE_CHECKING: + from collections.abc import Generator + + from sqlalchemy.engine import Engine + + +@pytest.fixture(scope="session") +def sqlite_engine() -> Generator[Engine]: + """Create an in-memory SQLite engine for testing.""" + engine = create_engine("sqlite:///:memory:") + + @event.listens_for(engine, "connect") + def _set_sqlite_pragma(dbapi_connection, _connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + SignalBotBase.metadata.create_all(engine) + yield engine + engine.dispose() + + +@pytest.fixture +def engine(sqlite_engine: Engine) -> Generator[Engine]: + """Yield the shared engine after cleaning all tables between tests.""" + yield sqlite_engine + with sqlite_engine.begin() as connection: + for table in reversed(SignalBotBase.metadata.sorted_tables): + connection.execute(table.delete()) diff --git a/tests/test_signal_bot.py b/tests/test_signal_bot.py index bd0be1c..82efdc9 100644 --- a/tests/test_signal_bot.py +++ b/tests/test_signal_bot.py @@ -7,9 +7,7 @@ from datetime import timedelta from unittest.mock import MagicMock, patch import pytest -from sqlalchemy import create_engine -from python.orm.richie.base import RichieBase from python.signal_bot.commands.inventory import ( _format_summary, parse_llm_response, @@ -17,7 +15,7 @@ from python.signal_bot.commands.inventory import ( from python.signal_bot.commands.location import _format_location, handle_location_request from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry from python.signal_bot.llm_client import LLMClient -from python.signal_bot.main import dispatch +from python.signal_bot.main import Bot from python.signal_bot.models import ( BotConfig, InventoryItem, @@ -78,12 +76,6 @@ class TestDeviceRegistry: def signal_mock(self): return MagicMock(spec=SignalClient) - @pytest.fixture - def engine(self): - engine = create_engine("sqlite://") - RichieBase.metadata.create_all(engine) - return engine - @pytest.fixture def registry(self, signal_mock, engine): return DeviceRegistry(signal_mock, engine) @@ -131,12 +123,6 @@ class TestContactCache: def signal_mock(self): return MagicMock(spec=SignalClient) - @pytest.fixture - def engine(self): - engine = create_engine("sqlite://") - RichieBase.metadata.create_all(engine) - return engine - @pytest.fixture def registry(self, signal_mock, engine): return DeviceRegistry(signal_mock, engine) @@ -215,6 +201,7 @@ class TestContactCache: trust_level=old.trust_level, has_safety_number=old.has_safety_number, safety_number=old.safety_number, + roles=old.roles, ) with patch("python.signal_bot.device_registry.Session") as mock_session_cls: @@ -229,25 +216,15 @@ class TestContactCache: class TestLocationCommand: - def test_format_location_from_attributes(self): - payload = { - "state": "whatever", - "attributes": { - "latitude": 12.34, - "longitude": 56.78, - "speed": "45 mph", - "last_updated": "2024-01-01T00:00:00+00:00", - }, - } - response = _format_location(payload) + def test_format_location(self): + response = _format_location("12.34", "56.78") assert "12.34, 56.78" in response assert "maps.google.com" in response - assert "Speed: 45 mph" in response def test_handle_location_request_without_config(self): signal = MagicMock(spec=SignalClient) message = SignalMessage(source="+1234", timestamp=0, message="location") - handle_location_request(message, signal, None, None, "sensor.gps_location") + handle_location_request(message, signal, None, None) signal.reply.assert_called_once() assert "not configured" in signal.reply.call_args[0][1] @@ -266,11 +243,12 @@ class TestDispatch: mock = MagicMock(spec=DeviceRegistry) mock.is_verified.return_value = True mock.has_safety_number.return_value = True + mock.has_role.return_value = False + mock.get_roles.return_value = [] return mock @pytest.fixture - def config(self): - engine = create_engine("sqlite://") + def config(self, engine): return BotConfig( signal_api_url="http://localhost:8080", phone_number="+1234567890", @@ -278,46 +256,66 @@ class TestDispatch: engine=engine, ) - def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config): + @pytest.fixture + def bot(self, signal_mock, llm_mock, registry_mock, config): + return Bot(signal_mock, llm_mock, registry_mock, config) + + def test_unverified_device_ignored(self, bot, signal_mock, registry_mock): registry_mock.is_verified.return_value = False msg = SignalMessage(source="+1234", timestamp=0, message="help") - dispatch(msg, signal_mock, llm_mock, registry_mock, config) + bot.dispatch(msg) signal_mock.reply.assert_not_called() - def test_help_command(self, signal_mock, llm_mock, registry_mock, config): + def test_admin_without_safety_number_ignored(self, bot, signal_mock, registry_mock): + registry_mock.has_safety_number.return_value = False + registry_mock.has_role.return_value = True msg = SignalMessage(source="+1234", timestamp=0, message="help") - dispatch(msg, signal_mock, llm_mock, registry_mock, config) + bot.dispatch(msg) + signal_mock.reply.assert_not_called() + + def test_non_admin_without_safety_number_allowed(self, bot, signal_mock, registry_mock): + registry_mock.has_safety_number.return_value = False + registry_mock.has_role.return_value = False + registry_mock.get_roles.return_value = [] + msg = SignalMessage(source="+1234", timestamp=0, message="help") + bot.dispatch(msg) + signal_mock.reply.assert_called_once() + + def test_help_command(self, bot, signal_mock, registry_mock): + msg = SignalMessage(source="+1234", timestamp=0, message="help") + bot.dispatch(msg) signal_mock.reply.assert_called_once() assert "Available commands" in signal_mock.reply.call_args[0][1] - def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config): + def test_unknown_command_ignored(self, bot, signal_mock): msg = SignalMessage(source="+1234", timestamp=0, message="foobar") - dispatch(msg, signal_mock, llm_mock, registry_mock, config) + bot.dispatch(msg) signal_mock.reply.assert_not_called() - def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config): + def test_non_command_message_ignored(self, bot, signal_mock): msg = SignalMessage(source="+1234", timestamp=0, message="hello there") - dispatch(msg, signal_mock, llm_mock, registry_mock, config) + bot.dispatch(msg) signal_mock.reply.assert_not_called() - def test_status_command(self, signal_mock, llm_mock, registry_mock, config): + def test_status_command(self, bot, signal_mock, llm_mock, registry_mock): llm_mock.list_models.return_value = ["model1", "model2"] llm_mock.model = "test:7b" registry_mock.list_devices.return_value = [] + registry_mock.has_role.return_value = True msg = SignalMessage(source="+1234", timestamp=0, message="status") - dispatch(msg, signal_mock, llm_mock, registry_mock, config) + bot.dispatch(msg) signal_mock.reply.assert_called_once() assert "Bot online" in signal_mock.reply.call_args[0][1] - def test_location_command(self, signal_mock, llm_mock, registry_mock, config): + def test_location_command(self, bot, signal_mock, registry_mock, config): + registry_mock.has_role.return_value = True msg = SignalMessage(source="+1234", timestamp=0, message="location") with patch("python.signal_bot.main.handle_location_request") as mock_location: - dispatch(msg, signal_mock, llm_mock, registry_mock, config) + bot.dispatch(msg) mock_location.assert_called_once_with( msg, signal_mock, config.ha_url, config.ha_token, - config.ha_location_entity, )