mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 04:58:19 -04:00
added auth cashe
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BLOCKED_TTL = timedelta(minutes=60)
|
||||
_DEFAULT_TTL = timedelta(minutes=5)
|
||||
|
||||
|
||||
class _CacheEntry(NamedTuple):
|
||||
expires: datetime
|
||||
trust_level: TrustLevel
|
||||
has_safety_number: bool
|
||||
safety_number: str | None
|
||||
|
||||
|
||||
class DeviceRegistry:
|
||||
"""Manage device trust based on Signal safety numbers.
|
||||
@@ -33,15 +44,23 @@ class DeviceRegistry:
|
||||
def __init__(self, signal_client: SignalClient, engine: Engine) -> None:
|
||||
self.signal_client = signal_client
|
||||
self.engine = engine
|
||||
self._contact_cache: dict[str, _CacheEntry] = {}
|
||||
|
||||
def is_verified(self, phone_number: str) -> bool:
|
||||
"""Check if a phone number is verified."""
|
||||
device = self._get(phone_number)
|
||||
# if entry := self._cached(phone_number):
|
||||
# return entry.trust_level == TrustLevel.VERIFIED
|
||||
device = self.get_device(phone_number)
|
||||
return device is not None and device.trust_level == TrustLevel.VERIFIED
|
||||
|
||||
def record_contact(self, phone_number: str, safety_number: str | None = None) -> None:
|
||||
"""Record seeing a device. Creates entry if new, updates last_seen."""
|
||||
now = utcnow()
|
||||
|
||||
# entry = self._cached(phone_number)
|
||||
# if entry and entry.safety_number == safety_number:
|
||||
# return
|
||||
|
||||
with Session(self.engine) as session:
|
||||
device = session.execute(
|
||||
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||
@@ -65,9 +84,19 @@ class DeviceRegistry:
|
||||
|
||||
session.commit()
|
||||
|
||||
ttl = _BLOCKED_TTL if device.trust_level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
||||
self._contact_cache[phone_number] = _CacheEntry(
|
||||
expires=now + ttl,
|
||||
trust_level=device.trust_level,
|
||||
has_safety_number=device.safety_number is not None,
|
||||
safety_number=device.safety_number,
|
||||
)
|
||||
|
||||
def has_safety_number(self, phone_number: str) -> bool:
|
||||
"""Check if a device has a safety number on file."""
|
||||
device = self._get(phone_number)
|
||||
# if entry := self._cached(phone_number):
|
||||
# return entry.has_safety_number
|
||||
device = self.get_device(phone_number)
|
||||
return device is not None and device.safety_number is not None
|
||||
|
||||
def verify(self, phone_number: str) -> bool:
|
||||
@@ -87,6 +116,12 @@ class DeviceRegistry:
|
||||
device.trust_level = TrustLevel.VERIFIED
|
||||
self.signal_client.trust_identity(phone_number, trust_all_known_keys=True)
|
||||
session.commit()
|
||||
self._contact_cache[phone_number] = _CacheEntry(
|
||||
expires=utcnow() + _DEFAULT_TTL,
|
||||
trust_level=TrustLevel.VERIFIED,
|
||||
has_safety_number=device.safety_number is not None,
|
||||
safety_number=device.safety_number,
|
||||
)
|
||||
logger.info(f"Device verified: {phone_number}")
|
||||
return True
|
||||
|
||||
@@ -112,7 +147,14 @@ class DeviceRegistry:
|
||||
if number:
|
||||
self.record_contact(number, safety)
|
||||
|
||||
def _get(self, phone_number: str) -> SignalDevice | None:
|
||||
def _cached(self, phone_number: str) -> _CacheEntry | None:
|
||||
"""Return the cache entry if it exists and hasn't expired."""
|
||||
entry = self._contact_cache.get(phone_number)
|
||||
if entry and utcnow() < entry.expires:
|
||||
return entry
|
||||
return None
|
||||
|
||||
def get_device(self, phone_number: str) -> SignalDevice | None:
|
||||
"""Fetch a device by phone number."""
|
||||
with Session(self.engine) as session:
|
||||
return session.execute(
|
||||
@@ -131,6 +173,13 @@ class DeviceRegistry:
|
||||
|
||||
device.trust_level = level
|
||||
session.commit()
|
||||
ttl = _BLOCKED_TTL if level == TrustLevel.BLOCKED else _DEFAULT_TTL
|
||||
self._contact_cache[phone_number] = _CacheEntry(
|
||||
expires=utcnow() + ttl,
|
||||
trust_level=level,
|
||||
has_safety_number=device.safety_number is not None,
|
||||
safety_number=device.safety_number,
|
||||
)
|
||||
if log_msg:
|
||||
logger.info(f"{log_msg}: {phone_number}")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user