diff --git a/python/signal_bot/device_registry.py b/python/signal_bot/device_registry.py index 460cf83..fd1cde0 100644 --- a/python/signal_bot/device_registry.py +++ b/python/signal_bot/device_registry.py @@ -3,7 +3,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, NamedTuple from sqlalchemy import select from sqlalchemy.orm import Session @@ -19,6 +20,16 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_BLOCKED_TTL = timedelta(minutes=60) +_DEFAULT_TTL = timedelta(minutes=5) + + +class _CacheEntry(NamedTuple): + expires: datetime + trust_level: TrustLevel + has_safety_number: bool + safety_number: str | None + class DeviceRegistry: """Manage device trust based on Signal safety numbers. @@ -33,15 +44,23 @@ class DeviceRegistry: def __init__(self, signal_client: SignalClient, engine: Engine) -> None: self.signal_client = signal_client self.engine = engine + self._contact_cache: dict[str, _CacheEntry] = {} def is_verified(self, phone_number: str) -> bool: """Check if a phone number is verified.""" - device = self._get(phone_number) + # if entry := self._cached(phone_number): + # return entry.trust_level == TrustLevel.VERIFIED + device = self.get_device(phone_number) return device is not None and device.trust_level == TrustLevel.VERIFIED def record_contact(self, phone_number: str, safety_number: str | None = None) -> None: """Record seeing a device. Creates entry if new, updates last_seen.""" now = utcnow() + + # entry = self._cached(phone_number) + # if entry and entry.safety_number == safety_number: + # return + with Session(self.engine) as session: device = session.execute( select(SignalDevice).where(SignalDevice.phone_number == phone_number) @@ -65,9 +84,19 @@ class DeviceRegistry: session.commit() + ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL + self._contact_cache[phone_number] = _CacheEntry( + expires=now + ttl, + trust_level=device.trust_level, + has_safety_number=device.safety_number is not None, + safety_number=device.safety_number, + ) + def has_safety_number(self, phone_number: str) -> bool: """Check if a device has a safety number on file.""" - device = self._get(phone_number) + # if entry := self._cached(phone_number): + # return entry.has_safety_number + device = self.get_device(phone_number) return device is not None and device.safety_number is not None def verify(self, phone_number: str) -> bool: @@ -87,6 +116,12 @@ class DeviceRegistry: device.trust_level = TrustLevel.VERIFIED self.signal_client.trust_identity(phone_number, trust_all_known_keys=True) session.commit() + self._contact_cache[phone_number] = _CacheEntry( + expires=utcnow() + _DEFAULT_TTL, + trust_level=TrustLevel.VERIFIED, + has_safety_number=device.safety_number is not None, + safety_number=device.safety_number, + ) logger.info(f"Device verified: {phone_number}") return True @@ -112,7 +147,14 @@ class DeviceRegistry: if number: self.record_contact(number, safety) - def _get(self, phone_number: str) -> SignalDevice | None: + def _cached(self, phone_number: str) -> _CacheEntry | None: + """Return the cache entry if it exists and hasn't expired.""" + entry = self._contact_cache.get(phone_number) + if entry and utcnow() < entry.expires: + return entry + return None + + def get_device(self, phone_number: str) -> SignalDevice | None: """Fetch a device by phone number.""" with Session(self.engine) as session: return session.execute( @@ -131,6 +173,13 @@ class DeviceRegistry: device.trust_level = level session.commit() + ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL + self._contact_cache[phone_number] = _CacheEntry( + expires=utcnow() + ttl, + trust_level=level, + has_safety_number=device.safety_number is not None, + safety_number=device.safety_number, + ) if log_msg: logger.info(f"{log_msg}: {phone_number}") return True diff --git a/tests/test_signal_bot.py b/tests/test_signal_bot.py index 533704d..8a2a12b 100644 --- a/tests/test_signal_bot.py +++ b/tests/test_signal_bot.py @@ -3,7 +3,8 @@ from __future__ import annotations import json -from unittest.mock import MagicMock +from datetime import timedelta +from unittest.mock import MagicMock, patch import pytest from sqlalchemy import create_engine @@ -13,7 +14,7 @@ from python.signal_bot.commands.inventory import ( _format_summary, parse_llm_response, ) -from python.signal_bot.device_registry import DeviceRegistry +from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry from python.signal_bot.llm_client import LLMClient from python.signal_bot.main import dispatch from python.signal_bot.models import ( @@ -124,6 +125,108 @@ class TestDeviceRegistry: assert len(registry.list_devices()) == 2 +class TestContactCache: + @pytest.fixture + def signal_mock(self): + return MagicMock(spec=SignalClient) + + @pytest.fixture + def engine(self): + engine = create_engine("sqlite://") + RichieBase.metadata.create_all(engine) + return engine + + @pytest.fixture + def registry(self, signal_mock, engine): + return DeviceRegistry(signal_mock, engine) + + def test_second_call_uses_cache(self, registry): + registry.record_contact("+1234", "abc") + assert "+1234" in registry._contact_cache + + with patch.object(registry, "engine") as mock_engine: + registry.record_contact("+1234", "abc") + mock_engine.assert_not_called() + + def test_unverified_gets_default_ttl(self, registry): + registry.record_contact("+1234", "abc") + from python.common import utcnow + + entry = registry._contact_cache["+1234"] + expected = utcnow() + _DEFAULT_TTL + assert abs((entry.expires - expected).total_seconds()) < 2 + assert entry.trust_level == TrustLevel.UNVERIFIED + assert entry.has_safety_number is True + + def test_blocked_gets_blocked_ttl(self, registry): + registry.record_contact("+1234", "abc") + registry.block("+1234") + from python.common import utcnow + + entry = registry._contact_cache["+1234"] + expected = utcnow() + _BLOCKED_TTL + assert abs((entry.expires - expected).total_seconds()) < 2 + assert entry.trust_level == TrustLevel.BLOCKED + + def test_verify_updates_cache(self, registry): + registry.record_contact("+1234", "abc") + registry.verify("+1234") + entry = registry._contact_cache["+1234"] + assert entry.trust_level == TrustLevel.VERIFIED + + def test_block_updates_cache(self, registry): + registry.record_contact("+1234", "abc") + registry.block("+1234") + entry = registry._contact_cache["+1234"] + assert entry.trust_level == TrustLevel.BLOCKED + + def test_unverify_updates_cache(self, registry): + registry.record_contact("+1234", "abc") + registry.verify("+1234") + registry.unverify("+1234") + entry = registry._contact_cache["+1234"] + assert entry.trust_level == TrustLevel.UNVERIFIED + + def test_is_verified_uses_cache(self, registry): + registry.record_contact("+1234", "abc") + registry.verify("+1234") + with patch.object(registry, "engine") as mock_engine: + assert registry.is_verified("+1234") is True + mock_engine.assert_not_called() + + def test_has_safety_number_uses_cache(self, registry): + registry.record_contact("+1234", "abc") + with patch.object(registry, "engine") as mock_engine: + assert registry.has_safety_number("+1234") is True + mock_engine.assert_not_called() + + def test_no_safety_number_cached(self, registry): + registry.record_contact("+1234", None) + with patch.object(registry, "engine") as mock_engine: + assert registry.has_safety_number("+1234") is False + mock_engine.assert_not_called() + + def test_expired_cache_hits_db(self, registry): + registry.record_contact("+1234", "abc") + old = registry._contact_cache["+1234"] + registry._contact_cache["+1234"] = _CacheEntry( + expires=old.expires - timedelta(minutes=10), + trust_level=old.trust_level, + has_safety_number=old.has_safety_number, + safety_number=old.safety_number, + ) + + with patch("python.signal_bot.device_registry.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_device = MagicMock() + mock_device.trust_level = TrustLevel.UNVERIFIED + mock_session.execute.return_value.scalar_one_or_none.return_value = mock_device + registry.record_contact("+1234", "abc") + mock_session.execute.assert_called_once() + + class TestDispatch: @pytest.fixture def signal_mock(self): @@ -152,21 +255,20 @@ class TestDispatch: def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config): registry_mock.is_verified.return_value = False - msg = SignalMessage(source="+1234", timestamp=0, message="!help") + msg = SignalMessage(source="+1234", timestamp=0, message="help") dispatch(msg, signal_mock, llm_mock, registry_mock, config) signal_mock.reply.assert_not_called() def test_help_command(self, signal_mock, llm_mock, registry_mock, config): - msg = SignalMessage(source="+1234", timestamp=0, message="!help") + msg = SignalMessage(source="+1234", timestamp=0, message="help") dispatch(msg, signal_mock, llm_mock, registry_mock, config) signal_mock.reply.assert_called_once() assert "Available commands" in signal_mock.reply.call_args[0][1] - def test_unknown_command(self, signal_mock, llm_mock, registry_mock, config): - msg = SignalMessage(source="+1234", timestamp=0, message="!foobar") + def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config): + msg = SignalMessage(source="+1234", timestamp=0, message="foobar") dispatch(msg, signal_mock, llm_mock, registry_mock, config) - signal_mock.reply.assert_called_once() - assert "Unknown command" in signal_mock.reply.call_args[0][1] + signal_mock.reply.assert_not_called() def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config): msg = SignalMessage(source="+1234", timestamp=0, message="hello there") @@ -177,7 +279,7 @@ class TestDispatch: llm_mock.list_models.return_value = ["model1", "model2"] llm_mock.model = "test:7b" registry_mock.list_devices.return_value = [] - msg = SignalMessage(source="+1234", timestamp=0, message="!status") + msg = SignalMessage(source="+1234", timestamp=0, message="status") dispatch(msg, signal_mock, llm_mock, registry_mock, config) signal_mock.reply.assert_called_once() assert "Bot online" in signal_mock.reply.call_args[0][1]