mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 04:58:19 -04:00
added auth cashe
This commit is contained in:
@@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from datetime import datetime, timedelta
|
||||||
|
from typing import TYPE_CHECKING, NamedTuple
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_BLOCKED_TTL = timedelta(minutes=60)
|
||||||
|
_DEFAULT_TTL = timedelta(minutes=5)
|
||||||
|
|
||||||
|
|
||||||
|
class _CacheEntry(NamedTuple):
|
||||||
|
expires: datetime
|
||||||
|
trust_level: TrustLevel
|
||||||
|
has_safety_number: bool
|
||||||
|
safety_number: str | None
|
||||||
|
|
||||||
|
|
||||||
class DeviceRegistry:
|
class DeviceRegistry:
|
||||||
"""Manage device trust based on Signal safety numbers.
|
"""Manage device trust based on Signal safety numbers.
|
||||||
@@ -33,15 +44,23 @@ class DeviceRegistry:
|
|||||||
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
||||||
self.signal_client = signal_client
|
self.signal_client = signal_client
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self._contact_cache: dict[str, _CacheEntry] = {}
|
||||||
|
|
||||||
def is_verified(self, phone_number: str) -> bool:
|
def is_verified(self, phone_number: str) -> bool:
|
||||||
"""Check if a phone number is verified."""
|
"""Check if a phone number is verified."""
|
||||||
device = self._get(phone_number)
|
# if entry := self._cached(phone_number):
|
||||||
|
# return entry.trust_level == TrustLevel.VERIFIED
|
||||||
|
device = self.get_device(phone_number)
|
||||||
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
||||||
|
|
||||||
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
||||||
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
|
|
||||||
|
# entry = self._cached(phone_number)
|
||||||
|
# if entry and entry.safety_number == safety_number:
|
||||||
|
# return
|
||||||
|
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = session.execute(
|
device = session.execute(
|
||||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
@@ -65,9 +84,19 @@ class DeviceRegistry:
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
||||||
|
self._contact_cache[phone_number] = _CacheEntry(
|
||||||
|
expires=now + ttl,
|
||||||
|
trust_level=device.trust_level,
|
||||||
|
has_safety_number=device.safety_number is not None,
|
||||||
|
safety_number=device.safety_number,
|
||||||
|
)
|
||||||
|
|
||||||
def has_safety_number(self, phone_number: str) -> bool:
|
def has_safety_number(self, phone_number: str) -> bool:
|
||||||
"""Check if a device has a safety number on file."""
|
"""Check if a device has a safety number on file."""
|
||||||
device = self._get(phone_number)
|
# if entry := self._cached(phone_number):
|
||||||
|
# return entry.has_safety_number
|
||||||
|
device = self.get_device(phone_number)
|
||||||
return device is not None and device.safety_number is not None
|
return device is not None and device.safety_number is not None
|
||||||
|
|
||||||
def verify(self, phone_number: str) -> bool:
|
def verify(self, phone_number: str) -> bool:
|
||||||
@@ -87,6 +116,12 @@ class DeviceRegistry:
|
|||||||
device.trust_level = TrustLevel.VERIFIED
|
device.trust_level = TrustLevel.VERIFIED
|
||||||
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
self._contact_cache[phone_number] = _CacheEntry(
|
||||||
|
expires=utcnow() + _DEFAULT_TTL,
|
||||||
|
trust_level=TrustLevel.VERIFIED,
|
||||||
|
has_safety_number=device.safety_number is not None,
|
||||||
|
safety_number=device.safety_number,
|
||||||
|
)
|
||||||
logger.info(f"Device verified: {phone_number}")
|
logger.info(f"Device verified: {phone_number}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -112,7 +147,14 @@ class DeviceRegistry:
|
|||||||
if number:
|
if number:
|
||||||
self.record_contact(number, safety)
|
self.record_contact(number, safety)
|
||||||
|
|
||||||
def _get(self, phone_number: str) -> SignalDevice | None:
|
def _cached(self, phone_number: str) -> _CacheEntry | None:
|
||||||
|
"""Return the cache entry if it exists and hasn't expired."""
|
||||||
|
entry = self._contact_cache.get(phone_number)
|
||||||
|
if entry and utcnow() < entry.expires:
|
||||||
|
return entry
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_device(self, phone_number: str) -> SignalDevice | None:
|
||||||
"""Fetch a device by phone number."""
|
"""Fetch a device by phone number."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
return session.execute(
|
return session.execute(
|
||||||
@@ -131,6 +173,13 @@ class DeviceRegistry:
|
|||||||
|
|
||||||
device.trust_level = level
|
device.trust_level = level
|
||||||
session.commit()
|
session.commit()
|
||||||
|
ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
||||||
|
self._contact_cache[phone_number] = _CacheEntry(
|
||||||
|
expires=utcnow() + ttl,
|
||||||
|
trust_level=level,
|
||||||
|
has_safety_number=device.safety_number is not None,
|
||||||
|
safety_number=device.safety_number,
|
||||||
|
)
|
||||||
if log_msg:
|
if log_msg:
|
||||||
logger.info(f"{log_msg}: {phone_number}")
|
logger.info(f"{log_msg}: {phone_number}")
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import MagicMock
|
from datetime import timedelta
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
@@ -13,7 +14,7 @@ from python.signal_bot.commands.inventory import (
|
|||||||
_format_summary,
|
_format_summary,
|
||||||
parse_llm_response,
|
parse_llm_response,
|
||||||
)
|
)
|
||||||
from python.signal_bot.device_registry import DeviceRegistry
|
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
|
||||||
from python.signal_bot.llm_client import LLMClient
|
from python.signal_bot.llm_client import LLMClient
|
||||||
from python.signal_bot.main import dispatch
|
from python.signal_bot.main import dispatch
|
||||||
from python.signal_bot.models import (
|
from python.signal_bot.models import (
|
||||||
@@ -124,6 +125,108 @@ class TestDeviceRegistry:
|
|||||||
assert len(registry.list_devices()) == 2
|
assert len(registry.list_devices()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestContactCache:
|
||||||
|
@pytest.fixture
|
||||||
|
def signal_mock(self):
|
||||||
|
return MagicMock(spec=SignalClient)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self):
|
||||||
|
engine = create_engine("sqlite://")
|
||||||
|
RichieBase.metadata.create_all(engine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry(self, signal_mock, engine):
|
||||||
|
return DeviceRegistry(signal_mock, engine)
|
||||||
|
|
||||||
|
def test_second_call_uses_cache(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
assert "+1234" in registry._contact_cache
|
||||||
|
|
||||||
|
with patch.object(registry, "engine") as mock_engine:
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
mock_engine.assert_not_called()
|
||||||
|
|
||||||
|
def test_unverified_gets_default_ttl(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
from python.common import utcnow
|
||||||
|
|
||||||
|
entry = registry._contact_cache["+1234"]
|
||||||
|
expected = utcnow() + _DEFAULT_TTL
|
||||||
|
assert abs((entry.expires - expected).total_seconds()) < 2
|
||||||
|
assert entry.trust_level == TrustLevel.UNVERIFIED
|
||||||
|
assert entry.has_safety_number is True
|
||||||
|
|
||||||
|
def test_blocked_gets_blocked_ttl(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
registry.block("+1234")
|
||||||
|
from python.common import utcnow
|
||||||
|
|
||||||
|
entry = registry._contact_cache["+1234"]
|
||||||
|
expected = utcnow() + _BLOCKED_TTL
|
||||||
|
assert abs((entry.expires - expected).total_seconds()) < 2
|
||||||
|
assert entry.trust_level == TrustLevel.BLOCKED
|
||||||
|
|
||||||
|
def test_verify_updates_cache(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
registry.verify("+1234")
|
||||||
|
entry = registry._contact_cache["+1234"]
|
||||||
|
assert entry.trust_level == TrustLevel.VERIFIED
|
||||||
|
|
||||||
|
def test_block_updates_cache(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
registry.block("+1234")
|
||||||
|
entry = registry._contact_cache["+1234"]
|
||||||
|
assert entry.trust_level == TrustLevel.BLOCKED
|
||||||
|
|
||||||
|
def test_unverify_updates_cache(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
registry.verify("+1234")
|
||||||
|
registry.unverify("+1234")
|
||||||
|
entry = registry._contact_cache["+1234"]
|
||||||
|
assert entry.trust_level == TrustLevel.UNVERIFIED
|
||||||
|
|
||||||
|
def test_is_verified_uses_cache(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
registry.verify("+1234")
|
||||||
|
with patch.object(registry, "engine") as mock_engine:
|
||||||
|
assert registry.is_verified("+1234") is True
|
||||||
|
mock_engine.assert_not_called()
|
||||||
|
|
||||||
|
def test_has_safety_number_uses_cache(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
with patch.object(registry, "engine") as mock_engine:
|
||||||
|
assert registry.has_safety_number("+1234") is True
|
||||||
|
mock_engine.assert_not_called()
|
||||||
|
|
||||||
|
def test_no_safety_number_cached(self, registry):
|
||||||
|
registry.record_contact("+1234", None)
|
||||||
|
with patch.object(registry, "engine") as mock_engine:
|
||||||
|
assert registry.has_safety_number("+1234") is False
|
||||||
|
mock_engine.assert_not_called()
|
||||||
|
|
||||||
|
def test_expired_cache_hits_db(self, registry):
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
old = registry._contact_cache["+1234"]
|
||||||
|
registry._contact_cache["+1234"] = _CacheEntry(
|
||||||
|
expires=old.expires - timedelta(minutes=10),
|
||||||
|
trust_level=old.trust_level,
|
||||||
|
has_safety_number=old.has_safety_number,
|
||||||
|
safety_number=old.safety_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
mock_device = MagicMock()
|
||||||
|
mock_device.trust_level = TrustLevel.UNVERIFIED
|
||||||
|
mock_session.execute.return_value.scalar_one_or_none.return_value = mock_device
|
||||||
|
registry.record_contact("+1234", "abc")
|
||||||
|
mock_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestDispatch:
|
class TestDispatch:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def signal_mock(self):
|
def signal_mock(self):
|
||||||
@@ -152,21 +255,20 @@ class TestDispatch:
|
|||||||
|
|
||||||
def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
||||||
registry_mock.is_verified.return_value = False
|
registry_mock.is_verified.return_value = False
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="!help")
|
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
||||||
signal_mock.reply.assert_not_called()
|
signal_mock.reply.assert_not_called()
|
||||||
|
|
||||||
def test_help_command(self, signal_mock, llm_mock, registry_mock, config):
|
def test_help_command(self, signal_mock, llm_mock, registry_mock, config):
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="!help")
|
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
||||||
signal_mock.reply.assert_called_once()
|
signal_mock.reply.assert_called_once()
|
||||||
assert "Available commands" in signal_mock.reply.call_args[0][1]
|
assert "Available commands" in signal_mock.reply.call_args[0][1]
|
||||||
|
|
||||||
def test_unknown_command(self, signal_mock, llm_mock, registry_mock, config):
|
def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="!foobar")
|
msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
||||||
signal_mock.reply.assert_called_once()
|
signal_mock.reply.assert_not_called()
|
||||||
assert "Unknown command" in signal_mock.reply.call_args[0][1]
|
|
||||||
|
|
||||||
def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
|
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
|
||||||
@@ -177,7 +279,7 @@ class TestDispatch:
|
|||||||
llm_mock.list_models.return_value = ["model1", "model2"]
|
llm_mock.list_models.return_value = ["model1", "model2"]
|
||||||
llm_mock.model = "test:7b"
|
llm_mock.model = "test:7b"
|
||||||
registry_mock.list_devices.return_value = []
|
registry_mock.list_devices.return_value = []
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="!status")
|
msg = SignalMessage(source="+1234", timestamp=0, message="status")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
||||||
signal_mock.reply.assert_called_once()
|
signal_mock.reply.assert_called_once()
|
||||||
assert "Bot online" in signal_mock.reply.call_args[0][1]
|
assert "Bot online" in signal_mock.reply.call_args[0][1]
|
||||||
|
|||||||
Reference in New Issue
Block a user