mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
fixed tests and treeftm
This commit is contained in:
@@ -112,4 +112,5 @@ exclude_lines = [
|
|||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "-n auto -ra"
|
addopts = "-n auto -ra"
|
||||||
|
testpaths = ["tests"]
|
||||||
# --cov=system_tools --cov-report=term-missing --cov-report=xml --cov-report=html --cov-branch
|
# --cov=system_tools --cov-report=term-missing --cov-report=xml --cov-report=html --cov-branch
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ def downgrade() -> None:
|
|||||||
sa.Column(
|
sa.Column(
|
||||||
"id",
|
"id",
|
||||||
sa.SMALLINT(),
|
sa.SMALLINT(),
|
||||||
server_default=sa.text("nextval(f'{schema}.role_id_seq'::regclass)"),
|
server_default=sa.text(f"nextval('{schema}.role_id_seq'::regclass)"),
|
||||||
autoincrement=True,
|
autoincrement=True,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from python.orm.richie.contact import (
|
|||||||
Need,
|
Need,
|
||||||
RelationshipType,
|
RelationshipType,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Bill",
|
"Bill",
|
||||||
"Contact",
|
"Contact",
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint
|
from sqlalchemy import DateTime, Enum, ForeignKey, SmallInteger, String, Text, UniqueConstraint
|
||||||
from sqlalchemy.dialects.postgresql import ENUM
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
|
from python.orm.signal_bot.base import SignalBotTableBase, SignalBotTableBaseSmall
|
||||||
@@ -41,7 +40,7 @@ class SignalDevice(SignalBotTableBase):
|
|||||||
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
|
phone_number: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
safety_number: Mapped[str | None]
|
safety_number: Mapped[str | None]
|
||||||
trust_level: Mapped[TrustLevel] = mapped_column(
|
trust_level: Mapped[TrustLevel] = mapped_column(
|
||||||
ENUM(TrustLevel, name="trust_level", create_type=True, schema="main"),
|
Enum(TrustLevel, name="trust_level", create_constraint=False, native_enum=False),
|
||||||
default=TrustLevel.UNVERIFIED,
|
default=TrustLevel.UNVERIFIED,
|
||||||
)
|
)
|
||||||
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||||
@@ -58,6 +57,6 @@ class DeadLetterMessage(SignalBotTableBase):
|
|||||||
message: Mapped[str] = mapped_column(Text)
|
message: Mapped[str] = mapped_column(Text)
|
||||||
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||||
status: Mapped[MessageStatus] = mapped_column(
|
status: Mapped[MessageStatus] = mapped_column(
|
||||||
ENUM(MessageStatus, name="message_status", create_type=True, schema="main"),
|
Enum(MessageStatus, name="message_status", create_constraint=False, native_enum=False),
|
||||||
default=MessageStatus.UNPROCESSED,
|
default=MessageStatus.UNPROCESSED,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,10 +63,9 @@ class DeviceRegistry:
|
|||||||
return
|
return
|
||||||
|
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = (
|
device = session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
if device:
|
if device:
|
||||||
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
if device.safety_number != safety_number and device.trust_level != TrustLevel.BLOCKED:
|
||||||
@@ -100,10 +99,9 @@ class DeviceRegistry:
|
|||||||
Returns True if the device was found and verified.
|
Returns True if the device was found and verified.
|
||||||
"""
|
"""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = (
|
device = session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
if not device:
|
if not device:
|
||||||
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
logger.warning(f"Cannot verify unknown device: {phone_number}")
|
||||||
@@ -141,10 +139,9 @@ class DeviceRegistry:
|
|||||||
def grant_role(self, phone_number: str, role: Role) -> bool:
|
def grant_role(self, phone_number: str, role: Role) -> bool:
|
||||||
"""Add a role to a device. Called by admin over SSH."""
|
"""Add a role to a device. Called by admin over SSH."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = (
|
device = session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
if not device:
|
if not device:
|
||||||
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
logger.warning(f"Cannot grant role for unknown device: {phone_number}")
|
||||||
@@ -153,9 +150,7 @@ class DeviceRegistry:
|
|||||||
if any(record.name == role for record in device.roles):
|
if any(record.name == role for record in device.roles):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
role_record = session.execute(
|
role_record = session.execute(select(RoleRecord).where(RoleRecord.name == role)).scalar_one_or_none()
|
||||||
select(RoleRecord).where(RoleRecord.name == role)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
|
|
||||||
if not role_record:
|
if not role_record:
|
||||||
logger.warning(f"Unknown role: {role}")
|
logger.warning(f"Unknown role: {role}")
|
||||||
@@ -170,10 +165,9 @@ class DeviceRegistry:
|
|||||||
def revoke_role(self, phone_number: str, role: Role) -> bool:
|
def revoke_role(self, phone_number: str, role: Role) -> bool:
|
||||||
"""Remove a role from a device. Called by admin over SSH."""
|
"""Remove a role from a device. Called by admin over SSH."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = (
|
device = session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
if not device:
|
if not device:
|
||||||
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
logger.warning(f"Cannot revoke role for unknown device: {phone_number}")
|
||||||
@@ -188,19 +182,16 @@ class DeviceRegistry:
|
|||||||
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
|
def set_roles(self, phone_number: str, roles: list[Role]) -> bool:
|
||||||
"""Replace all roles for a device. Called by admin over SSH."""
|
"""Replace all roles for a device. Called by admin over SSH."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = (
|
device = session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
if not device:
|
if not device:
|
||||||
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
logger.warning(f"Cannot set roles for unknown device: {phone_number}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
role_names = [str(role) for role in roles]
|
role_names = [str(role) for role in roles]
|
||||||
records = list(
|
records = list(session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all())
|
||||||
session.execute(select(RoleRecord).where(RoleRecord.name.in_(role_names))).scalars().all()
|
|
||||||
)
|
|
||||||
device.roles = records
|
device.roles = records
|
||||||
session.commit()
|
session.commit()
|
||||||
self._update_cache(phone_number, device)
|
self._update_cache(phone_number, device)
|
||||||
@@ -235,10 +226,9 @@ class DeviceRegistry:
|
|||||||
def _load_device(self, phone_number: str) -> SignalDevice | None:
|
def _load_device(self, phone_number: str) -> SignalDevice | None:
|
||||||
"""Fetch a device by phone number (with joined roles)."""
|
"""Fetch a device by phone number (with joined roles)."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
return (
|
return session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
|
def _update_cache(self, phone_number: str, device: SignalDevice) -> None:
|
||||||
"""Refresh the cache entry for a device."""
|
"""Refresh the cache entry for a device."""
|
||||||
@@ -254,10 +244,9 @@ class DeviceRegistry:
|
|||||||
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
def _set_trust(self, phone_number: str, level: str, log_msg: str | None = None) -> bool:
|
||||||
"""Update the trust level for a device."""
|
"""Update the trust level for a device."""
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
device = (
|
device = session.execute(
|
||||||
session.execute(select(SignalDevice).where(SignalDevice.phone_number == phone_number))
|
select(SignalDevice).where(SignalDevice.phone_number == phone_number)
|
||||||
.scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
)
|
|
||||||
|
|
||||||
if not device:
|
if not device:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -106,10 +106,14 @@ class Bot:
|
|||||||
"""Route an incoming message to the right command handler."""
|
"""Route an incoming message to the right command handler."""
|
||||||
source = message.source
|
source = message.source
|
||||||
|
|
||||||
if not self.registry.is_verified(source) or not self.registry.has_safety_number(source):
|
if not self.registry.is_verified(source):
|
||||||
logger.info(f"Device {source} not verified, ignoring message")
|
logger.info(f"Device {source} not verified, ignoring message")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not self.registry.has_safety_number(source) and self.registry.has_role(source, Role.ADMIN):
|
||||||
|
logger.warning(f"Admin device {source} missing safety number, ignoring message")
|
||||||
|
return
|
||||||
|
|
||||||
text = message.message.strip()
|
text = message.message.strip()
|
||||||
parts = text.split()
|
parts = text.split()
|
||||||
|
|
||||||
|
|||||||
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Shared test fixtures."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine, event
|
||||||
|
|
||||||
|
from python.orm.signal_bot.base import SignalBotBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sqlite_engine() -> Generator[Engine]:
|
||||||
|
"""Create an in-memory SQLite engine for testing."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _set_sqlite_pragma(dbapi_connection, _connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
SignalBotBase.metadata.create_all(engine)
|
||||||
|
yield engine
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(sqlite_engine: Engine) -> Generator[Engine]:
|
||||||
|
"""Yield the shared engine after cleaning all tables between tests."""
|
||||||
|
yield sqlite_engine
|
||||||
|
with sqlite_engine.begin() as connection:
|
||||||
|
for table in reversed(SignalBotBase.metadata.sorted_tables):
|
||||||
|
connection.execute(table.delete())
|
||||||
@@ -7,9 +7,7 @@ from datetime import timedelta
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
|
||||||
|
|
||||||
from python.orm.richie.base import RichieBase
|
|
||||||
from python.signal_bot.commands.inventory import (
|
from python.signal_bot.commands.inventory import (
|
||||||
_format_summary,
|
_format_summary,
|
||||||
parse_llm_response,
|
parse_llm_response,
|
||||||
@@ -17,7 +15,7 @@ from python.signal_bot.commands.inventory import (
|
|||||||
from python.signal_bot.commands.location import _format_location, handle_location_request
|
from python.signal_bot.commands.location import _format_location, handle_location_request
|
||||||
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
|
from python.signal_bot.device_registry import _BLOCKED_TTL, _DEFAULT_TTL, DeviceRegistry, _CacheEntry
|
||||||
from python.signal_bot.llm_client import LLMClient
|
from python.signal_bot.llm_client import LLMClient
|
||||||
from python.signal_bot.main import dispatch
|
from python.signal_bot.main import Bot
|
||||||
from python.signal_bot.models import (
|
from python.signal_bot.models import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
InventoryItem,
|
InventoryItem,
|
||||||
@@ -78,12 +76,6 @@ class TestDeviceRegistry:
|
|||||||
def signal_mock(self):
|
def signal_mock(self):
|
||||||
return MagicMock(spec=SignalClient)
|
return MagicMock(spec=SignalClient)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def engine(self):
|
|
||||||
engine = create_engine("sqlite://")
|
|
||||||
RichieBase.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def registry(self, signal_mock, engine):
|
def registry(self, signal_mock, engine):
|
||||||
return DeviceRegistry(signal_mock, engine)
|
return DeviceRegistry(signal_mock, engine)
|
||||||
@@ -131,12 +123,6 @@ class TestContactCache:
|
|||||||
def signal_mock(self):
|
def signal_mock(self):
|
||||||
return MagicMock(spec=SignalClient)
|
return MagicMock(spec=SignalClient)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def engine(self):
|
|
||||||
engine = create_engine("sqlite://")
|
|
||||||
RichieBase.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def registry(self, signal_mock, engine):
|
def registry(self, signal_mock, engine):
|
||||||
return DeviceRegistry(signal_mock, engine)
|
return DeviceRegistry(signal_mock, engine)
|
||||||
@@ -215,6 +201,7 @@ class TestContactCache:
|
|||||||
trust_level=old.trust_level,
|
trust_level=old.trust_level,
|
||||||
has_safety_number=old.has_safety_number,
|
has_safety_number=old.has_safety_number,
|
||||||
safety_number=old.safety_number,
|
safety_number=old.safety_number,
|
||||||
|
roles=old.roles,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
|
with patch("python.signal_bot.device_registry.Session") as mock_session_cls:
|
||||||
@@ -229,25 +216,15 @@ class TestContactCache:
|
|||||||
|
|
||||||
|
|
||||||
class TestLocationCommand:
|
class TestLocationCommand:
|
||||||
def test_format_location_from_attributes(self):
|
def test_format_location(self):
|
||||||
payload = {
|
response = _format_location("12.34", "56.78")
|
||||||
"state": "whatever",
|
|
||||||
"attributes": {
|
|
||||||
"latitude": 12.34,
|
|
||||||
"longitude": 56.78,
|
|
||||||
"speed": "45 mph",
|
|
||||||
"last_updated": "2024-01-01T00:00:00+00:00",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
response = _format_location(payload)
|
|
||||||
assert "12.34, 56.78" in response
|
assert "12.34, 56.78" in response
|
||||||
assert "maps.google.com" in response
|
assert "maps.google.com" in response
|
||||||
assert "Speed: 45 mph" in response
|
|
||||||
|
|
||||||
def test_handle_location_request_without_config(self):
|
def test_handle_location_request_without_config(self):
|
||||||
signal = MagicMock(spec=SignalClient)
|
signal = MagicMock(spec=SignalClient)
|
||||||
message = SignalMessage(source="+1234", timestamp=0, message="location")
|
message = SignalMessage(source="+1234", timestamp=0, message="location")
|
||||||
handle_location_request(message, signal, None, None, "sensor.gps_location")
|
handle_location_request(message, signal, None, None)
|
||||||
signal.reply.assert_called_once()
|
signal.reply.assert_called_once()
|
||||||
assert "not configured" in signal.reply.call_args[0][1]
|
assert "not configured" in signal.reply.call_args[0][1]
|
||||||
|
|
||||||
@@ -266,11 +243,12 @@ class TestDispatch:
|
|||||||
mock = MagicMock(spec=DeviceRegistry)
|
mock = MagicMock(spec=DeviceRegistry)
|
||||||
mock.is_verified.return_value = True
|
mock.is_verified.return_value = True
|
||||||
mock.has_safety_number.return_value = True
|
mock.has_safety_number.return_value = True
|
||||||
|
mock.has_role.return_value = False
|
||||||
|
mock.get_roles.return_value = []
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def config(self):
|
def config(self, engine):
|
||||||
engine = create_engine("sqlite://")
|
|
||||||
return BotConfig(
|
return BotConfig(
|
||||||
signal_api_url="http://localhost:8080",
|
signal_api_url="http://localhost:8080",
|
||||||
phone_number="+1234567890",
|
phone_number="+1234567890",
|
||||||
@@ -278,46 +256,66 @@ class TestDispatch:
|
|||||||
engine=engine,
|
engine=engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unverified_device_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
@pytest.fixture
|
||||||
|
def bot(self, signal_mock, llm_mock, registry_mock, config):
|
||||||
|
return Bot(signal_mock, llm_mock, registry_mock, config)
|
||||||
|
|
||||||
|
def test_unverified_device_ignored(self, bot, signal_mock, registry_mock):
|
||||||
registry_mock.is_verified.return_value = False
|
registry_mock.is_verified.return_value = False
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
bot.dispatch(msg)
|
||||||
signal_mock.reply.assert_not_called()
|
signal_mock.reply.assert_not_called()
|
||||||
|
|
||||||
def test_help_command(self, signal_mock, llm_mock, registry_mock, config):
|
def test_admin_without_safety_number_ignored(self, bot, signal_mock, registry_mock):
|
||||||
|
registry_mock.has_safety_number.return_value = False
|
||||||
|
registry_mock.has_role.return_value = True
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
bot.dispatch(msg)
|
||||||
|
signal_mock.reply.assert_not_called()
|
||||||
|
|
||||||
|
def test_non_admin_without_safety_number_allowed(self, bot, signal_mock, registry_mock):
|
||||||
|
registry_mock.has_safety_number.return_value = False
|
||||||
|
registry_mock.has_role.return_value = False
|
||||||
|
registry_mock.get_roles.return_value = []
|
||||||
|
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
||||||
|
bot.dispatch(msg)
|
||||||
|
signal_mock.reply.assert_called_once()
|
||||||
|
|
||||||
|
def test_help_command(self, bot, signal_mock, registry_mock):
|
||||||
|
msg = SignalMessage(source="+1234", timestamp=0, message="help")
|
||||||
|
bot.dispatch(msg)
|
||||||
signal_mock.reply.assert_called_once()
|
signal_mock.reply.assert_called_once()
|
||||||
assert "Available commands" in signal_mock.reply.call_args[0][1]
|
assert "Available commands" in signal_mock.reply.call_args[0][1]
|
||||||
|
|
||||||
def test_unknown_command_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
def test_unknown_command_ignored(self, bot, signal_mock):
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
|
msg = SignalMessage(source="+1234", timestamp=0, message="foobar")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
bot.dispatch(msg)
|
||||||
signal_mock.reply.assert_not_called()
|
signal_mock.reply.assert_not_called()
|
||||||
|
|
||||||
def test_non_command_message_ignored(self, signal_mock, llm_mock, registry_mock, config):
|
def test_non_command_message_ignored(self, bot, signal_mock):
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
|
msg = SignalMessage(source="+1234", timestamp=0, message="hello there")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
bot.dispatch(msg)
|
||||||
signal_mock.reply.assert_not_called()
|
signal_mock.reply.assert_not_called()
|
||||||
|
|
||||||
def test_status_command(self, signal_mock, llm_mock, registry_mock, config):
|
def test_status_command(self, bot, signal_mock, llm_mock, registry_mock):
|
||||||
llm_mock.list_models.return_value = ["model1", "model2"]
|
llm_mock.list_models.return_value = ["model1", "model2"]
|
||||||
llm_mock.model = "test:7b"
|
llm_mock.model = "test:7b"
|
||||||
registry_mock.list_devices.return_value = []
|
registry_mock.list_devices.return_value = []
|
||||||
|
registry_mock.has_role.return_value = True
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="status")
|
msg = SignalMessage(source="+1234", timestamp=0, message="status")
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
bot.dispatch(msg)
|
||||||
signal_mock.reply.assert_called_once()
|
signal_mock.reply.assert_called_once()
|
||||||
assert "Bot online" in signal_mock.reply.call_args[0][1]
|
assert "Bot online" in signal_mock.reply.call_args[0][1]
|
||||||
|
|
||||||
def test_location_command(self, signal_mock, llm_mock, registry_mock, config):
|
def test_location_command(self, bot, signal_mock, registry_mock, config):
|
||||||
|
registry_mock.has_role.return_value = True
|
||||||
msg = SignalMessage(source="+1234", timestamp=0, message="location")
|
msg = SignalMessage(source="+1234", timestamp=0, message="location")
|
||||||
with patch("python.signal_bot.main.handle_location_request") as mock_location:
|
with patch("python.signal_bot.main.handle_location_request") as mock_location:
|
||||||
dispatch(msg, signal_mock, llm_mock, registry_mock, config)
|
bot.dispatch(msg)
|
||||||
|
|
||||||
mock_location.assert_called_once_with(
|
mock_location.assert_called_once_with(
|
||||||
msg,
|
msg,
|
||||||
signal_mock,
|
signal_mock,
|
||||||
config.ha_url,
|
config.ha_url,
|
||||||
config.ha_token,
|
config.ha_token,
|
||||||
config.ha_location_entity,
|
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user