Compare commits

..

12 Commits

73 changed files with 1758 additions and 7021 deletions

View File

@@ -24,6 +24,7 @@
fastapi fastapi
fastapi-cli fastapi-cli
httpx httpx
python-multipart
mypy mypy
polars polars
psycopg psycopg

View File

@@ -7,7 +7,25 @@ requires-python = "~=3.13.0"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
# these dependencies are a best effort and aren't guaranteed to work # these dependencies are a best effort and aren't guaranteed to work
dependencies = ["apprise", "apscheduler", "httpx", "polars", "pydantic", "pyyaml", "requests", "typer"] # for up-to-date dependencies, see overlays/default.nix
dependencies = [
"alembic",
"apprise",
"apscheduler",
"httpx",
"python-multipart",
"polars",
"psycopg[binary]",
"pydantic",
"pyyaml",
"requests",
"sqlalchemy",
"typer",
]
[project.scripts]
database = "python.database_cli:app"
van-inventory = "python.van_inventory.main:serve"
[dependency-groups] [dependency-groups]
dev = [ dev = [

View File

@@ -1,109 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = python/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
file_template = %%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
revision_environment = true
[post_write_hooks]
hooks = dynamic_schema,ruff
dynamic_schema.type = dynamic_schema
ruff.type = ruff
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -9,20 +9,24 @@ from typing import TYPE_CHECKING, Any, Literal
from alembic import context from alembic import context
from alembic.script import write_hooks from alembic.script import write_hooks
from sqlalchemy.schema import CreateSchema
from python.common import bash_wrapper from python.common import bash_wrapper
from python.orm import RichieBase from python.orm.common import get_postgres_engine
from python.orm.base import get_postgres_engine
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import MutableMapping from collections.abc import MutableMapping
# this is the Alembic Config object, which provides from sqlalchemy.orm import DeclarativeBase
# access to the values within the .ini file in use.
config = context.config config = context.config
base_class: type[DeclarativeBase] = config.attributes.get("base")
if base_class is None:
error = "No base class provided. Use the database CLI to run alembic commands."
raise RuntimeError(error)
target_metadata = RichieBase.metadata target_metadata = base_class.metadata
logging.basicConfig( logging.basicConfig(
level="DEBUG", level="DEBUG",
datefmt="%Y-%m-%dT%H:%M:%S%z", datefmt="%Y-%m-%dT%H:%M:%S%z",
@@ -35,8 +39,9 @@ logging.basicConfig(
def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None: def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
"""Dynamic schema.""" """Dynamic schema."""
original_file = Path(filename).read_text() original_file = Path(filename).read_text()
dynamic_schema_file_part1 = original_file.replace(f"schema='{RichieBase.schema_name}'", "schema=schema") schema_name = base_class.schema_name
dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{RichieBase.schema_name}.", "f'{schema}.") dynamic_schema_file_part1 = original_file.replace(f"schema='{schema_name}'", "schema=schema")
dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{schema_name}.", "f'{schema}.")
Path(filename).write_text(dynamic_schema_file) Path(filename).write_text(dynamic_schema_file)
@@ -52,12 +57,12 @@ def include_name(
type_: Literal["schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint"], type_: Literal["schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint"],
_parent_names: MutableMapping[Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None], _parent_names: MutableMapping[Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None],
) -> bool: ) -> bool:
"""This filter table to be included in the migration. """Filter tables to be included in the migration.
Args: Args:
name (str): The name of the table. name (str): The name of the table.
type_ (str): The type of the table. type_ (str): The type of the table.
parent_names (list[str]): The names of the parent tables. _parent_names (MutableMapping): The names of the parent tables.
Returns: Returns:
bool: True if the table should be included, False otherwise. bool: True if the table should be included, False otherwise.
@@ -75,19 +80,30 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = get_postgres_engine() env_prefix = config.attributes.get("env_prefix", "POSTGRES")
connectable = get_postgres_engine(name=env_prefix)
with connectable.connect() as connection: with connectable.connect() as connection:
schema = base_class.schema_name
if not connectable.dialect.has_schema(connection, schema):
answer = input(f"Schema {schema!r} does not exist. Create it? [y/N] ")
if answer.lower() != "y":
error = f"Schema {schema!r} does not exist. Exiting."
raise SystemExit(error)
connection.execute(CreateSchema(schema))
connection.commit()
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
include_schemas=True, include_schemas=True,
version_table_schema=RichieBase.schema_name, version_table_schema=schema,
include_name=include_name, include_name=include_name,
) )
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
connection.commit()
run_migrations_online() run_migrations_online()

View File

@@ -0,0 +1,135 @@
"""add congress tracker tables.
Revision ID: 3f71565e38de
Revises: edd7dd61a3d2
Create Date: 2026-02-12 16:36:09.457303
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "3f71565e38de"
down_revision: str | None = "edd7dd61a3d2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"bill",
sa.Column("congress", sa.Integer(), nullable=False),
sa.Column("bill_type", sa.String(), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("title", sa.String(), nullable=True),
sa.Column("title_short", sa.String(), nullable=True),
sa.Column("official_title", sa.String(), nullable=True),
sa.Column("status", sa.String(), nullable=True),
sa.Column("status_at", sa.Date(), nullable=True),
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
sa.Column("subjects_top_term", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
sa.UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
schema=schema,
)
op.create_index("ix_bill_congress", "bill", ["congress"], unique=False, schema=schema)
op.create_table(
"legislator",
sa.Column("bioguide_id", sa.Text(), nullable=False),
sa.Column("thomas_id", sa.String(), nullable=True),
sa.Column("lis_id", sa.String(), nullable=True),
sa.Column("govtrack_id", sa.Integer(), nullable=True),
sa.Column("opensecrets_id", sa.String(), nullable=True),
sa.Column("fec_ids", sa.String(), nullable=True),
sa.Column("first_name", sa.String(), nullable=False),
sa.Column("last_name", sa.String(), nullable=False),
sa.Column("official_full_name", sa.String(), nullable=True),
sa.Column("nickname", sa.String(), nullable=True),
sa.Column("birthday", sa.Date(), nullable=True),
sa.Column("gender", sa.String(), nullable=True),
sa.Column("current_party", sa.String(), nullable=True),
sa.Column("current_state", sa.String(), nullable=True),
sa.Column("current_district", sa.Integer(), nullable=True),
sa.Column("current_chamber", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
schema=schema,
)
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
op.create_table(
"vote",
sa.Column("congress", sa.Integer(), nullable=False),
sa.Column("chamber", sa.String(), nullable=False),
sa.Column("session", sa.Integer(), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("vote_type", sa.String(), nullable=True),
sa.Column("question", sa.String(), nullable=True),
sa.Column("result", sa.String(), nullable=True),
sa.Column("result_text", sa.String(), nullable=True),
sa.Column("vote_date", sa.Date(), nullable=False),
sa.Column("yea_count", sa.Integer(), nullable=True),
sa.Column("nay_count", sa.Integer(), nullable=True),
sa.Column("not_voting_count", sa.Integer(), nullable=True),
sa.Column("present_count", sa.Integer(), nullable=True),
sa.Column("bill_id", sa.Integer(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
sa.UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
schema=schema,
)
op.create_index("ix_vote_congress_chamber", "vote", ["congress", "chamber"], unique=False, schema=schema)
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
op.create_table(
"vote_record",
sa.Column("vote_id", sa.Integer(), nullable=False),
sa.Column("legislator_id", sa.Integer(), nullable=False),
sa.Column("position", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["legislator_id"],
[f"{schema}.legislator.id"],
name=op.f("fk_vote_record_legislator_id_legislator"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("vote_record", schema=schema)
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
op.drop_table("vote", schema=schema)
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
op.drop_table("legislator", schema=schema)
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
op.drop_table("bill", schema=schema)
# ### end Alembic commands ###

View File

@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
from python.orm import RichieBase from python.orm import ${config.attributes["base"].__name__}
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
@@ -24,7 +24,7 @@ down_revision: str | None = ${repr(down_revision)}
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
depends_on: str | Sequence[str] | None = ${repr(depends_on)} depends_on: str | Sequence[str] | None = ${repr(depends_on)}
schema=RichieBase.schema_name schema=${config.attributes["base"].__name__}.schema_name
def upgrade() -> None: def upgrade() -> None:
"""Upgrade.""" """Upgrade."""

View File

@@ -0,0 +1,80 @@
"""starting van invintory.
Revision ID: 15e733499804
Revises:
Create Date: 2026-03-08 00:18:20.759720
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import VanInventoryBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "15e733499804"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = VanInventoryBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"items",
sa.Column("name", sa.String(), nullable=False),
sa.Column("quantity", sa.Float(), nullable=False),
sa.Column("unit", sa.String(), nullable=False),
sa.Column("category", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_items")),
sa.UniqueConstraint("name", name=op.f("uq_items_name")),
schema=schema,
)
op.create_table(
"meals",
sa.Column("name", sa.String(), nullable=False),
sa.Column("instructions", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_meals")),
sa.UniqueConstraint("name", name=op.f("uq_meals_name")),
schema=schema,
)
op.create_table(
"meal_ingredients",
sa.Column("meal_id", sa.Integer(), nullable=False),
sa.Column("item_id", sa.Integer(), nullable=False),
sa.Column("quantity_needed", sa.Float(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(["item_id"], [f"{schema}.items.id"], name=op.f("fk_meal_ingredients_item_id_items")),
sa.ForeignKeyConstraint(["meal_id"], [f"{schema}.meals.id"], name=op.f("fk_meal_ingredients_meal_id_meals")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_meal_ingredients")),
sa.UniqueConstraint("meal_id", "item_id", name=op.f("uq_meal_ingredients_meal_id")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("meal_ingredients", schema=schema)
op.drop_table("meals", schema=schema)
op.drop_table("items", schema=schema)
# ### end Alembic commands ###

View File

@@ -16,7 +16,7 @@ from fastapi import FastAPI
from python.api.routers import contact_router, create_frontend_router from python.api.routers import contact_router, create_frontend_router
from python.common import configure_logger from python.common import configure_logger
from python.orm.base import get_postgres_engine from python.orm.common import get_postgres_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from python.api.dependencies import DbSession from python.api.dependencies import DbSession
from python.orm.contact import Contact, ContactRelationship, Need, RelationshipType from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
class NeedBase(BaseModel): class NeedBase(BaseModel):

114
python/database_cli.py Normal file
View File

@@ -0,0 +1,114 @@
"""CLI wrapper around alembic for multi-database support.
Usage:
database <db_name> <command> [args...]
Examples:
database van_inventory upgrade head
database van_inventory downgrade head-1
database van_inventory revision --autogenerate -m "add meals table"
database van_inventory check
database richie check
database richie upgrade head
"""
from __future__ import annotations
from dataclasses import dataclass
from importlib import import_module
from typing import TYPE_CHECKING, Annotated
import typer
from alembic.config import CommandLine, Config
if TYPE_CHECKING:
from sqlalchemy.orm import DeclarativeBase
@dataclass(frozen=True)
class DatabaseConfig:
"""Configuration for a database."""
env_prefix: str
version_location: str
base_module: str
base_class_name: str
models_module: str
script_location: str = "python/alembic"
file_template: str = "%%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s"
def get_base(self) -> type[DeclarativeBase]:
"""Import and return the Base class."""
module = import_module(self.base_module)
return getattr(module, self.base_class_name)
def import_models(self) -> None:
"""Import ORM models so alembic autogenerate can detect them."""
import_module(self.models_module)
def alembic_config(self) -> Config:
"""Build an alembic Config for this database."""
# Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
from alembic.config import Config as AlembicConfig # noqa: PLC0415
cfg = AlembicConfig()
cfg.set_main_option("script_location", self.script_location)
cfg.set_main_option("file_template", self.file_template)
cfg.set_main_option("prepend_sys_path", ".")
cfg.set_main_option("version_path_separator", "os")
cfg.set_main_option("version_locations", self.version_location)
cfg.set_main_option("revision_environment", "true")
cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,ruff")
cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema")
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
cfg.attributes["base"] = self.get_base()
cfg.attributes["env_prefix"] = self.env_prefix
self.import_models()
return cfg
DATABASES: dict[str, DatabaseConfig] = {
"richie": DatabaseConfig(
env_prefix="RICHIE",
version_location="python/alembic/richie/versions",
base_module="python.orm.richie.base",
base_class_name="RichieBase",
models_module="python.orm.richie.contact",
),
"van_inventory": DatabaseConfig(
env_prefix="VAN_INVENTORY",
version_location="python/alembic/van_inventory/versions",
base_module="python.orm.van_inventory.base",
base_class_name="VanInventoryBase",
models_module="python.orm.van_inventory.models",
),
}
app = typer.Typer(help="Multi-database alembic wrapper.")
@app.command(
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def main(
ctx: typer.Context,
db_name: Annotated[str, typer.Argument(help=f"Database name. Options: {', '.join(DATABASES)}")],
command: Annotated[str, typer.Argument(help="Alembic command (upgrade, downgrade, revision, check, etc.)")],
) -> None:
"""Run an alembic command against the specified database."""
db_config = DATABASES.get(db_name)
if not db_config:
typer.echo(f"Unknown database: {db_name!r}. Available: {', '.join(DATABASES)}", err=True)
raise typer.Exit(code=1)
alembic_cfg = db_config.alembic_config()
cmd_line = CommandLine()
options = cmd_line.parser.parse_args([command, *ctx.args])
cmd_line.run_cmd(alembic_cfg, options)
if __name__ == "__main__":
app()

View File

@@ -1,22 +1,9 @@
"""ORM package exports.""" """ORM package exports."""
from __future__ import annotations from python.orm.richie.base import RichieBase
from python.orm.van_inventory.base import VanInventoryBase
from python.orm.base import RichieBase, TableBase
from python.orm.contact import (
Contact,
ContactNeed,
ContactRelationship,
Need,
RelationshipType,
)
__all__ = [ __all__ = [
"Contact",
"ContactNeed",
"ContactRelationship",
"Need",
"RelationshipType",
"RichieBase", "RichieBase",
"TableBase", "VanInventoryBase",
] ]

View File

@@ -1,80 +0,0 @@
"""Base ORM definitions."""
from __future__ import annotations
from datetime import datetime
from os import getenv
from typing import cast
from sqlalchemy import DateTime, MetaData, create_engine, func
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class RichieBase(DeclarativeBase):
"""Base class for all ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention={
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
},
)
class TableBase(AbstractConcreteBase, RichieBase):
"""Abstract concrete base for tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
def get_connection_info() -> tuple[str, str, str, str, str | None]:
"""Get connection info from environment variables."""
database = getenv("POSTGRES_DB")
host = getenv("POSTGRES_HOST")
port = getenv("POSTGRES_PORT")
username = getenv("POSTGRES_USER")
password = getenv("POSTGRES_PASSWORD")
if None in (database, host, port, username):
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
raise ValueError(error)
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info()
url = URL.create(
drivername="postgresql+psycopg",
username=username,
password=password,
host=host,
port=int(port),
database=database,
)
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
)

51
python/orm/common.py Normal file
View File

@@ -0,0 +1,51 @@
"""Shared ORM definitions."""
from __future__ import annotations
from os import getenv
from typing import cast
from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Engine
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
"""Get connection info from environment variables."""
database = getenv(f"{name}_DB")
host = getenv(f"{name}_HOST")
port = getenv(f"{name}_PORT")
username = getenv(f"{name}_USER")
password = getenv(f"{name}_PASSWORD")
if None in (database, host, port, username):
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
raise ValueError(error)
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info(name)
url = URL.create(
drivername="postgresql+psycopg",
username=username,
password=password,
host=host,
port=int(port),
database=database,
)
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
)

View File

@@ -0,0 +1,27 @@
"""Richie database ORM exports."""
from __future__ import annotations
from python.orm.richie.base import RichieBase, TableBase
from python.orm.richie.congress import Bill, Legislator, Vote, VoteRecord
from python.orm.richie.contact import (
Contact,
ContactNeed,
ContactRelationship,
Need,
RelationshipType,
)
__all__ = [
"Bill",
"Contact",
"ContactNeed",
"ContactRelationship",
"Legislator",
"Need",
"RelationshipType",
"RichieBase",
"TableBase",
"Vote",
"VoteRecord",
]

39
python/orm/richie/base.py Normal file
View File

@@ -0,0 +1,39 @@
"""Richie database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class RichieBase(DeclarativeBase):
"""Base class for richie database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class TableBase(AbstractConcreteBase, RichieBase):
"""Abstract concrete base for richie tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -0,0 +1,150 @@
"""Congress Tracker database models."""
from __future__ import annotations
from datetime import date
from sqlalchemy import ForeignKey, Index, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.richie.base import RichieBase, TableBase
class Legislator(TableBase):
"""Legislator model - members of Congress."""
__tablename__ = "legislator"
# Natural key - bioguide ID is the authoritative identifier
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
# Other IDs for cross-referencing
thomas_id: Mapped[str | None]
lis_id: Mapped[str | None]
govtrack_id: Mapped[int | None]
opensecrets_id: Mapped[str | None]
fec_ids: Mapped[str | None] # JSON array stored as string
# Name info
first_name: Mapped[str]
last_name: Mapped[str]
official_full_name: Mapped[str | None]
nickname: Mapped[str | None]
# Bio
birthday: Mapped[date | None]
gender: Mapped[str | None] # M/F
# Current term info (denormalized for query efficiency)
current_party: Mapped[str | None]
current_state: Mapped[str | None]
current_district: Mapped[int | None] # House only
current_chamber: Mapped[str | None] # rep/sen
# Relationships
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="legislator",
cascade="all, delete-orphan",
)
class Bill(TableBase):
"""Bill model - legislation introduced in Congress."""
__tablename__ = "bill"
# Composite natural key: congress + bill_type + number
congress: Mapped[int]
bill_type: Mapped[str] # hr, s, hres, sres, hjres, sjres
number: Mapped[int]
# Bill info
title: Mapped[str | None]
title_short: Mapped[str | None]
official_title: Mapped[str | None]
# Status
status: Mapped[str | None]
status_at: Mapped[date | None]
# Sponsor
sponsor_bioguide_id: Mapped[str | None]
# Subjects
subjects_top_term: Mapped[str | None]
# Relationships
votes: Mapped[list[Vote]] = relationship(
"Vote",
back_populates="bill",
)
__table_args__ = (
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
Index("ix_bill_congress", "congress"),
)
class Vote(TableBase):
"""Vote model - roll call votes in Congress."""
__tablename__ = "vote"
# Composite natural key: congress + chamber + session + number
congress: Mapped[int]
chamber: Mapped[str] # house/senate
session: Mapped[int]
number: Mapped[int]
# Vote details
vote_type: Mapped[str | None]
question: Mapped[str | None]
result: Mapped[str | None]
result_text: Mapped[str | None]
# Timing
vote_date: Mapped[date]
# Vote counts (denormalized for efficiency)
yea_count: Mapped[int | None]
nay_count: Mapped[int | None]
not_voting_count: Mapped[int | None]
present_count: Mapped[int | None]
# Related bill (optional - not all votes are on bills)
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
# Relationships
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
vote_records: Mapped[list[VoteRecord]] = relationship(
"VoteRecord",
back_populates="vote",
cascade="all, delete-orphan",
)
__table_args__ = (
UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
Index("ix_vote_date", "vote_date"),
Index("ix_vote_congress_chamber", "congress", "chamber"),
)
class VoteRecord(RichieBase):
"""Association table: Vote <-> Legislator with position."""
__tablename__ = "vote_record"
vote_id: Mapped[int] = mapped_column(
ForeignKey("main.vote.id", ondelete="CASCADE"),
primary_key=True,
)
legislator_id: Mapped[int] = mapped_column(
ForeignKey("main.legislator.id", ondelete="CASCADE"),
primary_key=True,
)
position: Mapped[str] # Yea, Nay, Not Voting, Present
# Relationships
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")

View File

@@ -7,7 +7,7 @@ from enum import Enum
from sqlalchemy import ForeignKey, String from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.base import RichieBase, TableBase from python.orm.richie.base import RichieBase, TableBase
class RelationshipType(str, Enum): class RelationshipType(str, Enum):

View File

@@ -0,0 +1 @@
"""Van inventory database ORM exports."""

View File

@@ -0,0 +1,39 @@
"""Van inventory database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class VanInventoryBase(DeclarativeBase):
"""Base class for van_inventory database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class VanTableBase(AbstractConcreteBase, VanInventoryBase):
"""Abstract concrete base for van_inventory tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -0,0 +1,46 @@
"""Van inventory ORM models."""
from __future__ import annotations
from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.van_inventory.base import VanTableBase
class Item(VanTableBase):
"""A food item in the van."""
__tablename__ = "items"
name: Mapped[str] = mapped_column(unique=True)
quantity: Mapped[float] = mapped_column(default=0)
unit: Mapped[str]
category: Mapped[str | None]
meal_ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="item")
class Meal(VanTableBase):
"""A meal that can be made from items in the van."""
__tablename__ = "meals"
name: Mapped[str] = mapped_column(unique=True)
instructions: Mapped[str | None]
ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="meal")
class MealIngredient(VanTableBase):
"""Links a meal to the items it requires, with quantities."""
__tablename__ = "meal_ingredients"
__table_args__ = (UniqueConstraint("meal_id", "item_id"),)
meal_id: Mapped[int] = mapped_column(ForeignKey("meals.id"))
item_id: Mapped[int] = mapped_column(ForeignKey("items.id"))
quantity_needed: Mapped[float]
meal: Mapped[Meal] = relationship(back_populates="ingredients")
item: Mapped[Item] = relationship(back_populates="meal_ingredients")

View File

@@ -0,0 +1 @@
"""Van inventory FastAPI application."""

View File

@@ -0,0 +1,16 @@
"""FastAPI dependencies for van inventory."""
from collections.abc import Iterator
from typing import Annotated
from fastapi import Depends, Request
from sqlalchemy.orm import Session
def get_db(request: Request) -> Iterator[Session]:
"""Get database session from app state."""
with Session(request.app.state.engine) as session:
yield session
DbSession = Annotated[Session, Depends(get_db)]

View File

@@ -0,0 +1,56 @@
"""FastAPI app for van inventory."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Annotated
import typer
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.van_inventory.routers import api_router, frontend_router
STATIC_DIR = Path(__file__).resolve().parent / "static"
if TYPE_CHECKING:
from collections.abc import AsyncIterator
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
app.state.engine = get_postgres_engine(name="VAN_INVENTORY")
yield
app.state.engine.dispose()
app = FastAPI(title="Van Inventory", lifespan=lifespan)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.include_router(api_router)
app.include_router(frontend_router)
return app
def serve(
# Intentionally binds all interfaces — this is a LAN-only van server
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "0.0.0.0", # noqa: S104
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8001,
log_level: Annotated[str, typer.Option("--log-level", "-l", help="Log level")] = "INFO",
) -> None:
"""Start the Van Inventory server."""
configure_logger(log_level)
app = create_app()
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
typer.run(serve)

View File

@@ -0,0 +1,6 @@
"""Van inventory API routers."""
from python.van_inventory.routers.api import router as api_router
from python.van_inventory.routers.frontend import router as frontend_router
__all__ = ["api_router", "frontend_router"]

View File

@@ -0,0 +1,314 @@
"""Van inventory API router."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from python.orm.van_inventory.models import Item, Meal, MealIngredient
if TYPE_CHECKING:
from python.van_inventory.dependencies import DbSession
# --- Schemas ---
class ItemCreate(BaseModel):
"""Schema for creating an item."""
name: str
quantity: float = Field(default=0, ge=0)
unit: str
category: str | None = None
class ItemUpdate(BaseModel):
"""Schema for updating an item."""
name: str | None = None
quantity: float | None = Field(default=None, ge=0)
unit: str | None = None
category: str | None = None
class ItemResponse(BaseModel):
"""Schema for item response."""
id: int
name: str
quantity: float
unit: str
category: str | None
model_config = {"from_attributes": True}
class IngredientCreate(BaseModel):
"""Schema for adding an ingredient to a meal."""
item_id: int
quantity_needed: float = Field(gt=0)
class MealCreate(BaseModel):
"""Schema for creating a meal."""
name: str
instructions: str | None = None
ingredients: list[IngredientCreate] = []
class MealUpdate(BaseModel):
"""Schema for updating a meal."""
name: str | None = None
instructions: str | None = None
class IngredientResponse(BaseModel):
"""Schema for ingredient response."""
item_id: int
item_name: str
quantity_needed: float
unit: str
model_config = {"from_attributes": True}
class MealResponse(BaseModel):
"""Schema for meal response."""
id: int
name: str
instructions: str | None
ingredients: list[IngredientResponse] = []
model_config = {"from_attributes": True}
@classmethod
def from_meal(cls, meal: Meal) -> MealResponse:
"""Build a MealResponse from an ORM Meal with loaded ingredients."""
return cls(
id=meal.id,
name=meal.name,
instructions=meal.instructions,
ingredients=[
IngredientResponse(
item_id=mi.item_id,
item_name=mi.item.name,
quantity_needed=mi.quantity_needed,
unit=mi.item.unit,
)
for mi in meal.ingredients
],
)
class ShoppingItem(BaseModel):
"""An item needed for a meal that is short on stock."""
item_name: str
unit: str
needed: float
have: float
short: float
class MealAvailability(BaseModel):
"""Availability status for a meal."""
meal_id: int
meal_name: str
can_make: bool
missing: list[ShoppingItem] = []
# --- Routes ---
router = APIRouter(prefix="/api", tags=["van_inventory"])
# Items
@router.post("/items", response_model=ItemResponse)
def create_item(item: ItemCreate, db: DbSession) -> Item:
"""Create a new inventory item."""
db_item = Item(**item.model_dump())
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
@router.get("/items", response_model=list[ItemResponse])
def list_items(db: DbSession) -> list[Item]:
"""List all inventory items."""
return list(db.scalars(select(Item).order_by(Item.name)).all())
@router.get("/items/{item_id}", response_model=ItemResponse)
def get_item(item_id: int, db: DbSession) -> Item:
"""Get an item by ID."""
item = db.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item
@router.patch("/items/{item_id}", response_model=ItemResponse)
def update_item(item_id: int, item: ItemUpdate, db: DbSession) -> Item:
"""Update an item by ID."""
db_item = db.get(Item, item_id)
if not db_item:
raise HTTPException(status_code=404, detail="Item not found")
for key, value in item.model_dump(exclude_unset=True).items():
setattr(db_item, key, value)
db.commit()
db.refresh(db_item)
return db_item
@router.delete("/items/{item_id}")
def delete_item(item_id: int, db: DbSession) -> dict[str, bool]:
"""Delete an item by ID."""
item = db.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
db.delete(item)
db.commit()
return {"deleted": True}
# Meals
@router.post("/meals", response_model=MealResponse)
def create_meal(meal: MealCreate, db: DbSession) -> MealResponse:
"""Create a new meal with optional ingredients."""
for ing in meal.ingredients:
if not db.get(Item, ing.item_id):
raise HTTPException(status_code=422, detail=f"Item {ing.item_id} not found")
db_meal = Meal(name=meal.name, instructions=meal.instructions)
db.add(db_meal)
db.flush()
for ing in meal.ingredients:
db.add(MealIngredient(meal_id=db_meal.id, item_id=ing.item_id, quantity_needed=ing.quantity_needed))
db.commit()
db_meal = db.scalar(
select(Meal)
.where(Meal.id == db_meal.id)
.options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
return MealResponse.from_meal(db_meal)
@router.get("/meals", response_model=list[MealResponse])
def list_meals(db: DbSession) -> list[MealResponse]:
"""List all meals with ingredients."""
meals = list(
db.scalars(
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
).all()
)
return [MealResponse.from_meal(m) for m in meals]
@router.get("/meals/availability", response_model=list[MealAvailability])
def check_all_meals(db: DbSession) -> list[MealAvailability]:
"""Check which meals can be made with current inventory."""
meals = list(
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
)
return [_check_meal(m) for m in meals]
@router.get("/meals/{meal_id}", response_model=MealResponse)
def get_meal(meal_id: int, db: DbSession) -> MealResponse:
"""Get a meal by ID with ingredients."""
meal = db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
return MealResponse.from_meal(meal)
@router.delete("/meals/{meal_id}")
def delete_meal(meal_id: int, db: DbSession) -> dict[str, bool]:
"""Delete a meal by ID."""
meal = db.get(Meal, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
db.delete(meal)
db.commit()
return {"deleted": True}
@router.post("/meals/{meal_id}/ingredients", response_model=MealResponse)
def add_ingredient(meal_id: int, ingredient: IngredientCreate, db: DbSession) -> MealResponse:
"""Add an ingredient to a meal."""
meal = db.get(Meal, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
if not db.get(Item, ingredient.item_id):
raise HTTPException(status_code=422, detail="Item not found")
existing = db.scalar(
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == ingredient.item_id)
)
if existing:
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
db.add(MealIngredient(meal_id=meal_id, item_id=ingredient.item_id, quantity_needed=ingredient.quantity_needed))
db.commit()
meal = db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
return MealResponse.from_meal(meal)
@router.delete("/meals/{meal_id}/ingredients/{item_id}")
def remove_ingredient(meal_id: int, item_id: int, db: DbSession) -> dict[str, bool]:
"""Remove an ingredient from a meal."""
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
if not mi:
raise HTTPException(status_code=404, detail="Ingredient not found")
db.delete(mi)
db.commit()
return {"deleted": True}
@router.get("/meals/{meal_id}/availability", response_model=MealAvailability)
def check_meal(meal_id: int, db: DbSession) -> MealAvailability:
"""Check if a specific meal can be made and what's missing."""
meal = db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
return _check_meal(meal)
def _check_meal(meal: Meal) -> MealAvailability:
missing = [
ShoppingItem(
item_name=mi.item.name,
unit=mi.item.unit,
needed=mi.quantity_needed,
have=mi.item.quantity,
short=mi.quantity_needed - mi.item.quantity,
)
for mi in meal.ingredients
if mi.item.quantity < mi.quantity_needed
]
return MealAvailability(
meal_id=meal.id,
meal_name=meal.name,
can_make=len(missing) == 0,
missing=missing,
)

View File

@@ -0,0 +1,198 @@
"""HTMX frontend routes for van inventory."""
from __future__ import annotations
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from python.orm.van_inventory.models import Item, Meal, MealIngredient
# FastAPI needs DbSession at runtime to resolve the Depends() annotation
from python.van_inventory.dependencies import DbSession # noqa: TC001
from python.van_inventory.routers.api import _check_meal
TEMPLATE_DIR = Path(__file__).resolve().parent.parent / "templates"
templates = Jinja2Templates(directory=TEMPLATE_DIR)
router = APIRouter(tags=["frontend"])
# --- Items ---
@router.get("/", response_class=HTMLResponse)
def items_page(request: Request, db: DbSession) -> HTMLResponse:
"""Render the inventory page."""
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "items.html", {"items": items})
@router.post("/items", response_class=HTMLResponse)
def htmx_create_item(
request: Request,
db: DbSession,
name: Annotated[str, Form()],
quantity: Annotated[float, Form()] = 0,
unit: Annotated[str, Form()] = "",
category: Annotated[str | None, Form()] = None,
) -> HTMLResponse:
"""Create an item and return updated item rows."""
if quantity < 0:
raise HTTPException(status_code=422, detail="Quantity must not be negative")
db.add(Item(name=name, quantity=quantity, unit=unit, category=category or None))
db.commit()
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
@router.patch("/items/{item_id}", response_class=HTMLResponse)
def htmx_update_item(
request: Request,
item_id: int,
db: DbSession,
quantity: Annotated[float, Form()],
) -> HTMLResponse:
"""Update an item's quantity and return updated item rows."""
if quantity < 0:
raise HTTPException(status_code=422, detail="Quantity must not be negative")
item = db.get(Item, item_id)
if item:
item.quantity = quantity
db.commit()
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
@router.delete("/items/{item_id}", response_class=HTMLResponse)
def htmx_delete_item(request: Request, item_id: int, db: DbSession) -> HTMLResponse:
"""Delete an item and return updated item rows."""
item = db.get(Item, item_id)
if item:
db.delete(item)
db.commit()
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
# --- Meals ---
def _load_meals(db: DbSession) -> list[Meal]:
return list(
db.scalars(
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
).all()
)
@router.get("/meals", response_class=HTMLResponse)
def meals_page(request: Request, db: DbSession) -> HTMLResponse:
"""Render the meals page."""
meals = _load_meals(db)
return templates.TemplateResponse(request, "meals.html", {"meals": meals})
@router.post("/meals", response_class=HTMLResponse)
def htmx_create_meal(
request: Request,
db: DbSession,
name: Annotated[str, Form()],
instructions: Annotated[str | None, Form()] = None,
) -> HTMLResponse:
"""Create a meal and return updated meal rows."""
db.add(Meal(name=name, instructions=instructions or None))
db.commit()
meals = _load_meals(db)
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
@router.delete("/meals/{meal_id}", response_class=HTMLResponse)
def htmx_delete_meal(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
"""Delete a meal and return updated meal rows."""
meal = db.get(Meal, meal_id)
if meal:
db.delete(meal)
db.commit()
meals = _load_meals(db)
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
# --- Meal detail ---
def _load_meal(db: DbSession, meal_id: int) -> Meal | None:
return db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
@router.get("/meals/{meal_id}", response_class=HTMLResponse)
def meal_detail_page(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
"""Render the meal detail page."""
meal = _load_meal(db, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "meal_detail.html", {"meal": meal, "items": items})
@router.post("/meals/{meal_id}/ingredients", response_class=HTMLResponse)
def htmx_add_ingredient(
request: Request,
meal_id: int,
db: DbSession,
item_id: Annotated[int, Form()],
quantity_needed: Annotated[float, Form()],
) -> HTMLResponse:
"""Add an ingredient to a meal and return updated ingredient rows."""
if quantity_needed <= 0:
raise HTTPException(status_code=422, detail="Quantity must be positive")
meal = db.get(Meal, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
if not db.get(Item, item_id):
raise HTTPException(status_code=422, detail="Item not found")
existing = db.scalar(
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id)
)
if existing:
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
db.add(MealIngredient(meal_id=meal_id, item_id=item_id, quantity_needed=quantity_needed))
db.commit()
meal = _load_meal(db, meal_id)
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
@router.delete("/meals/{meal_id}/ingredients/{item_id}", response_class=HTMLResponse)
def htmx_remove_ingredient(
request: Request,
meal_id: int,
item_id: int,
db: DbSession,
) -> HTMLResponse:
"""Remove an ingredient from a meal and return updated ingredient rows."""
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
if mi:
db.delete(mi)
db.commit()
meal = _load_meal(db, meal_id)
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
# --- Availability ---
@router.get("/availability", response_class=HTMLResponse)
def availability_page(request: Request, db: DbSession) -> HTMLResponse:
"""Render the meal availability page."""
meals = list(
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
)
availability = [_check_meal(m) for m in meals]
return templates.TemplateResponse(request, "availability.html", {"availability": availability})

View File

@@ -0,0 +1,212 @@
:root {
--neon-pink: #ff2a6d;
--neon-cyan: #05d9e8;
--neon-yellow: #f9f002;
--neon-purple: #d300c5;
--bg-dark: #0a0a0f;
--bg-panel: #0d0d1a;
--bg-input: #111128;
--border: #1a1a3e;
--text: #c0c0d0;
--text-dim: #8e8ea0;
}
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: 'Share Tech Mono', monospace;
max-width: 900px;
margin: 0 auto;
padding: 1rem;
background: var(--bg-dark);
color: var(--text);
position: relative;
}
/* Scanline overlay */
body::before {
content: '';
position: fixed;
top: 0; left: 0; right: 0; bottom: 0;
background: repeating-linear-gradient(
0deg,
transparent,
transparent 2px,
rgba(0, 0, 0, 0.08) 2px,
rgba(0, 0, 0, 0.08) 4px
);
pointer-events: none;
z-index: 9999;
}
h1, h2, h3 {
font-family: 'Orbitron', sans-serif;
margin-bottom: 0.5rem;
color: var(--neon-cyan);
text-shadow: 0 0 10px rgba(5, 217, 232, 0.5), 0 0 40px rgba(5, 217, 232, 0.2);
text-transform: uppercase;
letter-spacing: 2px;
}
a { color: var(--neon-pink); text-decoration: none; transition: all 0.2s; }
a:hover {
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8), 0 0 20px rgba(255, 42, 109, 0.4);
}
nav {
display: flex;
gap: 1.5rem;
padding: 1rem 0;
border-bottom: 1px solid var(--border);
margin-bottom: 1.5rem;
position: relative;
}
nav::after {
content: '';
position: absolute;
bottom: -1px;
left: 0;
right: 0;
height: 1px;
background: linear-gradient(90deg, var(--neon-pink), var(--neon-cyan), var(--neon-purple));
opacity: 0.6;
}
nav a {
font-family: 'Orbitron', sans-serif;
font-weight: 700;
font-size: 0.85rem;
letter-spacing: 1px;
text-transform: uppercase;
padding: 0.3rem 0;
border-bottom: 2px solid transparent;
transition: all 0.2s;
}
nav a:hover {
border-bottom-color: var(--neon-pink);
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8);
}
table {
width: 100%;
border-collapse: collapse;
margin: 1rem 0;
border: 1px solid var(--border);
}
th, td {
text-align: left;
padding: 0.6rem 0.75rem;
border-bottom: 1px solid var(--border);
}
th {
font-family: 'Orbitron', sans-serif;
color: var(--neon-cyan);
font-size: 0.7rem;
text-transform: uppercase;
letter-spacing: 2px;
background: var(--bg-panel);
border-bottom: 1px solid var(--neon-cyan);
text-shadow: 0 0 6px rgba(5, 217, 232, 0.3);
}
tr:hover td {
background: rgba(5, 217, 232, 0.03);
}
form {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
align-items: end;
margin: 1rem 0;
padding: 1rem;
border: 1px solid var(--border);
background: var(--bg-panel);
}
input, select {
padding: 0.5rem 0.6rem;
border: 1px solid var(--border);
border-radius: 2px;
background: var(--bg-input);
color: var(--neon-cyan);
font-family: 'Share Tech Mono', monospace;
transition: all 0.2s;
}
input:focus, select:focus {
outline: none;
border-color: var(--neon-cyan);
box-shadow: 0 0 8px rgba(5, 217, 232, 0.3), inset 0 0 8px rgba(5, 217, 232, 0.05);
}
button {
padding: 0.5rem 1.2rem;
border: 1px solid var(--neon-pink);
border-radius: 2px;
background: transparent;
color: var(--neon-pink);
cursor: pointer;
font-family: 'Orbitron', sans-serif;
font-weight: 700;
font-size: 0.7rem;
letter-spacing: 1px;
text-transform: uppercase;
transition: all 0.2s;
}
button:hover {
background: var(--neon-pink);
color: var(--bg-dark);
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5), 0 0 30px rgba(255, 42, 109, 0.2);
}
button.danger {
border-color: var(--text-dim);
color: var(--text-dim);
}
button.danger:hover {
border-color: var(--neon-pink);
background: var(--neon-pink);
color: var(--bg-dark);
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5);
}
.badge {
display: inline-block;
padding: 0.2rem 0.6rem;
border-radius: 2px;
font-family: 'Orbitron', sans-serif;
font-size: 0.65rem;
font-weight: 700;
letter-spacing: 1px;
text-transform: uppercase;
}
.badge.yes {
background: rgba(5, 217, 232, 0.1);
color: var(--neon-cyan);
border: 1px solid var(--neon-cyan);
text-shadow: 0 0 6px rgba(5, 217, 232, 0.5);
}
.badge.no {
background: rgba(255, 42, 109, 0.1);
color: var(--neon-pink);
border: 1px solid var(--neon-pink);
text-shadow: 0 0 6px rgba(255, 42, 109, 0.5);
}
.missing-list { font-size: 0.85rem; color: var(--text-dim); }
label {
font-size: 0.75rem;
color: var(--text-dim);
display: flex;
flex-direction: column;
gap: 0.2rem;
text-transform: uppercase;
letter-spacing: 1px;
}
.flash {
padding: 0.5rem 1rem;
margin: 0.5rem 0;
border-radius: 2px;
background: rgba(5, 217, 232, 0.1);
color: var(--neon-cyan);
border: 1px solid var(--neon-cyan);
}

View File

@@ -0,0 +1,30 @@
{% extends "base.html" %}
{% block title %}What Can I Make? - Van{% endblock %}
{% block content %}
<h1>What Can I Make?</h1>
<table>
<thead>
<tr><th>Meal</th><th>Status</th><th>Missing</th></tr>
</thead>
<tbody>
{% for meal in availability %}
<tr>
<td><a href="/meals/{{ meal.meal_id }}">{{ meal.meal_name }}</a></td>
<td>
{% if meal.can_make %}
<span class="badge yes">Ready</span>
{% else %}
<span class="badge no">Missing items</span>
{% endif %}
</td>
<td class="missing-list">
{% for m in meal.missing %}
{{ m.item_name }}: need {{ m.short }} more {{ m.unit }}{% if not loop.last %}, {% endif %}
{% endfor %}
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endblock %}

View File

@@ -0,0 +1,20 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}Van Inventory{% endblock %}</title>
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link href="https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&display=swap" rel="stylesheet">
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<nav>
<a href="/">Inventory</a>
<a href="/meals">Meals</a>
<a href="/availability">What Can I Make?</a>
</nav>
{% block content %}{% endblock %}
</body>
</html>

View File

@@ -0,0 +1,17 @@
{% extends "base.html" %}
{% block title %}Inventory - Van{% endblock %}
{% block content %}
<h1>Van Inventory</h1>
<form hx-post="/items" hx-target="#item-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
<label>Name <input type="text" name="name" required></label>
<label>Qty <input type="number" name="quantity" step="any" value="0" min="0" required></label>
<label>Unit <input type="text" name="unit" required placeholder="lbs, cans, etc"></label>
<label>Category <input type="text" name="category" placeholder="optional"></label>
<button type="submit">Add Item</button>
</form>
<div id="item-list">
{% include "partials/item_rows.html" %}
</div>
{% endblock %}

View File

@@ -0,0 +1,24 @@
{% extends "base.html" %}
{% block title %}{{ meal.name }} - Van{% endblock %}
{% block content %}
<h1>{{ meal.name }}</h1>
{% if meal.instructions %}<p>{{ meal.instructions }}</p>{% endif %}
<h2>Ingredients</h2>
<form hx-post="/meals/{{ meal.id }}/ingredients" hx-target="#ingredient-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
<label>Item
<select name="item_id" required>
<option value="">--</option>
{% for item in items %}
<option value="{{ item.id }}">{{ item.name }} ({{ item.unit }})</option>
{% endfor %}
</select>
</label>
<label>Qty needed <input type="number" name="quantity_needed" step="any" min="0.01" required></label>
<button type="submit">Add</button>
</form>
<div id="ingredient-list">
{% include "partials/ingredient_rows.html" %}
</div>
{% endblock %}

View File

@@ -0,0 +1,15 @@
{% extends "base.html" %}
{% block title %}Meals - Van{% endblock %}
{% block content %}
<h1>Meals</h1>
<form hx-post="/meals" hx-target="#meal-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
<label>Name <input type="text" name="name" required></label>
<label>Instructions <input type="text" name="instructions" placeholder="optional"></label>
<button type="submit">Add Meal</button>
</form>
<div id="meal-list">
{% include "partials/meal_rows.html" %}
</div>
{% endblock %}

View File

@@ -0,0 +1,16 @@
<table>
<thead>
<tr><th>Item</th><th>Needed</th><th>Have</th><th>Unit</th><th></th></tr>
</thead>
<tbody>
{% for mi in meal.ingredients %}
<tr>
<td>{{ mi.item.name }}</td>
<td>{{ mi.quantity_needed }}</td>
<td>{{ mi.item.quantity }}</td>
<td>{{ mi.item.unit }}</td>
<td><button class="danger" hx-delete="/meals/{{ meal.id }}/ingredients/{{ mi.item_id }}" hx-target="#ingredient-list" hx-swap="innerHTML" hx-confirm="Remove {{ mi.item.name }}?">X</button></td>
</tr>
{% endfor %}
</tbody>
</table>

View File

@@ -0,0 +1,21 @@
<table>
<thead>
<tr><th>Name</th><th>Qty</th><th>Unit</th><th>Category</th><th></th></tr>
</thead>
<tbody>
{% for item in items %}
<tr>
<td>{{ item.name }}</td>
<td>
<form hx-patch="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" style="display:inline; margin:0;">
<input type="number" name="quantity" value="{{ item.quantity }}" step="any" min="0" style="width:5rem">
<button type="submit" style="padding:0.2rem 0.5rem; font-size:0.8rem;">Update</button>
</form>
</td>
<td>{{ item.unit }}</td>
<td>{{ item.category or "" }}</td>
<td><button class="danger" hx-delete="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" hx-confirm="Delete {{ item.name }}?">X</button></td>
</tr>
{% endfor %}
</tbody>
</table>

View File

@@ -0,0 +1,15 @@
<table>
<thead>
<tr><th>Name</th><th>Ingredients</th><th>Instructions</th><th></th></tr>
</thead>
<tbody>
{% for meal in meals %}
<tr>
<td><a href="/meals/{{ meal.id }}">{{ meal.name }}</a></td>
<td>{{ meal.ingredients | length }}</td>
<td>{{ (meal.instructions or "")[:50] }}</td>
<td><button class="danger" hx-delete="/meals/{{ meal.id }}" hx-target="#meal-list" hx-swap="innerHTML" hx-confirm="Delete {{ meal.name }}?">X</button></td>
</tr>
{% endfor %}
</tbody>
</table>

View File

@@ -11,9 +11,10 @@
authentication = pkgs.lib.mkOverride 10 '' authentication = pkgs.lib.mkOverride 10 ''
# admins # admins
# These are required for the nixos postgresql setup
local all postgres trust local all postgres trust
host all postgres 127.0.0.1/32 trust host all postgres 127.0.0.1/32 trust
host all postgres ::1/128 trust host all postgres ::1/128 trust
local all richie trust local all richie trust
host all richie 127.0.0.1/32 trust host all richie 127.0.0.1/32 trust
@@ -21,6 +22,8 @@
host all richie 192.168.90.1/24 trust host all richie 192.168.90.1/24 trust
host all richie 192.168.99.1/24 trust host all richie 192.168.99.1/24 trust
local vaninventory vaninventory trust
#type database DBuser origin-address auth-method #type database DBuser origin-address auth-method
local hass hass trust local hass hass trust
@@ -62,6 +65,13 @@
replication = true; replication = true;
}; };
} }
{
name = "vaninventory";
ensureDBOwnership = true;
ensureClauses = {
login = true;
};
}
{ {
name = "hass"; name = "hass";
ensureDBOwnership = true; ensureDBOwnership = true;
@@ -76,6 +86,7 @@
ensureDatabases = [ ensureDatabases = [
"hass" "hass"
"richie" "richie"
"vaninventory"
]; ];
# Thank you NotAShelf # Thank you NotAShelf
# https://github.com/NotAShelf/nyx/blob/d407b4d6e5ab7f60350af61a3d73a62a5e9ac660/modules/core/roles/server/system/services/databases/postgresql.nix#L74 # https://github.com/NotAShelf/nyx/blob/d407b4d6e5ab7f60350af61a3d73a62a5e9ac660/modules/core/roles/server/system/services/databases/postgresql.nix#L74

View File

@@ -0,0 +1,48 @@
{
pkgs,
inputs,
...
}:
{
networking.firewall.allowedTCPPorts = [ 8001 ];
users.users.vaninventory = {
isSystemUser = true;
group = "vaninventory";
};
users.groups.vaninventory = { };
systemd.services.van_inventory = {
description = "Van Inventory API";
after = [
"network.target"
"postgresql.service"
];
requires = [ "postgresql.service" ];
wantedBy = [ "multi-user.target" ];
environment = {
PYTHONPATH = "${inputs.self}/";
VAN_INVENTORY_DB = "vaninventory";
VAN_INVENTORY_USER = "vaninventory";
VAN_INVENTORY_HOST = "/run/postgresql";
VAN_INVENTORY_PORT = "5432";
};
serviceConfig = {
Type = "simple";
User = "van-inventory";
Group = "van-inventory";
ExecStart = "${pkgs.my_python}/bin/python -m python.van_inventory.main --host 0.0.0.0 --port 8001";
Restart = "on-failure";
RestartSec = "5s";
StandardOutput = "journal";
StandardError = "journal";
NoNewPrivileges = true;
ProtectSystem = "strict";
ProtectHome = "read-only";
PrivateTmp = true;
ReadOnlyPaths = [ "${inputs.self}" ];
};
};
}

View File

@@ -1,236 +0,0 @@
"""Tests for python/api modules."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from python.api.routers.contact import (
ContactBase,
ContactCreate,
ContactListResponse,
ContactRelationshipCreate,
ContactRelationshipResponse,
ContactRelationshipUpdate,
ContactUpdate,
GraphData,
GraphEdge,
GraphNode,
NeedBase,
NeedCreate,
NeedResponse,
RelationshipTypeInfo,
router,
)
from python.api.routers.frontend import create_frontend_router
from python.orm.contact import RelationshipType
# --- Pydantic schema tests ---
def test_need_base() -> None:
"""Test NeedBase schema."""
need = NeedBase(name="ADHD", description="Attention deficit")
assert need.name == "ADHD"
def test_need_create() -> None:
"""Test NeedCreate schema."""
need = NeedCreate(name="Light Sensitive")
assert need.name == "Light Sensitive"
assert need.description is None
def test_need_response() -> None:
"""Test NeedResponse schema."""
need = NeedResponse(id=1, name="ADHD", description="test")
assert need.id == 1
def test_contact_base() -> None:
"""Test ContactBase schema."""
contact = ContactBase(name="John")
assert contact.name == "John"
assert contact.age is None
assert contact.bio is None
def test_contact_create() -> None:
"""Test ContactCreate schema."""
contact = ContactCreate(name="John", need_ids=[1, 2])
assert contact.need_ids == [1, 2]
def test_contact_create_no_needs() -> None:
"""Test ContactCreate with no needs."""
contact = ContactCreate(name="John")
assert contact.need_ids == []
def test_contact_update() -> None:
"""Test ContactUpdate schema."""
update = ContactUpdate(name="Jane", age=30)
assert update.name == "Jane"
assert update.age == 30
def test_contact_update_partial() -> None:
"""Test ContactUpdate with partial data."""
update = ContactUpdate(age=25)
assert update.name is None
assert update.age == 25
def test_contact_list_response() -> None:
"""Test ContactListResponse schema."""
contact = ContactListResponse(id=1, name="John")
assert contact.id == 1
def test_contact_relationship_create() -> None:
"""Test ContactRelationshipCreate schema."""
rel = ContactRelationshipCreate(
related_contact_id=2,
relationship_type=RelationshipType.FRIEND,
)
assert rel.related_contact_id == 2
assert rel.closeness_weight is None
def test_contact_relationship_create_with_weight() -> None:
"""Test ContactRelationshipCreate with custom weight."""
rel = ContactRelationshipCreate(
related_contact_id=2,
relationship_type=RelationshipType.SPOUSE,
closeness_weight=10,
)
assert rel.closeness_weight == 10
def test_contact_relationship_update() -> None:
"""Test ContactRelationshipUpdate schema."""
update = ContactRelationshipUpdate(closeness_weight=8)
assert update.relationship_type is None
assert update.closeness_weight == 8
def test_contact_relationship_response() -> None:
"""Test ContactRelationshipResponse schema."""
resp = ContactRelationshipResponse(
contact_id=1,
related_contact_id=2,
relationship_type="friend",
closeness_weight=6,
)
assert resp.contact_id == 1
def test_relationship_type_info() -> None:
"""Test RelationshipTypeInfo schema."""
info = RelationshipTypeInfo(value="spouse", display_name="Spouse", default_weight=10)
assert info.value == "spouse"
def test_graph_node() -> None:
"""Test GraphNode schema."""
node = GraphNode(id=1, name="John", current_job="Dev")
assert node.id == 1
def test_graph_edge() -> None:
"""Test GraphEdge schema."""
edge = GraphEdge(source=1, target=2, relationship_type="friend", closeness_weight=6)
assert edge.source == 1
def test_graph_data() -> None:
"""Test GraphData schema."""
data = GraphData(
nodes=[GraphNode(id=1, name="John")],
edges=[GraphEdge(source=1, target=2, relationship_type="friend", closeness_weight=6)],
)
assert len(data.nodes) == 1
assert len(data.edges) == 1
# --- frontend router test ---
def test_create_frontend_router(tmp_path: Path) -> None:
"""Test create_frontend_router creates router."""
# Create required assets dir and index.html
assets_dir = tmp_path / "assets"
assets_dir.mkdir()
index = tmp_path / "index.html"
index.write_text("<html></html>")
router = create_frontend_router(tmp_path)
assert router is not None
# --- API main tests ---
def test_create_app() -> None:
"""Test create_app creates FastAPI app."""
with patch("python.api.main.get_postgres_engine"):
from python.api.main import create_app
app = create_app()
assert app is not None
assert app.title == "Contact Database API"
def test_create_app_with_frontend(tmp_path: Path) -> None:
"""Test create_app with frontend directory."""
assets_dir = tmp_path / "assets"
assets_dir.mkdir()
index = tmp_path / "index.html"
index.write_text("<html></html>")
with patch("python.api.main.get_postgres_engine"):
from python.api.main import create_app
app = create_app(frontend_dir=tmp_path)
assert app is not None
def test_build_frontend_none() -> None:
"""Test build_frontend with None returns None."""
from python.api.main import build_frontend
result = build_frontend(None)
assert result is None
def test_build_frontend_missing_dir() -> None:
"""Test build_frontend with missing directory raises."""
from python.api.main import build_frontend
with pytest.raises(FileExistsError):
build_frontend(Path("/nonexistent/path"))
# --- dependencies test ---
def test_db_session_dependency() -> None:
"""Test get_db dependency."""
from python.api.dependencies import get_db
mock_engine = create_engine("sqlite:///:memory:")
mock_request = MagicMock()
mock_request.app.state.engine = mock_engine
gen = get_db(mock_request)
session = next(gen)
assert isinstance(session, Session)
try:
next(gen)
except StopIteration:
pass

View File

@@ -1,469 +0,0 @@
"""Integration tests for API router using SQLite in-memory database."""
from __future__ import annotations
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from python.api.routers.contact import (
ContactCreate,
ContactRelationshipCreate,
ContactRelationshipUpdate,
ContactUpdate,
NeedCreate,
add_contact_relationship,
add_need_to_contact,
create_contact,
create_need,
delete_contact,
delete_need,
get_contact,
get_contact_relationships,
get_need,
get_relationship_graph,
list_contacts,
list_needs,
list_relationship_types,
RelationshipTypeInfo,
remove_contact_relationship,
remove_need_from_contact,
update_contact,
update_contact_relationship,
)
from python.orm.base import RichieBase
from python.orm.contact import Contact, ContactNeed, ContactRelationship, Need, RelationshipType
import pytest
def _create_db() -> Session:
"""Create in-memory SQLite database with schema."""
engine = create_engine("sqlite:///:memory:")
# Create tables without schema prefix for SQLite
RichieBase.metadata.create_all(engine, checkfirst=True)
return Session(engine)
@pytest.fixture
def db() -> Session:
"""Database session fixture."""
engine = create_engine("sqlite:///:memory:")
# SQLite doesn't support schemas, so we need to drop the schema reference
from sqlalchemy import MetaData
meta = MetaData()
for table in RichieBase.metadata.sorted_tables:
# Create table without schema
table.to_metadata(meta)
meta.create_all(engine)
session = Session(engine)
yield session
session.close()
# --- Need CRUD tests ---
def test_create_need(db: Session) -> None:
"""Test creating a need."""
need = create_need(NeedCreate(name="ADHD", description="Attention deficit"), db)
assert need.name == "ADHD"
assert need.id is not None
def test_list_needs(db: Session) -> None:
"""Test listing needs."""
create_need(NeedCreate(name="ADHD"), db)
create_need(NeedCreate(name="Light Sensitive"), db)
needs = list_needs(db)
assert len(needs) == 2
def test_get_need(db: Session) -> None:
"""Test getting a need by ID."""
created = create_need(NeedCreate(name="ADHD"), db)
need = get_need(created.id, db)
assert need.name == "ADHD"
def test_get_need_not_found(db: Session) -> None:
"""Test getting a need that doesn't exist."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
get_need(999, db)
assert exc_info.value.status_code == 404
def test_delete_need(db: Session) -> None:
"""Test deleting a need."""
created = create_need(NeedCreate(name="ADHD"), db)
result = delete_need(created.id, db)
assert result == {"deleted": True}
def test_delete_need_not_found(db: Session) -> None:
"""Test deleting a need that doesn't exist."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
delete_need(999, db)
assert exc_info.value.status_code == 404
# --- Contact CRUD tests ---
def test_create_contact(db: Session) -> None:
"""Test creating a contact."""
contact = create_contact(ContactCreate(name="John"), db)
assert contact.name == "John"
assert contact.id is not None
def test_create_contact_with_needs(db: Session) -> None:
"""Test creating a contact with needs."""
need = create_need(NeedCreate(name="ADHD"), db)
contact = create_contact(ContactCreate(name="John", need_ids=[need.id]), db)
assert len(contact.needs) == 1
def test_list_contacts(db: Session) -> None:
"""Test listing contacts."""
create_contact(ContactCreate(name="John"), db)
create_contact(ContactCreate(name="Jane"), db)
contacts = list_contacts(db)
assert len(contacts) == 2
def test_list_contacts_pagination(db: Session) -> None:
"""Test listing contacts with pagination."""
for i in range(5):
create_contact(ContactCreate(name=f"Contact {i}"), db)
contacts = list_contacts(db, skip=2, limit=2)
assert len(contacts) == 2
def test_get_contact(db: Session) -> None:
"""Test getting a contact by ID."""
created = create_contact(ContactCreate(name="John"), db)
contact = get_contact(created.id, db)
assert contact.name == "John"
def test_get_contact_not_found(db: Session) -> None:
"""Test getting a contact that doesn't exist."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
get_contact(999, db)
assert exc_info.value.status_code == 404
def test_update_contact(db: Session) -> None:
"""Test updating a contact."""
created = create_contact(ContactCreate(name="John"), db)
updated = update_contact(created.id, ContactUpdate(name="Jane", age=30), db)
assert updated.name == "Jane"
assert updated.age == 30
def test_update_contact_with_needs(db: Session) -> None:
"""Test updating a contact's needs."""
need = create_need(NeedCreate(name="ADHD"), db)
created = create_contact(ContactCreate(name="John"), db)
updated = update_contact(created.id, ContactUpdate(need_ids=[need.id]), db)
assert len(updated.needs) == 1
def test_update_contact_not_found(db: Session) -> None:
"""Test updating a contact that doesn't exist."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
update_contact(999, ContactUpdate(name="Jane"), db)
assert exc_info.value.status_code == 404
def test_delete_contact(db: Session) -> None:
"""Test deleting a contact."""
created = create_contact(ContactCreate(name="John"), db)
result = delete_contact(created.id, db)
assert result == {"deleted": True}
def test_delete_contact_not_found(db: Session) -> None:
"""Test deleting a contact that doesn't exist."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
delete_contact(999, db)
assert exc_info.value.status_code == 404
# --- Need-Contact association tests ---
def test_add_need_to_contact(db: Session) -> None:
"""Test adding a need to a contact."""
need = create_need(NeedCreate(name="ADHD"), db)
contact = create_contact(ContactCreate(name="John"), db)
result = add_need_to_contact(contact.id, need.id, db)
assert result == {"added": True}
def test_add_need_to_contact_contact_not_found(db: Session) -> None:
"""Test adding need to nonexistent contact."""
from fastapi import HTTPException
need = create_need(NeedCreate(name="ADHD"), db)
with pytest.raises(HTTPException) as exc_info:
add_need_to_contact(999, need.id, db)
assert exc_info.value.status_code == 404
def test_add_need_to_contact_need_not_found(db: Session) -> None:
"""Test adding nonexistent need to contact."""
from fastapi import HTTPException
contact = create_contact(ContactCreate(name="John"), db)
with pytest.raises(HTTPException) as exc_info:
add_need_to_contact(contact.id, 999, db)
assert exc_info.value.status_code == 404
def test_remove_need_from_contact(db: Session) -> None:
"""Test removing a need from a contact."""
need = create_need(NeedCreate(name="ADHD"), db)
contact = create_contact(ContactCreate(name="John", need_ids=[need.id]), db)
result = remove_need_from_contact(contact.id, need.id, db)
assert result == {"removed": True}
def test_remove_need_from_contact_contact_not_found(db: Session) -> None:
"""Test removing need from nonexistent contact."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
remove_need_from_contact(999, 1, db)
assert exc_info.value.status_code == 404
def test_remove_need_from_contact_need_not_found(db: Session) -> None:
"""Test removing nonexistent need from contact."""
from fastapi import HTTPException
contact = create_contact(ContactCreate(name="John"), db)
with pytest.raises(HTTPException) as exc_info:
remove_need_from_contact(contact.id, 999, db)
assert exc_info.value.status_code == 404
# --- Relationship tests ---
def test_add_contact_relationship(db: Session) -> None:
"""Test adding a relationship between contacts."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
rel = add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
assert rel.contact_id == c1.id
assert rel.related_contact_id == c2.id
def test_add_contact_relationship_default_weight(db: Session) -> None:
"""Test relationship uses default weight from type."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
rel = add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.SPOUSE),
db,
)
assert rel.closeness_weight == RelationshipType.SPOUSE.default_weight
def test_add_contact_relationship_custom_weight(db: Session) -> None:
"""Test relationship with custom weight."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
rel = add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND, closeness_weight=8),
db,
)
assert rel.closeness_weight == 8
def test_add_contact_relationship_contact_not_found(db: Session) -> None:
"""Test adding relationship with nonexistent contact."""
from fastapi import HTTPException
c2 = create_contact(ContactCreate(name="Jane"), db)
with pytest.raises(HTTPException) as exc_info:
add_contact_relationship(
999,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
assert exc_info.value.status_code == 404
def test_add_contact_relationship_related_not_found(db: Session) -> None:
"""Test adding relationship with nonexistent related contact."""
from fastapi import HTTPException
c1 = create_contact(ContactCreate(name="John"), db)
with pytest.raises(HTTPException) as exc_info:
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=999, relationship_type=RelationshipType.FRIEND),
db,
)
assert exc_info.value.status_code == 404
def test_add_contact_relationship_self(db: Session) -> None:
"""Test cannot relate contact to itself."""
from fastapi import HTTPException
c1 = create_contact(ContactCreate(name="John"), db)
with pytest.raises(HTTPException) as exc_info:
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c1.id, relationship_type=RelationshipType.FRIEND),
db,
)
assert exc_info.value.status_code == 400
def test_get_contact_relationships(db: Session) -> None:
"""Test getting relationships for a contact."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
rels = get_contact_relationships(c1.id, db)
assert len(rels) == 1
def test_get_contact_relationships_not_found(db: Session) -> None:
"""Test getting relationships for nonexistent contact."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
get_contact_relationships(999, db)
assert exc_info.value.status_code == 404
def test_update_contact_relationship(db: Session) -> None:
"""Test updating a relationship."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
updated = update_contact_relationship(
c1.id,
c2.id,
ContactRelationshipUpdate(closeness_weight=9),
db,
)
assert updated.closeness_weight == 9
def test_update_contact_relationship_type(db: Session) -> None:
"""Test updating relationship type."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
updated = update_contact_relationship(
c1.id,
c2.id,
ContactRelationshipUpdate(relationship_type=RelationshipType.BEST_FRIEND),
db,
)
assert updated.relationship_type == "best_friend"
def test_update_contact_relationship_not_found(db: Session) -> None:
"""Test updating nonexistent relationship."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
update_contact_relationship(
999,
998,
ContactRelationshipUpdate(closeness_weight=5),
db,
)
assert exc_info.value.status_code == 404
def test_remove_contact_relationship(db: Session) -> None:
"""Test removing a relationship."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
result = remove_contact_relationship(c1.id, c2.id, db)
assert result == {"deleted": True}
def test_remove_contact_relationship_not_found(db: Session) -> None:
"""Test removing nonexistent relationship."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
remove_contact_relationship(999, 998, db)
assert exc_info.value.status_code == 404
# --- list_relationship_types ---
def test_list_relationship_types() -> None:
"""Test listing relationship types."""
types = list_relationship_types()
assert len(types) == len(RelationshipType)
assert all(isinstance(t, RelationshipTypeInfo) for t in types)
# --- graph tests ---
def test_get_relationship_graph(db: Session) -> None:
"""Test getting relationship graph."""
c1 = create_contact(ContactCreate(name="John"), db)
c2 = create_contact(ContactCreate(name="Jane"), db)
add_contact_relationship(
c1.id,
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
db,
)
graph = get_relationship_graph(db)
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
def test_get_relationship_graph_empty(db: Session) -> None:
"""Test getting empty relationship graph."""
graph = get_relationship_graph(db)
assert len(graph.nodes) == 0
assert len(graph.edges) == 0

View File

@@ -1,66 +0,0 @@
"""Extended tests for python/api/main.py."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from python.api.main import build_frontend, create_app
def test_build_frontend_runs_npm(tmp_path: Path) -> None:
"""Test build_frontend runs npm commands."""
source_dir = tmp_path / "frontend"
source_dir.mkdir()
(source_dir / "package.json").write_text('{"name": "test"}')
dist_dir = tmp_path / "build" / "dist"
dist_dir.mkdir(parents=True)
(dist_dir / "index.html").write_text("<html></html>")
def mock_copytree(src: Path, dst: Path, dirs_exist_ok: bool = False) -> None:
if "dist" in str(src):
Path(dst).mkdir(parents=True, exist_ok=True)
(Path(dst) / "index.html").write_text("<html></html>")
with (
patch("python.api.main.subprocess.run") as mock_run,
patch("python.api.main.shutil.copytree") as mock_copy,
patch("python.api.main.shutil.rmtree"),
patch("python.api.main.tempfile.mkdtemp") as mock_mkdtemp,
):
# First mkdtemp for build dir, second for output dir
build_dir = str(tmp_path / "build")
output_dir = str(tmp_path / "output")
mock_mkdtemp.side_effect = [build_dir, output_dir]
# dist_dir exists check
with patch("pathlib.Path.exists", return_value=True):
result = build_frontend(source_dir, cache_dir=tmp_path / ".npm")
assert mock_run.call_count == 2 # npm install + npm run build
def test_build_frontend_no_dist(tmp_path: Path) -> None:
"""Test build_frontend raises when dist directory not found."""
source_dir = tmp_path / "frontend"
source_dir.mkdir()
(source_dir / "package.json").write_text('{"name": "test"}')
with (
patch("python.api.main.subprocess.run"),
patch("python.api.main.shutil.copytree"),
patch("python.api.main.tempfile.mkdtemp", return_value=str(tmp_path / "build")),
pytest.raises(FileNotFoundError, match="Build output not found"),
):
build_frontend(source_dir)
def test_create_app_includes_contact_router() -> None:
"""Test create_app includes contact router."""
app = create_app()
routes = [r.path for r in app.routes]
# Should have API routes
assert any("/api" in r for r in routes)

View File

@@ -1,61 +0,0 @@
"""Tests for api/main.py serve function and frontend router."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
from python.api.main import build_frontend, create_app, serve
def test_build_frontend_none_source() -> None:
"""Test build_frontend returns None when no source dir."""
result = build_frontend(None)
assert result is None
def test_build_frontend_nonexistent_dir(tmp_path: Path) -> None:
"""Test build_frontend raises for nonexistent directory."""
with pytest.raises(FileExistsError):
build_frontend(tmp_path / "nonexistent")
def test_create_app_with_frontend(tmp_path: Path) -> None:
"""Test create_app with frontend directory."""
# Create a minimal frontend dir with assets
assets = tmp_path / "assets"
assets.mkdir()
(tmp_path / "index.html").write_text("<html></html>")
app = create_app(frontend_dir=tmp_path)
routes = [r.path for r in app.routes]
assert any("/api" in r for r in routes)
def test_serve_calls_uvicorn() -> None:
"""Test serve function calls uvicorn.run."""
with (
patch("python.api.main.uvicorn.run") as mock_run,
patch("python.api.main.build_frontend", return_value=None),
patch("python.api.main.configure_logger"),
patch.dict("os.environ", {"HOME": "/tmp"}),
):
serve(host="localhost", port=8000, log_level="INFO")
mock_run.assert_called_once()
def test_serve_with_frontend_dir(tmp_path: Path) -> None:
"""Test serve function with frontend dir."""
assets = tmp_path / "assets"
assets.mkdir()
(tmp_path / "index.html").write_text("<html></html>")
with (
patch("python.api.main.uvicorn.run") as mock_run,
patch("python.api.main.build_frontend", return_value=tmp_path),
patch("python.api.main.configure_logger"),
patch.dict("os.environ", {"HOME": "/tmp"}),
):
serve(host="localhost", frontend_dir=tmp_path, port=8000, log_level="INFO")
mock_run.assert_called_once()

View File

@@ -1,364 +0,0 @@
"""Tests for python/eval_warnings/main.py."""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from zipfile import ZipFile
from io import BytesIO
import pytest
from python.eval_warnings.main import (
EvalWarning,
FileChange,
apply_changes,
compute_warning_hash,
check_duplicate_pr,
download_logs,
extract_referenced_files,
parse_changes,
parse_warnings,
query_ollama,
run_cmd,
create_pr,
)
if TYPE_CHECKING:
pass
def test_eval_warning_frozen() -> None:
"""Test EvalWarning is frozen dataclass."""
w = EvalWarning(system="test", message="warning: test msg")
assert w.system == "test"
assert w.message == "warning: test msg"
def test_file_change() -> None:
"""Test FileChange dataclass."""
fc = FileChange(file_path="test.nix", original="old", fixed="new")
assert fc.file_path == "test.nix"
def test_run_cmd() -> None:
"""Test run_cmd."""
result = run_cmd(["echo", "hello"])
assert result.stdout.strip() == "hello"
def test_run_cmd_check_false() -> None:
"""Test run_cmd with check=False."""
result = run_cmd(["ls", "/nonexistent"], check=False)
assert result.returncode != 0
def test_parse_warnings_basic() -> None:
"""Test parse_warnings extracts warnings."""
logs = {
"build-server1/2_Build.txt": "warning: test warning\nsome other line\ntrace: warning: another warning\n",
}
warnings = parse_warnings(logs)
assert len(warnings) == 2
def test_parse_warnings_ignores_untrusted_flake() -> None:
"""Test parse_warnings ignores untrusted flake settings."""
logs = {
"build-server1/2_Build.txt": "warning: ignoring untrusted flake configuration setting foo\n",
}
warnings = parse_warnings(logs)
assert len(warnings) == 0
def test_parse_warnings_strips_timestamp() -> None:
"""Test parse_warnings strips timestamps."""
logs = {
"build-server1/2_Build.txt": "2024-01-01T00:00:00.000Z warning: test msg\n",
}
warnings = parse_warnings(logs)
assert len(warnings) == 1
w = warnings.pop()
assert w.message == "warning: test msg"
assert w.system == "server1"
def test_parse_warnings_empty() -> None:
"""Test parse_warnings with no warnings."""
logs = {"build-server1/2_Build.txt": "all good\n"}
warnings = parse_warnings(logs)
assert len(warnings) == 0
def test_compute_warning_hash() -> None:
"""Test compute_warning_hash returns consistent 8-char hash."""
warnings = {EvalWarning(system="s1", message="msg1")}
h = compute_warning_hash(warnings)
assert len(h) == 8
# Same input -> same hash
assert compute_warning_hash(warnings) == h
def test_compute_warning_hash_different() -> None:
"""Test different warnings produce different hashes."""
w1 = {EvalWarning(system="s1", message="msg1")}
w2 = {EvalWarning(system="s1", message="msg2")}
assert compute_warning_hash(w1) != compute_warning_hash(w2)
def test_extract_referenced_files(tmp_path: Path) -> None:
"""Test extract_referenced_files reads existing files."""
nix_file = tmp_path / "test.nix"
nix_file.write_text("{ pkgs }: pkgs")
warnings = {EvalWarning(system="s1", message=f"warning: in /nix/store/abc-source/{nix_file}")}
# Won't find the file since it uses absolute paths resolved differently
files = extract_referenced_files(warnings)
# Result depends on actual file resolution
assert isinstance(files, dict)
def test_check_duplicate_pr_no_duplicate() -> None:
"""Test check_duplicate_pr when no duplicate exists."""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "fix: resolve nix eval warnings (abcd1234)\nfix: other (efgh5678)\n"
with patch("python.eval_warnings.main.run_cmd", return_value=mock_result):
assert check_duplicate_pr("xxxxxxxx") is False
def test_check_duplicate_pr_found() -> None:
"""Test check_duplicate_pr when duplicate exists."""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "fix: resolve nix eval warnings (abcd1234)\n"
with patch("python.eval_warnings.main.run_cmd", return_value=mock_result):
assert check_duplicate_pr("abcd1234") is True
def test_check_duplicate_pr_error() -> None:
"""Test check_duplicate_pr raises on error."""
mock_result = MagicMock()
mock_result.returncode = 1
mock_result.stderr = "gh error"
with (
patch("python.eval_warnings.main.run_cmd", return_value=mock_result),
pytest.raises(RuntimeError, match="Failed to check for duplicate PRs"),
):
check_duplicate_pr("test")
def test_parse_changes_basic() -> None:
"""Test parse_changes with valid response."""
response = """## **REASONING**
Some reasoning here.
## **CHANGES**
FILE: test.nix
<<<<<<< ORIGINAL
old line
=======
new line
>>>>>>> FIXED
"""
changes = parse_changes(response)
assert len(changes) == 1
assert changes[0].file_path == "test.nix"
assert changes[0].original == "old line"
assert changes[0].fixed == "new line"
def test_parse_changes_no_changes_section() -> None:
"""Test parse_changes with missing CHANGES section."""
response = "Some text without changes"
changes = parse_changes(response)
assert changes == []
def test_parse_changes_multiple() -> None:
"""Test parse_changes with multiple file changes."""
response = """**CHANGES**
FILE: file1.nix
<<<<<<< ORIGINAL
old1
=======
new1
>>>>>>> FIXED
FILE: file2.nix
<<<<<<< ORIGINAL
old2
=======
new2
>>>>>>> FIXED
"""
changes = parse_changes(response)
assert len(changes) == 2
def test_apply_changes(tmp_path: Path) -> None:
"""Test apply_changes applies changes to files."""
test_file = tmp_path / "test.nix"
test_file.write_text("old content here")
changes = [FileChange(file_path=str(test_file), original="old content", fixed="new content")]
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
applied = apply_changes(changes)
assert applied == 1
assert "new content here" in test_file.read_text()
def test_apply_changes_file_not_found(tmp_path: Path) -> None:
"""Test apply_changes skips missing files."""
changes = [FileChange(file_path=str(tmp_path / "missing.nix"), original="old", fixed="new")]
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
applied = apply_changes(changes)
assert applied == 0
def test_apply_changes_original_not_found(tmp_path: Path) -> None:
"""Test apply_changes skips if original text not in file."""
test_file = tmp_path / "test.nix"
test_file.write_text("different content")
changes = [FileChange(file_path=str(test_file), original="not found", fixed="new")]
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
applied = apply_changes(changes)
assert applied == 0
def test_apply_changes_path_traversal(tmp_path: Path) -> None:
"""Test apply_changes blocks path traversal."""
changes = [FileChange(file_path="/etc/passwd", original="old", fixed="new")]
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
applied = apply_changes(changes)
assert applied == 0
def test_query_ollama_success() -> None:
"""Test query_ollama returns response."""
warnings = {EvalWarning(system="s1", message="warning: test")}
files = {"test.nix": "{ pkgs }: pkgs"}
mock_response = MagicMock()
mock_response.json.return_value = {"response": "some fix suggestion"}
mock_response.raise_for_status.return_value = None
with patch("python.eval_warnings.main.post", return_value=mock_response):
result = query_ollama(warnings, files, "http://localhost:11434")
assert result == "some fix suggestion"
def test_query_ollama_failure() -> None:
"""Test query_ollama returns None on failure."""
from httpx import HTTPError
warnings = {EvalWarning(system="s1", message="warning: test")}
files = {}
with patch("python.eval_warnings.main.post", side_effect=HTTPError("fail")):
result = query_ollama(warnings, files, "http://localhost:11434")
assert result is None
def test_download_logs_success() -> None:
"""Test download_logs extracts build log files from zip."""
# Create a zip file in memory
buf = BytesIO()
with ZipFile(buf, "w") as zf:
zf.writestr("build-server1/2_Build.txt", "warning: test")
zf.writestr("other-file.txt", "not a build log")
zip_bytes = buf.getvalue()
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = zip_bytes
with patch("python.eval_warnings.main.subprocess.run", return_value=mock_result):
logs = download_logs("12345", "owner/repo")
assert "build-server1/2_Build.txt" in logs
assert "other-file.txt" not in logs
def test_download_logs_failure() -> None:
"""Test download_logs raises on failure."""
mock_result = MagicMock()
mock_result.returncode = 1
mock_result.stderr = b"error"
with (
patch("python.eval_warnings.main.subprocess.run", return_value=mock_result),
pytest.raises(RuntimeError, match="Failed to download logs"),
):
download_logs("12345", "owner/repo")
def test_create_pr() -> None:
"""Test create_pr creates branch and PR."""
warnings = {EvalWarning(system="s1", message="warning: test")}
llm_response = "**REASONING**\nSome fix.\n**CHANGES**\nstuff"
mock_diff_result = MagicMock()
mock_diff_result.returncode = 1 # changes exist
call_count = 0
def mock_run_cmd(cmd: list[str], *, check: bool = True) -> MagicMock:
nonlocal call_count
call_count += 1
result = MagicMock()
result.returncode = 0
result.stdout = ""
if "diff" in cmd:
result.returncode = 1
return result
with patch("python.eval_warnings.main.run_cmd", side_effect=mock_run_cmd):
create_pr("abcd1234", warnings, llm_response, "https://example.com/run/1")
assert call_count > 0
def test_create_pr_no_changes() -> None:
"""Test create_pr does nothing when no file changes."""
warnings = {EvalWarning(system="s1", message="warning: test")}
llm_response = "**REASONING**\nNo changes needed.\n**CHANGES**\n"
def mock_run_cmd(cmd: list[str], *, check: bool = True) -> MagicMock:
result = MagicMock()
result.returncode = 0
result.stdout = ""
return result
with patch("python.eval_warnings.main.run_cmd", side_effect=mock_run_cmd):
create_pr("abcd1234", warnings, llm_response, "https://example.com/run/1")
def test_create_pr_no_reasoning() -> None:
"""Test create_pr handles missing REASONING section."""
warnings = {EvalWarning(system="s1", message="warning: test")}
llm_response = "No reasoning here"
def mock_run_cmd(cmd: list[str], *, check: bool = True) -> MagicMock:
result = MagicMock()
result.returncode = 0 if "diff" not in cmd else 1
result.stdout = ""
return result
with patch("python.eval_warnings.main.run_cmd", side_effect=mock_run_cmd):
create_pr("abcd1234", warnings, llm_response, "https://example.com/run/1")

View File

@@ -1,77 +0,0 @@
"""Extended tests for python/eval_warnings/main.py."""
from __future__ import annotations
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
from python.eval_warnings.main import (
EvalWarning,
extract_referenced_files,
)
def test_extract_referenced_files_nix_store_paths(tmp_path: Path) -> None:
"""Test extracting files from nix store paths."""
# Create matching directory structure
systems_dir = tmp_path / "systems"
systems_dir.mkdir()
nix_file = systems_dir / "test.nix"
nix_file.write_text("{ pkgs }: pkgs")
warnings = {
EvalWarning(
system="s1",
message="warning: in /nix/store/abc-source/systems/test.nix:5: deprecated",
)
}
# Change to tmp_path so relative paths work
old_cwd = os.getcwd()
try:
os.chdir(tmp_path)
files = extract_referenced_files(warnings)
finally:
os.chdir(old_cwd)
assert "systems/test.nix" in files
assert files["systems/test.nix"] == "{ pkgs }: pkgs"
def test_extract_referenced_files_no_files_found() -> None:
"""Test extract_referenced_files when no files are found."""
warnings = {
EvalWarning(
system="s1",
message="warning: something generic without file paths",
)
}
files = extract_referenced_files(warnings)
# Either empty or has flake.nix fallback
assert isinstance(files, dict)
def test_extract_referenced_files_repo_relative_paths(tmp_path: Path) -> None:
"""Test extracting repo-relative file paths."""
# Create the referenced file
systems_dir = tmp_path / "systems" / "foo"
systems_dir.mkdir(parents=True)
nix_file = systems_dir / "bar.nix"
nix_file.write_text("{ config }: {}")
warnings = {
EvalWarning(
system="s1",
message="warning: in systems/foo/bar.nix:10: test",
)
}
old_cwd = os.getcwd()
try:
os.chdir(tmp_path)
files = extract_referenced_files(warnings)
finally:
os.chdir(old_cwd)
assert "systems/foo/bar.nix" in files

View File

@@ -1,115 +0,0 @@
"""Tests for eval_warnings/main.py main() entry point."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
def test_eval_warnings_main_no_warnings() -> None:
"""Test main() when no warnings are found."""
from python.eval_warnings.main import main
with (
patch("python.eval_warnings.main.configure_logger"),
patch("python.eval_warnings.main.download_logs", return_value="clean log"),
patch("python.eval_warnings.main.parse_warnings", return_value=set()),
):
main(
run_id="123",
repo="owner/repo",
ollama_url="http://localhost:11434",
run_url="http://example.com/run",
log_level="INFO",
)
def test_eval_warnings_main_duplicate_pr() -> None:
"""Test main() when a duplicate PR exists."""
from python.eval_warnings.main import main, EvalWarning
warnings = {EvalWarning(system="s1", message="test")}
with (
patch("python.eval_warnings.main.configure_logger"),
patch("python.eval_warnings.main.download_logs", return_value="log"),
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
patch("python.eval_warnings.main.check_duplicate_pr", return_value=True),
):
main(
run_id="123",
repo="owner/repo",
ollama_url="http://localhost:11434",
run_url="http://example.com/run",
)
def test_eval_warnings_main_no_llm_response() -> None:
"""Test main() when LLM returns no response."""
from python.eval_warnings.main import main, EvalWarning
warnings = {EvalWarning(system="s1", message="test")}
with (
patch("python.eval_warnings.main.configure_logger"),
patch("python.eval_warnings.main.download_logs", return_value="log"),
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
patch("python.eval_warnings.main.check_duplicate_pr", return_value=False),
patch("python.eval_warnings.main.extract_referenced_files", return_value={}),
patch("python.eval_warnings.main.query_ollama", return_value=None),
):
main(
run_id="123",
repo="owner/repo",
ollama_url="http://localhost:11434",
run_url="http://example.com/run",
)
def test_eval_warnings_main_no_changes_applied() -> None:
"""Test main() when no changes are applied."""
from python.eval_warnings.main import main, EvalWarning
warnings = {EvalWarning(system="s1", message="test")}
with (
patch("python.eval_warnings.main.configure_logger"),
patch("python.eval_warnings.main.download_logs", return_value="log"),
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
patch("python.eval_warnings.main.check_duplicate_pr", return_value=False),
patch("python.eval_warnings.main.extract_referenced_files", return_value={}),
patch("python.eval_warnings.main.query_ollama", return_value="some response"),
patch("python.eval_warnings.main.parse_changes", return_value=[]),
patch("python.eval_warnings.main.apply_changes", return_value=0),
):
main(
run_id="123",
repo="owner/repo",
ollama_url="http://localhost:11434",
run_url="http://example.com/run",
)
def test_eval_warnings_main_full_success() -> None:
"""Test main() full success path."""
from python.eval_warnings.main import main, EvalWarning
warnings = {EvalWarning(system="s1", message="test")}
with (
patch("python.eval_warnings.main.configure_logger"),
patch("python.eval_warnings.main.download_logs", return_value="log"),
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
patch("python.eval_warnings.main.check_duplicate_pr", return_value=False),
patch("python.eval_warnings.main.extract_referenced_files", return_value={}),
patch("python.eval_warnings.main.query_ollama", return_value="response"),
patch("python.eval_warnings.main.parse_changes", return_value=[{"file": "a.nix"}]),
patch("python.eval_warnings.main.apply_changes", return_value=1),
patch("python.eval_warnings.main.create_pr") as mock_pr,
):
main(
run_id="123",
repo="owner/repo",
ollama_url="http://localhost:11434",
run_url="http://example.com/run",
)
mock_pr.assert_called_once()

View File

@@ -1,248 +0,0 @@
"""Tests for python/heater modules."""
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
if TYPE_CHECKING:
pass
# --- models tests ---
def test_device_config() -> None:
"""Test DeviceConfig creation."""
config = DeviceConfig(device_id="abc123", ip="192.168.1.1", local_key="key123")
assert config.device_id == "abc123"
assert config.ip == "192.168.1.1"
assert config.local_key == "key123"
assert config.version == 3.5
def test_device_config_custom_version() -> None:
"""Test DeviceConfig with custom version."""
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key", version=3.3)
assert config.version == 3.3
def test_heater_status_defaults() -> None:
"""Test HeaterStatus default values."""
status = HeaterStatus(power=True)
assert status.power is True
assert status.setpoint is None
assert status.state is None
assert status.error_code is None
assert status.raw_dps == {}
def test_heater_status_full() -> None:
"""Test HeaterStatus with all fields."""
status = HeaterStatus(
power=True,
setpoint=72,
state="Heat",
error_code=0,
raw_dps={"1": True, "101": 72},
)
assert status.power is True
assert status.setpoint == 72
assert status.state == "Heat"
def test_action_result_success() -> None:
"""Test ActionResult success."""
result = ActionResult(success=True, action="on", power=True)
assert result.success is True
assert result.action == "on"
assert result.power is True
assert result.error is None
def test_action_result_failure() -> None:
"""Test ActionResult failure."""
result = ActionResult(success=False, action="on", error="Connection failed")
assert result.success is False
assert result.error == "Connection failed"
# --- controller tests (with mocked tinytuya) ---
def _get_controller_class() -> type:
"""Import HeaterController with mocked tinytuya."""
mock_tinytuya = MagicMock()
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
# Force reimport
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
return HeaterController
def test_heater_controller_status_success() -> None:
"""Test HeaterController.status returns correct status."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.status.return_value = {"dps": {"1": True, "101": 72, "102": "Heat", "108": 0}}
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
status = controller.status()
assert status.power is True
assert status.setpoint == 72
assert status.state == "Heat"
def test_heater_controller_status_error() -> None:
"""Test HeaterController.status handles device error."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.status.return_value = {"Error": "Connection timeout"}
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
status = controller.status()
assert status.power is False
def test_heater_controller_turn_on() -> None:
"""Test HeaterController.turn_on."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
result = controller.turn_on()
assert result.success is True
assert result.action == "on"
assert result.power is True
def test_heater_controller_turn_on_error() -> None:
"""Test HeaterController.turn_on handles errors."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.side_effect = ConnectionError("timeout")
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
result = controller.turn_on()
assert result.success is False
assert "timeout" in result.error
def test_heater_controller_turn_off() -> None:
"""Test HeaterController.turn_off."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
result = controller.turn_off()
assert result.success is True
assert result.action == "off"
assert result.power is False
def test_heater_controller_turn_off_error() -> None:
"""Test HeaterController.turn_off handles errors."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.side_effect = ConnectionError("timeout")
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
result = controller.turn_off()
assert result.success is False
def test_heater_controller_toggle_on_to_off() -> None:
"""Test HeaterController.toggle when heater is on."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.status.return_value = {"dps": {"1": True}}
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
result = controller.toggle()
assert result.success is True
assert result.action == "off"
def test_heater_controller_toggle_off_to_on() -> None:
"""Test HeaterController.toggle when heater is off."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.status.return_value = {"dps": {"1": False}}
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
controller = HeaterController(config)
result = controller.toggle()
assert result.success is True
assert result.action == "on"

View File

@@ -1,43 +0,0 @@
"""Tests for python/heater/main.py."""
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
def test_create_app() -> None:
"""Test create_app creates FastAPI app."""
mock_tinytuya = MagicMock()
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
assert app is not None
assert app.title == "Heater Control API"
def test_serve_missing_params() -> None:
"""Test serve raises with missing parameters."""
import typer
mock_tinytuya = MagicMock()
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import serve
with patch("python.heater.main.configure_logger"):
try:
serve(host="0.0.0.0", port=8124, log_level="INFO")
except (typer.Exit, SystemExit):
pass

View File

@@ -1,165 +0,0 @@
"""Extended tests for python/heater/main.py - FastAPI routes."""
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
def test_heater_app_routes() -> None:
"""Test heater app has expected routes."""
mock_tinytuya = MagicMock()
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
route_paths = [r.path for r in app.routes]
assert "/status" in route_paths
assert "/on" in route_paths
assert "/off" in route_paths
assert "/toggle" in route_paths
def test_heater_get_status_route() -> None:
"""Test /status route handler."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.status.return_value = {"dps": {"1": True, "101": 72}}
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
# Simulate lifespan by setting controller
app.state.controller = HeaterController(config)
# Find and call the status handler
for route in app.routes:
if hasattr(route, "path") and route.path == "/status":
result = route.endpoint()
assert result.power is True
break
def test_heater_on_route() -> None:
"""Test /on route handler."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
app.state.controller = HeaterController(config)
for route in app.routes:
if hasattr(route, "path") and route.path == "/on":
result = route.endpoint()
assert result.success is True
break
def test_heater_off_route() -> None:
"""Test /off route handler."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
app.state.controller = HeaterController(config)
for route in app.routes:
if hasattr(route, "path") and route.path == "/off":
result = route.endpoint()
assert result.success is True
break
def test_heater_toggle_route() -> None:
"""Test /toggle route handler."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.status.return_value = {"dps": {"1": True}}
mock_device.set_value.return_value = None
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
app.state.controller = HeaterController(config)
for route in app.routes:
if hasattr(route, "path") and route.path == "/toggle":
result = route.endpoint()
assert result.success is True
break
def test_heater_on_route_failure() -> None:
"""Test /on route raises HTTPException on failure."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.side_effect = ConnectionError("fail")
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
from fastapi import HTTPException
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
app.state.controller = HeaterController(config)
import pytest
for route in app.routes:
if hasattr(route, "path") and route.path == "/on":
with pytest.raises(HTTPException):
route.endpoint()
break

View File

@@ -1,103 +0,0 @@
"""Tests for heater/main.py serve function and lifespan."""
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from click.exceptions import Exit
from python.heater.models import DeviceConfig
def test_serve_missing_params() -> None:
"""Test serve raises when device params are missing."""
mock_tinytuya = MagicMock()
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import serve
with pytest.raises(Exit):
serve(host="localhost", port=8124, log_level="INFO")
def test_serve_with_params() -> None:
"""Test serve starts uvicorn when params provided."""
mock_tinytuya = MagicMock()
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import serve
with patch("python.heater.main.uvicorn.run") as mock_run:
serve(
host="localhost",
port=8124,
log_level="INFO",
device_id="abc",
device_ip="10.0.0.1",
local_key="key123",
)
mock_run.assert_called_once()
def test_heater_off_route_failure() -> None:
"""Test /off route raises HTTPException on failure."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
mock_device.set_value.side_effect = ConnectionError("fail")
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
from fastapi import HTTPException
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
app.state.controller = HeaterController(config)
for route in app.routes:
if hasattr(route, "path") and route.path == "/off":
with pytest.raises(HTTPException):
route.endpoint()
break
def test_heater_toggle_route_failure() -> None:
"""Test /toggle route raises HTTPException on failure."""
mock_tinytuya = MagicMock()
mock_device = MagicMock()
# toggle calls status() first then set_value - make set_value fail
mock_device.status.return_value = {"dps": {"1": True}}
mock_device.set_value.side_effect = ConnectionError("fail")
mock_tinytuya.Device.return_value = mock_device
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
if "python.heater.controller" in sys.modules:
del sys.modules["python.heater.controller"]
if "python.heater.main" in sys.modules:
del sys.modules["python.heater.main"]
from python.heater.main import create_app
from python.heater.controller import HeaterController
from fastapi import HTTPException
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
app = create_app(config)
app.state.controller = HeaterController(config)
for route in app.routes:
if hasattr(route, "path") and route.path == "/toggle":
with pytest.raises(HTTPException):
route.endpoint()
break

View File

@@ -1,191 +0,0 @@
"""Tests for python/installer modules."""
from __future__ import annotations
import curses
from unittest.mock import MagicMock, patch
import pytest
from python.installer.tui import (
Cursor,
State,
calculate_device_menu_padding,
get_device,
)
# --- Cursor tests ---
def test_cursor_init() -> None:
"""Test Cursor initialization."""
c = Cursor()
assert c.get_x() == 0
assert c.get_y() == 0
assert c.height == 0
assert c.width == 0
def test_cursor_set_height_width() -> None:
"""Test Cursor set_height and set_width."""
c = Cursor()
c.set_height(100)
c.set_width(200)
assert c.height == 100
assert c.width == 200
def test_cursor_bounce_check() -> None:
"""Test Cursor bounce checks."""
c = Cursor()
c.set_height(10)
c.set_width(20)
assert c.x_bounce_check(-1) == 0
assert c.x_bounce_check(25) == 19
assert c.x_bounce_check(5) == 5
assert c.y_bounce_check(-1) == 0
assert c.y_bounce_check(15) == 9
assert c.y_bounce_check(5) == 5
def test_cursor_set_x_y() -> None:
"""Test Cursor set_x and set_y."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(5)
c.set_y(3)
assert c.get_x() == 5
assert c.get_y() == 3
def test_cursor_set_x_y_bounds() -> None:
"""Test Cursor set_x and set_y with bounds."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(-5)
assert c.get_x() == 0
c.set_y(100)
assert c.get_y() == 9
def test_cursor_move_up() -> None:
"""Test Cursor move_up."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_y(5)
c.move_up()
assert c.get_y() == 4
def test_cursor_move_down() -> None:
"""Test Cursor move_down."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_y(5)
c.move_down()
assert c.get_y() == 6
def test_cursor_move_left() -> None:
"""Test Cursor move_left."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(5)
c.move_left()
assert c.get_x() == 4
def test_cursor_move_right() -> None:
"""Test Cursor move_right."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(5)
c.move_right()
assert c.get_x() == 6
def test_cursor_navigation() -> None:
"""Test Cursor navigation with arrow keys."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(5)
c.set_y(5)
c.navigation(curses.KEY_UP)
assert c.get_y() == 4
c.navigation(curses.KEY_DOWN)
assert c.get_y() == 5
c.navigation(curses.KEY_LEFT)
assert c.get_x() == 4
c.navigation(curses.KEY_RIGHT)
assert c.get_x() == 5
def test_cursor_navigation_unknown_key() -> None:
"""Test Cursor navigation with unknown key (no-op)."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(5)
c.set_y(5)
c.navigation(999) # Unknown key
assert c.get_x() == 5
assert c.get_y() == 5
# --- State tests ---
def test_state_init() -> None:
"""Test State initialization."""
s = State()
assert s.key == 0
assert s.swap_size == 0
assert s.reserve_size == 0
assert s.selected_device_ids == set()
assert s.show_swap_input is False
assert s.show_reserve_input is False
def test_state_get_selected_devices() -> None:
"""Test State.get_selected_devices."""
s = State()
s.selected_device_ids = {"/dev/sda", "/dev/sdb"}
result = s.get_selected_devices()
assert isinstance(result, tuple)
assert set(result) == {"/dev/sda", "/dev/sdb"}
# --- get_device tests ---
def test_get_device() -> None:
"""Test get_device parses device string."""
raw = 'NAME="/dev/sda" SIZE="100G" TYPE="disk" MOUNTPOINTS=""'
device = get_device(raw)
assert device["name"] == "/dev/sda"
assert device["size"] == "100G"
assert device["type"] == "disk"
# --- calculate_device_menu_padding ---
def test_calculate_device_menu_padding() -> None:
"""Test calculate_device_menu_padding."""
devices = [
{"name": "/dev/sda", "size": "100G"},
{"name": "/dev/nvme0n1", "size": "500G"},
]
padding = calculate_device_menu_padding(devices, "name", 2)
assert padding == len("/dev/nvme0n1") + 2

View File

@@ -1,168 +0,0 @@
"""Extended tests for python/installer modules."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from python.installer.__main__ import (
bash_wrapper,
create_zfs_pool,
get_cpu_manufacturer,
partition_disk,
)
from python.installer.tui import (
Cursor,
State,
bash_wrapper as tui_bash_wrapper,
get_device,
calculate_device_menu_padding,
)
# --- installer __main__ tests ---
def test_installer_bash_wrapper_success() -> None:
"""Test installer bash_wrapper on success."""
result = bash_wrapper("echo hello")
assert result.strip() == "hello"
def test_installer_bash_wrapper_error() -> None:
"""Test installer bash_wrapper raises on error."""
with pytest.raises(RuntimeError, match="Failed to run command"):
bash_wrapper("ls /nonexistent/path/that/does/not/exist")
def test_partition_disk() -> None:
"""Test partition_disk calls commands correctly."""
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
partition_disk("/dev/sda", swap_size=8, reserve=0)
assert mock_bash.call_count == 2
def test_partition_disk_with_reserve() -> None:
"""Test partition_disk with reserve space."""
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
partition_disk("/dev/sda", swap_size=8, reserve=10)
assert mock_bash.call_count == 2
def test_partition_disk_minimum_swap() -> None:
"""Test partition_disk enforces minimum swap size."""
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
partition_disk("/dev/sda", swap_size=0, reserve=-1)
# swap_size should be clamped to 1, reserve to 0
assert mock_bash.call_count == 2
def test_create_zfs_pool_single_disk() -> None:
"""Test create_zfs_pool with single disk."""
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
mock_bash.return_value = "NAME\nroot_pool\n"
create_zfs_pool(["/dev/sda-part2"], "/mnt")
assert mock_bash.call_count == 2
def test_create_zfs_pool_mirror() -> None:
"""Test create_zfs_pool with mirror disks."""
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
mock_bash.return_value = "NAME\nroot_pool\n"
create_zfs_pool(["/dev/sda-part2", "/dev/sdb-part2"], "/mnt")
assert mock_bash.call_count == 2
def test_create_zfs_pool_no_disks() -> None:
"""Test create_zfs_pool raises with no disks."""
with pytest.raises(ValueError, match="disks must be a tuple"):
create_zfs_pool([], "/mnt")
def test_get_cpu_manufacturer_amd() -> None:
"""Test get_cpu_manufacturer with AMD CPU."""
output = "vendor_id\t: AuthenticAMD\nmodel name\t: AMD Ryzen 9\n"
with patch("python.installer.__main__.bash_wrapper", return_value=output):
assert get_cpu_manufacturer() == "amd"
def test_get_cpu_manufacturer_intel() -> None:
"""Test get_cpu_manufacturer with Intel CPU."""
output = "vendor_id\t: GenuineIntel\nmodel name\t: Intel Core i9\n"
with patch("python.installer.__main__.bash_wrapper", return_value=output):
assert get_cpu_manufacturer() == "intel"
def test_get_cpu_manufacturer_unknown() -> None:
"""Test get_cpu_manufacturer with unknown CPU raises."""
output = "model name\t: Unknown CPU\n"
with (
patch("python.installer.__main__.bash_wrapper", return_value=output),
pytest.raises(RuntimeError, match="Failed to get CPU manufacturer"),
):
get_cpu_manufacturer()
# --- tui bash_wrapper tests ---
def test_tui_bash_wrapper_success() -> None:
"""Test tui bash_wrapper success."""
result = tui_bash_wrapper("echo hello")
assert result.strip() == "hello"
def test_tui_bash_wrapper_error() -> None:
"""Test tui bash_wrapper raises on error."""
with pytest.raises(RuntimeError, match="Failed to run command"):
tui_bash_wrapper("ls /nonexistent/path/that/does/not/exist")
# --- Cursor boundary tests ---
def test_cursor_move_at_boundaries() -> None:
"""Test cursor doesn't go below 0."""
c = Cursor()
c.set_height(10)
c.set_width(20)
c.set_x(0)
c.set_y(0)
c.move_up()
assert c.get_y() == 0
c.move_left()
assert c.get_x() == 0
def test_cursor_move_at_max_boundaries() -> None:
"""Test cursor doesn't exceed max."""
c = Cursor()
c.set_height(5)
c.set_width(10)
c.set_x(9)
c.set_y(4)
c.move_down()
assert c.get_y() == 4
c.move_right()
assert c.get_x() == 9
# --- get_device additional ---
def test_get_device_with_mountpoint() -> None:
"""Test get_device with mountpoint."""
raw = 'NAME="/dev/sda1" SIZE="512M" TYPE="part" MOUNTPOINTS="/boot"'
device = get_device(raw)
assert device["mountpoints"] == "/boot"
# --- State additional ---
def test_state_selected_devices_empty() -> None:
"""Test State get_selected_devices when empty."""
s = State()
result = s.get_selected_devices()
assert result == ()

View File

@@ -1,50 +0,0 @@
"""Extended tests for python/installer/__main__.py."""
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from python.installer.__main__ import (
create_zfs_datasets,
create_zfs_pool,
get_boot_drive_id,
partition_disk,
)
def test_create_zfs_datasets() -> None:
"""Test create_zfs_datasets creates expected datasets."""
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
mock_bash.return_value = "NAME\nroot_pool\nroot_pool/root\nroot_pool/home\nroot_pool/var\nroot_pool/nix\n"
create_zfs_datasets()
assert mock_bash.call_count == 5 # 4 create + 1 list
def test_create_zfs_datasets_missing(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test create_zfs_datasets exits on missing datasets."""
with (
patch("python.installer.__main__.bash_wrapper") as mock_bash,
pytest.raises(SystemExit),
):
mock_bash.return_value = "NAME\nroot_pool\n"
create_zfs_datasets()
def test_create_zfs_pool_failure(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test create_zfs_pool exits on failure."""
with (
patch("python.installer.__main__.bash_wrapper") as mock_bash,
pytest.raises(SystemExit),
):
mock_bash.return_value = "NAME\n"
create_zfs_pool(["/dev/sda-part2"], "/mnt")
def test_get_boot_drive_id() -> None:
"""Test get_boot_drive_id extracts UUID."""
with patch("python.installer.__main__.bash_wrapper", return_value="UUID\nABCD-1234\n"):
result = get_boot_drive_id("/dev/sda")
assert result == "ABCD-1234"

View File

@@ -1,312 +0,0 @@
"""Additional tests for python/installer/__main__.py covering missing lines."""
from __future__ import annotations
from unittest.mock import MagicMock, call, patch
import pytest
from python.installer.__main__ import (
create_nix_hardware_file,
install_nixos,
installer,
main,
)
# --- create_nix_hardware_file (lines 167-218) ---
def test_create_nix_hardware_file_no_encrypt() -> None:
"""Test create_nix_hardware_file without encryption."""
with (
patch("python.installer.__main__.get_cpu_manufacturer", return_value="amd"),
patch("python.installer.__main__.get_boot_drive_id", return_value="ABCD-1234"),
patch("python.installer.__main__.getrandbits", return_value=0xDEADBEEF),
patch("python.installer.__main__.Path") as mock_path,
):
create_nix_hardware_file("/mnt", ["/dev/sda"], encrypt=None)
mock_path.assert_called_once_with("/mnt/etc/nixos/hardware-configuration.nix")
written_content = mock_path.return_value.write_text.call_args[0][0]
assert "kvm-amd" in written_content
assert "ABCD-1234" in written_content
assert "deadbeef" in written_content
assert "luks" not in written_content
def test_create_nix_hardware_file_with_encrypt() -> None:
"""Test create_nix_hardware_file with encryption enabled."""
with (
patch("python.installer.__main__.get_cpu_manufacturer", return_value="intel"),
patch("python.installer.__main__.get_boot_drive_id", return_value="EFGH-5678"),
patch("python.installer.__main__.getrandbits", return_value=0x12345678),
patch("python.installer.__main__.Path") as mock_path,
):
create_nix_hardware_file("/mnt", ["/dev/sda"], encrypt="mykey")
written_content = mock_path.return_value.write_text.call_args[0][0]
assert "kvm-intel" in written_content
assert "EFGH-5678" in written_content
assert "12345678" in written_content
assert "luks" in written_content
assert "luks-root-pool-sda-part2" in written_content
assert "bypassWorkqueues" in written_content
assert "allowDiscards" in written_content
def test_create_nix_hardware_file_content_structure() -> None:
"""Test create_nix_hardware_file generates correct Nix structure."""
with (
patch("python.installer.__main__.get_cpu_manufacturer", return_value="amd"),
patch("python.installer.__main__.get_boot_drive_id", return_value="UUID-1234"),
patch("python.installer.__main__.getrandbits", return_value=0xAABBCCDD),
patch("python.installer.__main__.Path") as mock_path,
):
create_nix_hardware_file("/mnt", ["/dev/sda"], encrypt=None)
written_content = mock_path.return_value.write_text.call_args[0][0]
assert "{ config, lib, modulesPath, ... }:" in written_content
assert "boot =" in written_content
assert "fileSystems" in written_content
assert "root_pool/root" in written_content
assert "root_pool/home" in written_content
assert "root_pool/var" in written_content
assert "root_pool/nix" in written_content
assert "networking.hostId" in written_content
assert "x86_64-linux" in written_content
# --- install_nixos (lines 221-241) ---
def test_install_nixos_single_disk() -> None:
"""Test install_nixos mounts filesystems and runs nixos-install."""
with (
patch("python.installer.__main__.bash_wrapper") as mock_bash,
patch("python.installer.__main__.run") as mock_run,
patch("python.installer.__main__.create_nix_hardware_file") as mock_hw,
):
install_nixos("/mnt", ["/dev/sda"], encrypt=None)
# 4 mount commands + 1 mkfs.vfat + 1 boot mount + 1 nixos-generate-config = 7 bash_wrapper calls
assert mock_bash.call_count == 7
mock_hw.assert_called_once_with("/mnt", ["/dev/sda"], None)
mock_run.assert_called_once_with(("nixos-install", "--root", "/mnt"), check=True)
def test_install_nixos_multiple_disks() -> None:
"""Test install_nixos formats all disk EFI partitions."""
with (
patch("python.installer.__main__.bash_wrapper") as mock_bash,
patch("python.installer.__main__.run") as mock_run,
patch("python.installer.__main__.create_nix_hardware_file") as mock_hw,
):
install_nixos("/mnt", ["/dev/sda", "/dev/sdb"], encrypt="key")
# 4 mount + 2 mkfs.vfat + 1 boot mount + 1 generate-config = 8
assert mock_bash.call_count == 8
# Check mkfs.vfat called for both disks
bash_calls = [str(c) for c in mock_bash.call_args_list]
assert any("mkfs.vfat" in c and "sda" in c for c in bash_calls)
assert any("mkfs.vfat" in c and "sdb" in c for c in bash_calls)
mock_hw.assert_called_once_with("/mnt", ["/dev/sda", "/dev/sdb"], "key")
def test_install_nixos_mounts_zfs_datasets() -> None:
"""Test install_nixos mounts all required ZFS datasets."""
with (
patch("python.installer.__main__.bash_wrapper") as mock_bash,
patch("python.installer.__main__.run"),
patch("python.installer.__main__.create_nix_hardware_file"),
):
install_nixos("/mnt", ["/dev/sda"], encrypt=None)
bash_calls = [str(c) for c in mock_bash.call_args_list]
assert any("root_pool/root" in c for c in bash_calls)
assert any("root_pool/home" in c for c in bash_calls)
assert any("root_pool/var" in c for c in bash_calls)
assert any("root_pool/nix" in c for c in bash_calls)
# --- installer (lines 244-280) ---
def test_installer_no_encrypt() -> None:
"""Test installer flow without encryption."""
with (
patch("python.installer.__main__.partition_disk") as mock_partition,
patch("python.installer.__main__.Popen") as mock_popen,
patch("python.installer.__main__.Path") as mock_path,
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
patch("python.installer.__main__.install_nixos") as mock_install,
):
installer(
disks=("/dev/sda",),
swap_size=8,
reserve=0,
encrypt_key=None,
)
mock_partition.assert_called_once_with("/dev/sda", 8, 0)
mock_pool.assert_called_once_with(["/dev/sda-part2"], "/tmp/nix_install")
mock_datasets.assert_called_once()
mock_install.assert_called_once_with("/tmp/nix_install", ("/dev/sda",), None)
def test_installer_with_encrypt() -> None:
"""Test installer flow with encryption enabled."""
with (
patch("python.installer.__main__.partition_disk") as mock_partition,
patch("python.installer.__main__.Popen") as mock_popen,
patch("python.installer.__main__.sleep") as mock_sleep,
patch("python.installer.__main__.run") as mock_run,
patch("python.installer.__main__.Path") as mock_path,
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
patch("python.installer.__main__.install_nixos") as mock_install,
):
installer(
disks=("/dev/sda",),
swap_size=8,
reserve=10,
encrypt_key="secret",
)
mock_partition.assert_called_once_with("/dev/sda", 8, 10)
mock_sleep.assert_called_once_with(1)
# cryptsetup luksFormat and luksOpen
assert mock_run.call_count == 2
mock_pool.assert_called_once_with(
["/dev/mapper/luks-root-pool-sda-part2"],
"/tmp/nix_install",
)
mock_datasets.assert_called_once()
mock_install.assert_called_once_with("/tmp/nix_install", ("/dev/sda",), "secret")
def test_installer_multiple_disks_no_encrypt() -> None:
"""Test installer with multiple disks and no encryption."""
with (
patch("python.installer.__main__.partition_disk") as mock_partition,
patch("python.installer.__main__.Popen") as mock_popen,
patch("python.installer.__main__.Path") as mock_path,
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
patch("python.installer.__main__.install_nixos") as mock_install,
):
installer(
disks=("/dev/sda", "/dev/sdb"),
swap_size=4,
reserve=0,
encrypt_key=None,
)
assert mock_partition.call_count == 2
mock_pool.assert_called_once_with(
["/dev/sda-part2", "/dev/sdb-part2"],
"/tmp/nix_install",
)
def test_installer_multiple_disks_with_encrypt() -> None:
"""Test installer with multiple disks and encryption."""
with (
patch("python.installer.__main__.partition_disk") as mock_partition,
patch("python.installer.__main__.Popen") as mock_popen,
patch("python.installer.__main__.sleep") as mock_sleep,
patch("python.installer.__main__.run") as mock_run,
patch("python.installer.__main__.Path") as mock_path,
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
patch("python.installer.__main__.install_nixos") as mock_install,
):
installer(
disks=("/dev/sda", "/dev/sdb"),
swap_size=4,
reserve=2,
encrypt_key="key123",
)
assert mock_partition.call_count == 2
assert mock_sleep.call_count == 2
# 2 disks x 2 cryptsetup commands = 4
assert mock_run.call_count == 4
mock_pool.assert_called_once_with(
["/dev/mapper/luks-root-pool-sda-part2", "/dev/mapper/luks-root-pool-sdb-part2"],
"/tmp/nix_install",
)
# --- main (lines 283-299) ---
def test_main_calls_installer() -> None:
"""Test main function orchestrates TUI and installer."""
mock_state = MagicMock()
mock_state.selected_device_ids = {"/dev/disk/by-id/ata-DISK1"}
mock_state.get_selected_devices.return_value = ("/dev/disk/by-id/ata-DISK1",)
mock_state.swap_size = 8
mock_state.reserve_size = 0
with (
patch("python.installer.__main__.configure_logger"),
patch("python.installer.__main__.curses.wrapper", return_value=mock_state),
patch("python.installer.__main__.getenv", return_value=None),
patch("python.installer.__main__.sleep"),
patch("python.installer.__main__.installer") as mock_installer,
):
main()
mock_installer.assert_called_once_with(
disks=("/dev/disk/by-id/ata-DISK1",),
swap_size=8,
reserve=0,
encrypt_key=None,
)
def test_main_with_encrypt_key() -> None:
"""Test main function passes encrypt key from environment."""
mock_state = MagicMock()
mock_state.selected_device_ids = {"/dev/disk/by-id/ata-DISK1"}
mock_state.get_selected_devices.return_value = ("/dev/disk/by-id/ata-DISK1",)
mock_state.swap_size = 16
mock_state.reserve_size = 5
with (
patch("python.installer.__main__.configure_logger"),
patch("python.installer.__main__.curses.wrapper", return_value=mock_state),
patch("python.installer.__main__.getenv", return_value="my_encrypt_key"),
patch("python.installer.__main__.sleep"),
patch("python.installer.__main__.installer") as mock_installer,
):
main()
mock_installer.assert_called_once_with(
disks=("/dev/disk/by-id/ata-DISK1",),
swap_size=16,
reserve=5,
encrypt_key="my_encrypt_key",
)
def test_main_calls_sleep() -> None:
"""Test main function sleeps for 3 seconds before installing."""
mock_state = MagicMock()
mock_state.selected_device_ids = set()
mock_state.get_selected_devices.return_value = ()
mock_state.swap_size = 0
mock_state.reserve_size = 0
with (
patch("python.installer.__main__.configure_logger"),
patch("python.installer.__main__.curses.wrapper", return_value=mock_state),
patch("python.installer.__main__.getenv", return_value=None),
patch("python.installer.__main__.sleep") as mock_sleep,
patch("python.installer.__main__.installer"),
):
main()
mock_sleep.assert_called_once_with(3)

View File

@@ -1,70 +0,0 @@
"""Extended tests for python/installer/tui.py."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from python.installer.tui import (
Cursor,
State,
bash_wrapper,
calculate_device_menu_padding,
get_device,
get_devices,
status_bar,
)
def test_get_devices() -> None:
"""Test get_devices parses lsblk output."""
mock_output = (
'NAME="/dev/sda" SIZE="100G" TYPE="disk" MOUNTPOINTS=""\n'
'NAME="/dev/sda1" SIZE="512M" TYPE="part" MOUNTPOINTS="/boot"\n'
)
with patch("python.installer.tui.bash_wrapper", return_value=mock_output):
devices = get_devices()
assert len(devices) == 2
assert devices[0]["name"] == "/dev/sda"
assert devices[1]["name"] == "/dev/sda1"
def test_calculate_device_menu_padding_with_padding() -> None:
"""Test calculate_device_menu_padding with custom padding."""
devices = [
{"name": "abc", "size": "100G"},
{"name": "abcdef", "size": "500G"},
]
result = calculate_device_menu_padding(devices, "name", 5)
assert result == len("abcdef") + 5
def test_calculate_device_menu_padding_zero() -> None:
"""Test calculate_device_menu_padding with zero padding."""
devices = [{"name": "abc"}]
result = calculate_device_menu_padding(devices, "name", 0)
assert result == 3
def test_status_bar() -> None:
"""Test status_bar renders without error."""
import curses as _curses
mock_screen = MagicMock()
cursor = Cursor()
cursor.set_height(50)
cursor.set_width(100)
cursor.set_x(5)
cursor.set_y(10)
with patch.object(_curses, "color_pair", return_value=0), patch.object(_curses, "A_REVERSE", 0):
status_bar(mock_screen, cursor, 100, 50)
assert mock_screen.addstr.call_count > 0
def test_get_device_various_formats() -> None:
"""Test get_device with different formats."""
raw = 'NAME="/dev/nvme0n1p1" SIZE="1T" TYPE="nvme" MOUNTPOINTS="/"'
device = get_device(raw)
assert device["name"] == "/dev/nvme0n1p1"
assert device["size"] == "1T"
assert device["type"] == "nvme"
assert device["mountpoints"] == "/"

View File

@@ -1,515 +0,0 @@
"""Additional tests for python/installer/tui.py covering missing lines."""
from __future__ import annotations
import curses
from unittest.mock import MagicMock, call, patch
from python.installer.tui import (
State,
debug_menu,
draw_device_ids,
draw_device_menu,
draw_menu,
get_device_id_mapping,
get_text_input,
reserve_size_input,
set_color,
swap_size_input,
)
# --- set_color (lines 153-156) ---
def test_set_color() -> None:
"""Test set_color initializes curses colors."""
with (
patch("python.installer.tui.curses.start_color") as mock_start,
patch("python.installer.tui.curses.use_default_colors") as mock_defaults,
patch("python.installer.tui.curses.init_pair") as mock_init_pair,
patch.object(curses, "COLORS", 8, create=True),
):
set_color()
mock_start.assert_called_once()
mock_defaults.assert_called_once()
assert mock_init_pair.call_count == 8
mock_init_pair.assert_any_call(1, 0, -1)
mock_init_pair.assert_any_call(8, 7, -1)
# --- debug_menu (lines 166-175) ---
def test_debug_menu_with_key_pressed() -> None:
"""Test debug_menu when a key has been pressed."""
mock_screen = MagicMock()
mock_screen.getmaxyx.return_value = (40, 80)
with patch("python.installer.tui.curses.color_pair", return_value=0):
debug_menu(mock_screen, ord("a"))
# Should show width/height, key pressed, and color blocks
assert mock_screen.addstr.call_count >= 3
def test_debug_menu_no_key_pressed() -> None:
"""Test debug_menu when no key has been pressed (key=0)."""
mock_screen = MagicMock()
mock_screen.getmaxyx.return_value = (40, 80)
with patch("python.installer.tui.curses.color_pair", return_value=0):
debug_menu(mock_screen, 0)
# Check that "No key press detected..." is displayed
calls = [str(c) for c in mock_screen.addstr.call_args_list]
assert any("No key press detected" in c for c in calls)
# --- get_text_input (lines 190-208) ---
def test_get_text_input_enter_key() -> None:
"""Test get_text_input returns input when Enter is pressed."""
mock_screen = MagicMock()
mock_screen.getch.side_effect = [ord("h"), ord("i"), ord("\n")]
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
assert result == "hi"
def test_get_text_input_escape_key() -> None:
"""Test get_text_input returns empty string when Escape is pressed."""
mock_screen = MagicMock()
mock_screen.getch.side_effect = [ord("h"), ord("i"), 27] # 27 = ESC
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
assert result == ""
def test_get_text_input_backspace() -> None:
"""Test get_text_input handles backspace correctly."""
mock_screen = MagicMock()
mock_screen.getch.side_effect = [ord("h"), ord("i"), 127, ord("\n")] # 127 = backspace
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
assert result == "h"
def test_get_text_input_curses_backspace() -> None:
"""Test get_text_input handles curses KEY_BACKSPACE."""
mock_screen = MagicMock()
mock_screen.getch.side_effect = [ord("a"), ord("b"), curses.KEY_BACKSPACE, ord("\n")]
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
assert result == "a"
# --- swap_size_input (lines 226-241) ---
def test_swap_size_input_no_trigger() -> None:
"""Test swap_size_input when not triggered (no enter on swap row)."""
mock_screen = MagicMock()
state = State()
state.key = ord("a")
result = swap_size_input(mock_screen, state, swap_offset=5)
assert result.swap_size == 0
assert result.show_swap_input is False
def test_swap_size_input_enter_triggers_input() -> None:
"""Test swap_size_input when Enter is pressed on the swap row."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(20)
state.cursor.set_width(80)
state.cursor.set_y(5)
state.key = ord("\n")
with patch("python.installer.tui.get_text_input", return_value="16"):
result = swap_size_input(mock_screen, state, swap_offset=5)
assert result.swap_size == 16
assert result.show_swap_input is False
def test_swap_size_input_invalid_value() -> None:
"""Test swap_size_input with invalid (non-integer) input."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(20)
state.cursor.set_width(80)
state.cursor.set_y(5)
state.key = ord("\n")
with patch("python.installer.tui.get_text_input", return_value="abc"):
result = swap_size_input(mock_screen, state, swap_offset=5)
assert result.swap_size == 0
assert result.show_swap_input is False
# Should have shown "Invalid input" message and waited for a key
mock_screen.getch.assert_called_once()
def test_swap_size_input_already_showing() -> None:
"""Test swap_size_input when show_swap_input is already True."""
mock_screen = MagicMock()
state = State()
state.show_swap_input = True
state.key = 0
with patch("python.installer.tui.get_text_input", return_value="8"):
result = swap_size_input(mock_screen, state, swap_offset=5)
assert result.swap_size == 8
assert result.show_swap_input is False
# --- reserve_size_input (lines 259-274) ---
def test_reserve_size_input_no_trigger() -> None:
"""Test reserve_size_input when not triggered."""
mock_screen = MagicMock()
state = State()
state.key = ord("a")
result = reserve_size_input(mock_screen, state, reserve_offset=6)
assert result.reserve_size == 0
assert result.show_reserve_input is False
def test_reserve_size_input_enter_triggers_input() -> None:
"""Test reserve_size_input when Enter is pressed on the reserve row."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(20)
state.cursor.set_width(80)
state.cursor.set_y(6)
state.key = ord("\n")
with patch("python.installer.tui.get_text_input", return_value="32"):
result = reserve_size_input(mock_screen, state, reserve_offset=6)
assert result.reserve_size == 32
assert result.show_reserve_input is False
def test_reserve_size_input_invalid_value() -> None:
"""Test reserve_size_input with invalid (non-integer) input."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(20)
state.cursor.set_width(80)
state.cursor.set_y(6)
state.key = ord("\n")
with patch("python.installer.tui.get_text_input", return_value="xyz"):
result = reserve_size_input(mock_screen, state, reserve_offset=6)
assert result.reserve_size == 0
assert result.show_reserve_input is False
mock_screen.getch.assert_called_once()
def test_reserve_size_input_already_showing() -> None:
"""Test reserve_size_input when show_reserve_input is already True."""
mock_screen = MagicMock()
state = State()
state.show_reserve_input = True
state.key = 0
with patch("python.installer.tui.get_text_input", return_value="10"):
result = reserve_size_input(mock_screen, state, reserve_offset=6)
assert result.reserve_size == 10
assert result.show_reserve_input is False
# --- get_device_id_mapping (lines 308-316) ---
def test_get_device_id_mapping() -> None:
"""Test get_device_id_mapping returns correct mapping."""
find_output = "/dev/disk/by-id/ata-DISK1\n/dev/disk/by-id/ata-DISK2\n"
def mock_bash(cmd: str) -> str:
if cmd.startswith("find"):
return find_output
if "ata-DISK1" in cmd:
return "/dev/sda\n"
if "ata-DISK2" in cmd:
return "/dev/sda\n"
return ""
with patch("python.installer.tui.bash_wrapper", side_effect=mock_bash):
result = get_device_id_mapping()
assert "/dev/sda" in result
assert "/dev/disk/by-id/ata-DISK1" in result["/dev/sda"]
assert "/dev/disk/by-id/ata-DISK2" in result["/dev/sda"]
def test_get_device_id_mapping_multiple_devices() -> None:
"""Test get_device_id_mapping with multiple different devices."""
find_output = "/dev/disk/by-id/ata-DISK1\n/dev/disk/by-id/nvme-DISK2\n"
def mock_bash(cmd: str) -> str:
if cmd.startswith("find"):
return find_output
if "ata-DISK1" in cmd:
return "/dev/sda\n"
if "nvme-DISK2" in cmd:
return "/dev/nvme0n1\n"
return ""
with patch("python.installer.tui.bash_wrapper", side_effect=mock_bash):
result = get_device_id_mapping()
assert "/dev/sda" in result
assert "/dev/nvme0n1" in result
# --- draw_device_ids (lines 354-372) ---
def test_draw_device_ids_no_selection() -> None:
"""Test draw_device_ids without selecting any device."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.key = 0
device_ids = {"/dev/disk/by-id/ata-DISK1"}
menu_width = list(range(0, 60))
with (
patch("python.installer.tui.curses.A_BOLD", 1),
patch("python.installer.tui.curses.color_pair", return_value=0),
):
result_state, result_row = draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
assert result_row == 3
assert len(result_state.selected_device_ids) == 0
def test_draw_device_ids_select_device() -> None:
"""Test draw_device_ids selecting a device with space key."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.cursor.set_y(3)
state.cursor.set_x(0)
state.key = ord(" ")
device_ids = {"/dev/disk/by-id/ata-DISK1"}
menu_width = list(range(0, 60))
with (
patch("python.installer.tui.curses.A_BOLD", 1),
patch("python.installer.tui.curses.color_pair", return_value=0),
):
result_state, result_row = draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
assert "/dev/disk/by-id/ata-DISK1" in result_state.selected_device_ids
def test_draw_device_ids_deselect_device() -> None:
"""Test draw_device_ids deselecting an already selected device."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.cursor.set_y(3)
state.cursor.set_x(0)
state.key = ord(" ")
state.selected_device_ids.add("/dev/disk/by-id/ata-DISK1")
device_ids = {"/dev/disk/by-id/ata-DISK1"}
menu_width = list(range(0, 60))
with (
patch("python.installer.tui.curses.A_BOLD", 1),
patch("python.installer.tui.curses.color_pair", return_value=0),
):
result_state, _ = draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
assert "/dev/disk/by-id/ata-DISK1" not in result_state.selected_device_ids
def test_draw_device_ids_selected_device_color() -> None:
"""Test draw_device_ids applies color to already selected devices."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.key = 0
state.selected_device_ids.add("/dev/disk/by-id/ata-DISK1")
device_ids = {"/dev/disk/by-id/ata-DISK1"}
menu_width = list(range(0, 60))
with (
patch("python.installer.tui.curses.A_BOLD", 1),
patch("python.installer.tui.curses.color_pair", return_value=7) as mock_color,
):
draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
mock_screen.attron.assert_any_call(7)
# --- draw_device_menu (lines 396-434) ---
def test_draw_device_menu() -> None:
"""Test draw_device_menu renders devices and calls draw_device_ids."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.key = 0
devices = [
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
]
device_id_mapping = {
"/dev/sda": {"/dev/disk/by-id/ata-DISK1"},
}
with (
patch("python.installer.tui.curses.color_pair", return_value=0),
patch("python.installer.tui.curses.A_BOLD", 1),
):
result_state, row_number = draw_device_menu(
mock_screen, devices, device_id_mapping, state, menu_start_y=0, menu_start_x=0
)
assert mock_screen.addstr.call_count > 0
assert row_number > 0
def test_draw_device_menu_multiple_devices() -> None:
"""Test draw_device_menu with multiple devices."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.key = 0
devices = [
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
{"name": "/dev/sdb", "size": "200G", "type": "disk", "mountpoints": ""},
]
device_id_mapping = {
"/dev/sda": {"/dev/disk/by-id/ata-DISK1"},
"/dev/sdb": {"/dev/disk/by-id/ata-DISK2"},
}
with (
patch("python.installer.tui.curses.color_pair", return_value=0),
patch("python.installer.tui.curses.A_BOLD", 1),
):
result_state, row_number = draw_device_menu(
mock_screen, devices, device_id_mapping, state, menu_start_y=0, menu_start_x=0
)
# 2 devices + 2 device ids = at least 4 rows past the header
assert row_number >= 4
def test_draw_device_menu_no_device_ids() -> None:
"""Test draw_device_menu when a device has no IDs."""
mock_screen = MagicMock()
state = State()
state.cursor.set_height(40)
state.cursor.set_width(80)
state.key = 0
devices = [
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
]
device_id_mapping: dict[str, set[str]] = {
"/dev/sda": set(),
}
with (
patch("python.installer.tui.curses.color_pair", return_value=0),
patch("python.installer.tui.curses.A_BOLD", 1),
):
result_state, row_number = draw_device_menu(
mock_screen, devices, device_id_mapping, state, menu_start_y=0, menu_start_x=0
)
# Should still work; row_number reflects only the device row (no id rows)
assert row_number >= 2
# --- draw_menu (lines 447-498) ---
def test_draw_menu_quit_immediately() -> None:
"""Test draw_menu exits when 'q' is pressed immediately."""
mock_screen = MagicMock()
mock_screen.getmaxyx.return_value = (40, 80)
mock_screen.getch.return_value = ord("q")
devices = [
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
]
device_id_mapping = {"/dev/sda": {"/dev/disk/by-id/ata-DISK1"}}
with (
patch("python.installer.tui.set_color"),
patch("python.installer.tui.get_devices", return_value=devices),
patch("python.installer.tui.get_device_id_mapping", return_value=device_id_mapping),
patch("python.installer.tui.draw_device_menu", return_value=(State(), 5)),
patch("python.installer.tui.swap_size_input"),
patch("python.installer.tui.reserve_size_input"),
patch("python.installer.tui.status_bar"),
patch("python.installer.tui.debug_menu"),
patch("python.installer.tui.curses.color_pair", return_value=0),
):
result = draw_menu(mock_screen)
assert isinstance(result, State)
mock_screen.clear.assert_called()
mock_screen.refresh.assert_called()
def test_draw_menu_navigation_then_quit() -> None:
"""Test draw_menu handles navigation keys before quitting."""
mock_screen = MagicMock()
mock_screen.getmaxyx.return_value = (40, 80)
# Simulate pressing down arrow then 'q'
mock_screen.getch.side_effect = [curses.KEY_DOWN, ord("q")]
devices = [
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
]
device_id_mapping = {"/dev/sda": set()}
with (
patch("python.installer.tui.set_color"),
patch("python.installer.tui.get_devices", return_value=devices),
patch("python.installer.tui.get_device_id_mapping", return_value=device_id_mapping),
patch("python.installer.tui.draw_device_menu", return_value=(State(), 5)),
patch("python.installer.tui.swap_size_input"),
patch("python.installer.tui.reserve_size_input"),
patch("python.installer.tui.status_bar"),
patch("python.installer.tui.debug_menu"),
patch("python.installer.tui.curses.color_pair", return_value=0),
):
result = draw_menu(mock_screen)
assert isinstance(result, State)

View File

@@ -1,129 +0,0 @@
"""Tests for python/orm modules."""
from __future__ import annotations
from os import environ
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
from python.orm.base import RichieBase, TableBase, get_connection_info, get_postgres_engine
from python.orm.contact import ContactNeed, ContactRelationship, RelationshipType
if TYPE_CHECKING:
pass
def test_richie_base_schema_name() -> None:
"""Test RichieBase has correct schema name."""
assert RichieBase.schema_name == "main"
def test_richie_base_metadata_naming() -> None:
"""Test RichieBase metadata has naming conventions."""
assert RichieBase.metadata.schema == "main"
naming = RichieBase.metadata.naming_convention
assert naming is not None
assert "ix" in naming
assert "uq" in naming
assert "ck" in naming
assert "fk" in naming
assert "pk" in naming
def test_table_base_abstract() -> None:
"""Test TableBase is abstract."""
assert TableBase.__abstract__ is True
def test_get_connection_info_success() -> None:
"""Test get_connection_info with all env vars set."""
env = {
"POSTGRES_DB": "testdb",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpass",
}
with patch.dict(environ, env, clear=False):
result = get_connection_info()
assert result == ("testdb", "localhost", "5432", "testuser", "testpass")
def test_get_connection_info_no_password() -> None:
"""Test get_connection_info with no password."""
env = {
"POSTGRES_DB": "testdb",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_USER": "testuser",
}
# Clear password if set
cleaned = {k: v for k, v in environ.items() if k != "POSTGRES_PASSWORD"}
cleaned.update(env)
with patch.dict(environ, cleaned, clear=True):
result = get_connection_info()
assert result == ("testdb", "localhost", "5432", "testuser", None)
def test_get_connection_info_missing_vars() -> None:
"""Test get_connection_info raises with missing env vars."""
with patch.dict(environ, {}, clear=True), pytest.raises(ValueError, match="Missing environment variables"):
get_connection_info()
def test_get_postgres_engine() -> None:
"""Test get_postgres_engine creates an engine."""
env = {
"POSTGRES_DB": "testdb",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpass",
}
mock_engine = MagicMock()
with patch.dict(environ, env, clear=False), patch("python.orm.base.create_engine", return_value=mock_engine):
engine = get_postgres_engine()
assert engine is mock_engine
# --- Contact ORM tests ---
def test_relationship_type_values() -> None:
"""Test RelationshipType enum values."""
assert RelationshipType.SPOUSE.value == "spouse"
assert RelationshipType.OTHER.value == "other"
def test_relationship_type_default_weight() -> None:
"""Test RelationshipType default weights."""
assert RelationshipType.SPOUSE.default_weight == 10
assert RelationshipType.ACQUAINTANCE.default_weight == 3
assert RelationshipType.OTHER.default_weight == 2
assert RelationshipType.PARENT.default_weight == 9
def test_relationship_type_display_name() -> None:
"""Test RelationshipType display_name."""
assert RelationshipType.BEST_FRIEND.display_name == "Best Friend"
assert RelationshipType.AUNT_UNCLE.display_name == "Aunt Uncle"
assert RelationshipType.SPOUSE.display_name == "Spouse"
def test_all_relationship_types_have_weights() -> None:
"""Test all relationship types have valid weights."""
for rt in RelationshipType:
weight = rt.default_weight
assert 1 <= weight <= 10
def test_contact_need_table_name() -> None:
"""Test ContactNeed table name."""
assert ContactNeed.__tablename__ == "contact_need"
def test_contact_relationship_table_name() -> None:
"""Test ContactRelationship table name."""
assert ContactRelationship.__tablename__ == "contact_relationship"

View File

@@ -1,674 +0,0 @@
"""Tests for python/splendor modules."""
from __future__ import annotations
import random
from unittest.mock import patch
from python.splendor.base import (
BASE_COLORS,
GEM_COLORS,
Action,
BuyCard,
BuyCardReserved,
Card,
GameConfig,
GameState,
Noble,
PlayerState,
ReserveCard,
TakeDifferent,
TakeDouble,
apply_action,
apply_buy_card,
apply_buy_card_reserved,
apply_reserve_card,
apply_take_different,
apply_take_double,
auto_discard_tokens,
check_nobles_for_player,
create_random_cards,
create_random_cards_tier,
create_random_nobles,
enforce_token_limit,
get_default_starting_tokens,
get_legal_actions,
load_cards,
load_nobles,
new_game,
run_game,
)
from python.splendor.bot import (
PersonalizedBot,
PersonalizedBot2,
RandomBot,
buy_card,
buy_card_reserved,
can_bot_afford,
check_cards_in_tier,
take_tokens,
)
from python.splendor.public_state import (
Observation,
ObsCard,
ObsNoble,
ObsPlayer,
to_observation,
_encode_card,
_encode_noble,
_encode_player,
)
from python.splendor.sim import SimStrategy, simulate_step
import pytest
# --- Helper to create a simple game ---
def _make_card(tier: int = 1, points: int = 0, color: str = "white", cost: dict | None = None) -> Card:
if cost is None:
cost = dict.fromkeys(GEM_COLORS, 0)
return Card(tier=tier, points=points, color=color, cost=cost)
def _make_noble(name: str = "Noble", points: int = 3, reqs: dict | None = None) -> Noble:
if reqs is None:
reqs = {"white": 3, "blue": 3, "green": 3}
return Noble(name=name, points=points, requirements=reqs)
def _make_game(num_players: int = 2) -> tuple[GameState, list[RandomBot]]:
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
return game, bots
# --- PlayerState tests ---
def test_player_state_defaults() -> None:
"""Test PlayerState default values."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
assert p.total_tokens() == 0
assert p.score == 0
assert p.card_score == 0
assert p.noble_score == 0
def test_player_add_card() -> None:
"""Test adding a card to player."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
card = _make_card(points=3)
p.add_card(card)
assert len(p.cards) == 1
assert p.card_score == 3
assert p.score == 3
def test_player_add_noble() -> None:
"""Test adding a noble to player."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
noble = _make_noble(points=3)
p.add_noble(noble)
assert len(p.nobles) == 1
assert p.noble_score == 3
assert p.score == 3
def test_player_can_afford_free_card() -> None:
"""Test can_afford with a free card."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
assert p.can_afford(card) is True
def test_player_can_afford_with_tokens() -> None:
"""Test can_afford with tokens."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
p.tokens["white"] = 3
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 3})
assert p.can_afford(card) is True
def test_player_cannot_afford() -> None:
"""Test can_afford returns False when not enough."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 5})
assert p.can_afford(card) is False
def test_player_can_afford_with_gold() -> None:
"""Test can_afford uses gold tokens."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
p.tokens["gold"] = 3
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 3})
assert p.can_afford(card) is True
def test_player_pay_for_card() -> None:
"""Test pay_for_card transfers tokens."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
p.tokens["white"] = 3
card = _make_card(color="white", cost={**dict.fromkeys(GEM_COLORS, 0), "white": 2})
payment = p.pay_for_card(card)
assert payment["white"] == 2
assert p.tokens["white"] == 1
assert len(p.cards) == 1
assert p.discounts["white"] == 1
def test_player_pay_for_card_cannot_afford() -> None:
"""Test pay_for_card raises when cannot afford."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 5})
with pytest.raises(ValueError, match="cannot afford"):
p.pay_for_card(card)
# --- GameState tests ---
def test_get_default_starting_tokens() -> None:
"""Test starting token counts."""
tokens = get_default_starting_tokens(2)
assert tokens["gold"] == 5
assert tokens["white"] == 4 # (4-6+10)//2 = 4
tokens = get_default_starting_tokens(3)
assert tokens["white"] == 5
tokens = get_default_starting_tokens(4)
assert tokens["white"] == 7
def test_new_game() -> None:
"""Test new_game creates valid state."""
game, _ = _make_game(2)
assert len(game.players) == 2
assert game.bank["gold"] == 5
assert len(game.available_nobles) == 3 # 2 players + 1
def test_game_next_player() -> None:
"""Test next_player cycles."""
game, _ = _make_game(2)
assert game.current_player_index == 0
game.next_player()
assert game.current_player_index == 1
game.next_player()
assert game.current_player_index == 0
def test_game_current_player() -> None:
"""Test current_player property."""
game, _ = _make_game(2)
assert game.current_player is game.players[0]
def test_game_check_winner_simple_no_winner() -> None:
"""Test check_winner_simple with no winner."""
game, _ = _make_game(2)
assert game.check_winner_simple() is None
def test_game_check_winner_simple_winner() -> None:
"""Test check_winner_simple with winner."""
game, _ = _make_game(2)
# Give player enough points
for _ in range(15):
game.players[0].add_card(_make_card(points=1))
winner = game.check_winner_simple()
assert winner is game.players[0]
assert game.finished is True
def test_game_refill_table() -> None:
"""Test refill_table fills from decks."""
game, _ = _make_game(2)
# Table should be filled initially
for tier in (1, 2, 3):
assert len(game.table_by_tier[tier]) <= game.config.table_cards_per_tier
# --- Action tests ---
def test_apply_take_different() -> None:
"""Test take different colors."""
game, bots = _make_game(2)
strategy = bots[0]
action = TakeDifferent(colors=["white", "blue", "green"])
apply_take_different(game, strategy, action)
p = game.players[0]
assert p.tokens["white"] == 1
assert p.tokens["blue"] == 1
assert p.tokens["green"] == 1
def test_apply_take_different_invalid() -> None:
"""Test take different with too many colors is truncated."""
game, bots = _make_game(2)
strategy = bots[0]
# 4 colors should be rejected
action = TakeDifferent(colors=["white", "blue", "green", "red"])
apply_take_different(game, strategy, action)
def test_apply_take_double() -> None:
"""Test take double."""
game, bots = _make_game(2)
strategy = bots[0]
action = TakeDouble(color="white")
apply_take_double(game, strategy, action)
p = game.players[0]
assert p.tokens["white"] == 2
def test_apply_take_double_insufficient() -> None:
"""Test take double fails when bank has insufficient."""
game, bots = _make_game(2)
strategy = bots[0]
game.bank["white"] = 2 # Below minimum_tokens_to_buy_2
action = TakeDouble(color="white")
apply_take_double(game, strategy, action)
p = game.players[0]
assert p.tokens["white"] == 0 # No change
def test_apply_buy_card() -> None:
"""Test buy a card."""
game, bots = _make_game(2)
strategy = bots[0]
# Give the player enough tokens
game.players[0].tokens["white"] = 10
game.players[0].tokens["blue"] = 10
game.players[0].tokens["green"] = 10
game.players[0].tokens["red"] = 10
game.players[0].tokens["black"] = 10
if game.table_by_tier[1]:
action = BuyCard(tier=1, index=0)
apply_buy_card(game, strategy, action)
def test_apply_buy_card_reserved() -> None:
"""Test buy a reserved card."""
game, bots = _make_game(2)
strategy = bots[0]
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
game.players[0].reserved.append(card)
action = BuyCardReserved(index=0)
apply_buy_card_reserved(game, strategy, action)
assert len(game.players[0].reserved) == 0
assert len(game.players[0].cards) == 1
def test_apply_reserve_card_from_table() -> None:
"""Test reserve a card from table."""
game, bots = _make_game(2)
strategy = bots[0]
if game.table_by_tier[1]:
action = ReserveCard(tier=1, index=0, from_deck=False)
apply_reserve_card(game, strategy, action)
assert len(game.players[0].reserved) == 1
def test_apply_reserve_card_from_deck() -> None:
"""Test reserve a card from deck."""
game, bots = _make_game(2)
strategy = bots[0]
action = ReserveCard(tier=1, index=None, from_deck=True)
apply_reserve_card(game, strategy, action)
assert len(game.players[0].reserved) == 1
def test_apply_reserve_card_limit() -> None:
"""Test reserve limit."""
game, bots = _make_game(2)
strategy = bots[0]
# Fill reserves
for _ in range(3):
game.players[0].reserved.append(_make_card())
action = ReserveCard(tier=1, index=0, from_deck=False)
apply_reserve_card(game, strategy, action)
assert len(game.players[0].reserved) == 3 # No change
def test_apply_action_unknown_type() -> None:
"""Test apply_action with unknown action type."""
class FakeAction(Action):
pass
game, bots = _make_game(2)
with pytest.raises(ValueError, match="Unknown action type"):
apply_action(game, bots[0], FakeAction())
def test_apply_action_dispatches() -> None:
"""Test apply_action dispatches to correct handler."""
game, bots = _make_game(2)
action = TakeDifferent(colors=["white"])
apply_action(game, bots[0], action)
# --- auto_discard_tokens ---
def test_auto_discard_tokens() -> None:
"""Test auto_discard_tokens."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
p.tokens["white"] = 5
p.tokens["blue"] = 3
discards = auto_discard_tokens(p, 2)
assert sum(discards.values()) == 2
# --- enforce_token_limit ---
def test_enforce_token_limit_under() -> None:
"""Test enforce_token_limit when under limit."""
game, bots = _make_game(2)
p = game.players[0]
p.tokens["white"] = 3
enforce_token_limit(game, bots[0], p)
assert p.tokens["white"] == 3 # No change
def test_enforce_token_limit_over() -> None:
"""Test enforce_token_limit when over limit."""
game, bots = _make_game(2)
p = game.players[0]
for color in BASE_COLORS:
p.tokens[color] = 5
enforce_token_limit(game, bots[0], p)
assert p.total_tokens() <= game.config.token_limit
# --- check_nobles_for_player ---
def test_check_nobles_no_qualification() -> None:
"""Test check_nobles when player doesn't qualify."""
game, bots = _make_game(2)
check_nobles_for_player(game, bots[0], game.players[0])
assert len(game.players[0].nobles) == 0
def test_check_nobles_qualification() -> None:
"""Test check_nobles when player qualifies."""
game, bots = _make_game(2)
p = game.players[0]
# Give enough discounts to qualify for ALL nobles (ensures at least one match)
for color in BASE_COLORS:
p.discounts[color] = 10
check_nobles_for_player(game, bots[0], p)
assert len(p.nobles) >= 1
# --- get_legal_actions ---
def test_get_legal_actions() -> None:
"""Test get_legal_actions returns valid actions."""
game, _ = _make_game(2)
actions = get_legal_actions(game)
assert len(actions) > 0
def test_get_legal_actions_explicit_player() -> None:
"""Test get_legal_actions with explicit player."""
game, _ = _make_game(2)
actions = get_legal_actions(game, game.players[1])
assert len(actions) > 0
# --- create_random helpers ---
def test_create_random_cards() -> None:
"""Test create_random_cards."""
random.seed(42)
cards = create_random_cards()
assert len(cards) > 0
tiers = {c.tier for c in cards}
assert tiers == {1, 2, 3}
def test_create_random_cards_tier() -> None:
"""Test create_random_cards_tier."""
cards = create_random_cards_tier(1, 3, [0, 1], [0, 1])
assert len(cards) == 15 # 5 colors * 3 per color
def test_create_random_nobles() -> None:
"""Test create_random_nobles."""
nobles = create_random_nobles()
assert len(nobles) == 8
assert all(n.points == 3 for n in nobles)
# --- load_cards / load_nobles ---
def test_load_cards(tmp_path: Path) -> None:
"""Test load_cards from file."""
import json
from pathlib import Path
cards_data = [
{"tier": 1, "points": 0, "color": "white", "cost": {"white": 0, "blue": 1}},
]
file = tmp_path / "cards.json"
file.write_text(json.dumps(cards_data))
cards = load_cards(file)
assert len(cards) == 1
def test_load_nobles(tmp_path: Path) -> None:
"""Test load_nobles from file."""
import json
from pathlib import Path
nobles_data = [
{"name": "Noble 1", "points": 3, "requirements": {"white": 3, "blue": 3}},
]
file = tmp_path / "nobles.json"
file.write_text(json.dumps(nobles_data))
nobles = load_nobles(file)
assert len(nobles) == 1
# --- run_game ---
def test_run_game() -> None:
"""Test run_game completes."""
random.seed(42)
game, _ = _make_game(2)
winner, turns = run_game(game)
assert winner is not None
assert turns > 0
def test_run_game_concede() -> None:
"""Test run_game handles player conceding."""
class ConcedingBot(RandomBot):
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
return None
bots = [ConcedingBot("bot1"), RandomBot("bot2")]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
winner, turns = run_game(game)
assert winner is not None
# --- Bot tests ---
def test_random_bot_choose_action() -> None:
"""Test RandomBot.choose_action returns valid action."""
random.seed(42)
game, bots = _make_game(2)
action = bots[0].choose_action(game, game.players[0])
assert action is not None
def test_personalized_bot_choose_action() -> None:
"""Test PersonalizedBot.choose_action."""
random.seed(42)
bot = PersonalizedBot("pbot")
game, _ = _make_game(2)
game.players[0].strategy = bot
action = bot.choose_action(game, game.players[0])
assert action is not None
def test_personalized_bot2_choose_action() -> None:
"""Test PersonalizedBot2.choose_action."""
random.seed(42)
bot = PersonalizedBot2("pbot2")
game, _ = _make_game(2)
game.players[0].strategy = bot
action = bot.choose_action(game, game.players[0])
assert action is not None
def test_can_bot_afford() -> None:
"""Test can_bot_afford function."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
assert can_bot_afford(p, card) is True
def test_check_cards_in_tier() -> None:
"""Test check_cards_in_tier."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
free_card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
expensive_card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 10})
result = check_cards_in_tier([free_card, expensive_card], p)
assert result == [0]
def test_buy_card_function() -> None:
"""Test buy_card helper function."""
game, _ = _make_game(2)
p = game.players[0]
# Give player enough tokens
for c in BASE_COLORS:
p.tokens[c] = 10
result = buy_card(game, p)
assert result is not None or True # May or may not find affordable card
def test_buy_card_reserved_function() -> None:
"""Test buy_card_reserved helper function."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
# No reserved cards
assert buy_card_reserved(p) is None
# With affordable reserved card
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
p.reserved.append(card)
result = buy_card_reserved(p)
assert isinstance(result, BuyCardReserved)
def test_take_tokens_function() -> None:
"""Test take_tokens helper function."""
game, _ = _make_game(2)
result = take_tokens(game)
assert result is not None
def test_take_tokens_empty_bank() -> None:
"""Test take_tokens with empty bank."""
game, _ = _make_game(2)
for c in BASE_COLORS:
game.bank[c] = 0
result = take_tokens(game)
assert result is None
# --- public_state tests ---
def test_encode_card() -> None:
"""Test _encode_card."""
card = _make_card(tier=1, points=2, color="blue", cost={"white": 1, "blue": 2})
obs = _encode_card(card)
assert isinstance(obs, ObsCard)
assert obs.tier == 1
assert obs.points == 2
def test_encode_noble() -> None:
"""Test _encode_noble."""
noble = _make_noble(points=3, reqs={"white": 3, "blue": 3, "green": 3})
obs = _encode_noble(noble)
assert isinstance(obs, ObsNoble)
assert obs.points == 3
def test_encode_player() -> None:
"""Test _encode_player."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
obs = _encode_player(p)
assert isinstance(obs, ObsPlayer)
assert obs.score == 0
def test_to_observation() -> None:
"""Test to_observation creates full observation."""
game, _ = _make_game(2)
obs = to_observation(game)
assert isinstance(obs, Observation)
assert len(obs.players) == 2
assert obs.current_player == 0
# --- sim tests ---
def test_sim_strategy_choose_action_raises() -> None:
"""Test SimStrategy.choose_action raises."""
sim = SimStrategy("sim")
game, _ = _make_game(2)
with pytest.raises(RuntimeError, match="should not be used"):
sim.choose_action(game, game.players[0])
def test_simulate_step() -> None:
"""Test simulate_step returns deep copy."""
random.seed(42)
game, _ = _make_game(2)
action = TakeDifferent(colors=["white", "blue", "green"])
# SimStrategy() in source is missing name arg - patch it
with patch("python.splendor.sim.SimStrategy", lambda: SimStrategy("sim")):
next_state = simulate_step(game, action)
assert next_state is not game
assert next_state.current_player_index != game.current_player_index or len(game.players) == 1

View File

@@ -1,246 +0,0 @@
"""Extra tests for splendor/base.py covering missed lines and branches."""
from __future__ import annotations
import random
from python.splendor.base import (
BASE_COLORS,
GEM_COLORS,
BuyCard,
BuyCardReserved,
Card,
GameConfig,
Noble,
ReserveCard,
TakeDifferent,
TakeDouble,
apply_action,
apply_buy_card,
apply_buy_card_reserved,
apply_reserve_card,
apply_take_different,
apply_take_double,
auto_discard_tokens,
check_nobles_for_player,
create_random_cards,
create_random_nobles,
enforce_token_limit,
get_legal_actions,
new_game,
run_game,
)
from python.splendor.bot import RandomBot
def _make_game(num_players: int = 2):
random.seed(42)
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
return game, bots
def test_auto_discard_tokens_all_zero() -> None:
"""Test auto_discard when all tokens are zero."""
game, _ = _make_game()
p = game.players[0]
for c in GEM_COLORS:
p.tokens[c] = 0
result = auto_discard_tokens(p, 3)
assert sum(result.values()) == 0 # Can't discard from empty
def test_enforce_token_limit_with_fallback() -> None:
"""Test enforce_token_limit uses auto_discard as fallback."""
game, bots = _make_game()
p = game.players[0]
strategy = bots[0]
# Give player many tokens to force discard
for c in BASE_COLORS:
p.tokens[c] = 5
enforce_token_limit(game, strategy, p)
assert p.total_tokens() <= game.config.token_limit
def test_apply_take_different_invalid_color() -> None:
"""Test take different with gold (non-base) color."""
game, bots = _make_game()
action = TakeDifferent(colors=["gold"])
apply_take_different(game, bots[0], action)
# Gold is not in BASE_COLORS, so no tokens should be taken
def test_apply_take_double_invalid_color() -> None:
"""Test take double with gold (non-base) color."""
game, bots = _make_game()
action = TakeDouble(color="gold")
apply_take_double(game, bots[0], action)
def test_apply_take_double_insufficient_bank() -> None:
"""Test take double when bank has fewer than minimum."""
game, bots = _make_game()
game.bank["white"] = 2 # Below minimum_tokens_to_buy_2 (4)
action = TakeDouble(color="white")
apply_take_double(game, bots[0], action)
def test_apply_buy_card_invalid_tier() -> None:
"""Test buy card with invalid tier."""
game, bots = _make_game()
action = BuyCard(tier=99, index=0)
apply_buy_card(game, bots[0], action)
def test_apply_buy_card_invalid_index() -> None:
"""Test buy card with out-of-range index."""
game, bots = _make_game()
action = BuyCard(tier=1, index=99)
apply_buy_card(game, bots[0], action)
def test_apply_buy_card_cannot_afford() -> None:
"""Test buy card when player can't afford."""
game, bots = _make_game()
# Zero out all tokens
for c in GEM_COLORS:
game.players[0].tokens[c] = 0
# Find an expensive card
for tier, row in game.table_by_tier.items():
for idx, card in enumerate(row):
if any(v > 0 for v in card.cost.values()):
action = BuyCard(tier=tier, index=idx)
apply_buy_card(game, bots[0], action)
return
def test_apply_buy_card_reserved_invalid_index() -> None:
"""Test buy reserved card with out-of-range index."""
game, bots = _make_game()
action = BuyCardReserved(index=99)
apply_buy_card_reserved(game, bots[0], action)
def test_apply_buy_card_reserved_cannot_afford() -> None:
"""Test buy reserved card when can't afford."""
game, bots = _make_game()
expensive = Card(tier=3, points=5, color="white", cost={
"white": 10, "blue": 10, "green": 10, "red": 10, "black": 10, "gold": 0
})
game.players[0].reserved.append(expensive)
for c in GEM_COLORS:
game.players[0].tokens[c] = 0
action = BuyCardReserved(index=0)
apply_buy_card_reserved(game, bots[0], action)
def test_apply_reserve_card_at_limit() -> None:
"""Test reserve card when at reserve limit."""
game, bots = _make_game()
p = game.players[0]
# Fill up reserved slots
for _ in range(game.config.reserve_limit):
p.reserved.append(Card(tier=1, points=0, color="white", cost=dict.fromkeys(GEM_COLORS, 0)))
action = ReserveCard(tier=1, index=0, from_deck=False)
apply_reserve_card(game, bots[0], action)
assert len(p.reserved) == game.config.reserve_limit
def test_apply_reserve_card_invalid_tier() -> None:
"""Test reserve face-up card with invalid tier."""
game, bots = _make_game()
action = ReserveCard(tier=99, index=0, from_deck=False)
apply_reserve_card(game, bots[0], action)
def test_apply_reserve_card_invalid_index() -> None:
"""Test reserve face-up card with None index."""
game, bots = _make_game()
action = ReserveCard(tier=1, index=None, from_deck=False)
apply_reserve_card(game, bots[0], action)
def test_apply_reserve_card_from_empty_deck() -> None:
"""Test reserve from deck when deck is empty."""
game, bots = _make_game()
game.decks_by_tier[1] = [] # Empty the deck
action = ReserveCard(tier=1, index=None, from_deck=True)
apply_reserve_card(game, bots[0], action)
def test_apply_reserve_card_no_gold() -> None:
"""Test reserve card when bank has no gold."""
game, bots = _make_game()
game.bank["gold"] = 0
action = ReserveCard(tier=1, index=0, from_deck=True)
reserved_before = len(game.players[0].reserved)
apply_reserve_card(game, bots[0], action)
if len(game.players[0].reserved) > reserved_before:
assert game.players[0].tokens["gold"] == 0
def test_check_nobles_multiple_candidates() -> None:
"""Test check_nobles when player qualifies for multiple nobles."""
game, bots = _make_game()
p = game.players[0]
# Give player huge discounts to qualify for everything
for c in BASE_COLORS:
p.discounts[c] = 20
check_nobles_for_player(game, bots[0], p)
def test_check_nobles_chosen_not_in_available() -> None:
"""Test check_nobles when chosen noble is somehow not available."""
game, bots = _make_game()
p = game.players[0]
for c in BASE_COLORS:
p.discounts[c] = 20
# This tests the normal path - chosen should be in available
def test_run_game_turn_limit() -> None:
"""Test run_game respects turn limit."""
random.seed(99)
bots = [RandomBot(f"bot{i}") for i in range(2)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles, turn_limit=5)
game = new_game(bots, config)
winner, turns = run_game(game)
assert turns <= 5
def test_run_game_action_none() -> None:
"""Test run_game stops when strategy returns None."""
from unittest.mock import MagicMock
bots = [RandomBot(f"bot{i}") for i in range(2)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
# Make the first player's strategy return None
game.players[0].strategy.choose_action = MagicMock(return_value=None)
winner, turns = run_game(game)
assert turns == 1
def test_get_valid_actions_with_reserved() -> None:
"""Test get_valid_actions includes BuyCardReserved when player has reserved cards."""
game, _ = _make_game()
p = game.players[0]
# Give player a free reserved card
free_card = Card(tier=1, points=0, color="white", cost=dict.fromkeys(GEM_COLORS, 0))
p.reserved.append(free_card)
actions = get_legal_actions(game)
assert any(isinstance(a, BuyCardReserved) for a in actions)
def test_get_legal_actions_reserve_from_deck() -> None:
"""Test get_legal_actions includes ReserveCard from deck."""
game, _ = _make_game()
actions = get_legal_actions(game)
assert any(isinstance(a, ReserveCard) and a.from_deck for a in actions)
assert any(isinstance(a, ReserveCard) and not a.from_deck for a in actions)

View File

@@ -1,143 +0,0 @@
"""Tests for PersonalizedBot3 and PersonalizedBot4 edge cases."""
from __future__ import annotations
import random
from python.splendor.base import (
BASE_COLORS,
GEM_COLORS,
BuyCard,
Card,
GameConfig,
GameState,
PlayerState,
ReserveCard,
TakeDifferent,
create_random_cards,
create_random_nobles,
new_game,
run_game,
)
from python.splendor.bot import (
PersonalizedBot2,
PersonalizedBot3,
PersonalizedBot4,
RandomBot,
)
def _make_card(tier: int = 1, points: int = 0, color: str = "white", cost: dict | None = None) -> Card:
if cost is None:
cost = dict.fromkeys(GEM_COLORS, 0)
return Card(tier=tier, points=points, color=color, cost=cost)
def _make_game(bots: list) -> GameState:
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles, turn_limit=100)
return new_game(bots, config)
def test_personalized_bot3_reserves_from_deck() -> None:
"""Test PersonalizedBot3 reserves from deck when no tokens."""
random.seed(42)
bot = PersonalizedBot3("pbot3")
game = _make_game([bot, RandomBot("r")])
p = game.players[0]
p.strategy = bot
# Clear bank to force reserve
for c in BASE_COLORS:
game.bank[c] = 0
# Clear table to prevent buys
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
action = bot.choose_action(game, p)
assert isinstance(action, (ReserveCard, TakeDifferent))
def test_personalized_bot3_fallback_take_different() -> None:
"""Test PersonalizedBot3 falls back to TakeDifferent."""
random.seed(42)
bot = PersonalizedBot3("pbot3")
game = _make_game([bot, RandomBot("r")])
p = game.players[0]
p.strategy = bot
# Empty everything
for c in BASE_COLORS:
game.bank[c] = 0
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
game.decks_by_tier[tier] = []
action = bot.choose_action(game, p)
assert isinstance(action, TakeDifferent)
def test_personalized_bot4_reserves_from_deck() -> None:
"""Test PersonalizedBot4 reserves from deck."""
random.seed(42)
bot = PersonalizedBot4("pbot4")
game = _make_game([bot, RandomBot("r")])
p = game.players[0]
p.strategy = bot
for c in BASE_COLORS:
game.bank[c] = 0
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
action = bot.choose_action(game, p)
assert isinstance(action, (ReserveCard, TakeDifferent))
def test_personalized_bot4_fallback() -> None:
"""Test PersonalizedBot4 fallback with empty everything."""
random.seed(42)
bot = PersonalizedBot4("pbot4")
game = _make_game([bot, RandomBot("r")])
p = game.players[0]
p.strategy = bot
for c in BASE_COLORS:
game.bank[c] = 0
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
game.decks_by_tier[tier] = []
action = bot.choose_action(game, p)
assert isinstance(action, TakeDifferent)
def test_personalized_bot2_fallback_empty_colors() -> None:
"""Test PersonalizedBot2 with very few available colors."""
random.seed(42)
bot = PersonalizedBot2("pbot2")
game = _make_game([bot, RandomBot("r")])
p = game.players[0]
p.strategy = bot
# No table cards, no affordable reserved
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
# Set exactly 2 colors
for c in BASE_COLORS:
game.bank[c] = 0
game.bank["white"] = 1
game.bank["blue"] = 1
action = bot.choose_action(game, p)
assert action is not None
def test_full_game_with_bot3_and_bot4() -> None:
"""Test a full game with bot3 and bot4."""
random.seed(42)
bots = [PersonalizedBot3("b3"), PersonalizedBot4("b4")]
game = _make_game(bots)
winner, turns = run_game(game)
assert winner is not None

View File

@@ -1,230 +0,0 @@
"""Extended tests for python/splendor/bot.py to improve coverage."""
from __future__ import annotations
import random
from python.splendor.base import (
BASE_COLORS,
GEM_COLORS,
BuyCard,
BuyCardReserved,
Card,
GameConfig,
GameState,
PlayerState,
ReserveCard,
TakeDifferent,
create_random_cards,
create_random_nobles,
new_game,
run_game,
)
from python.splendor.bot import (
PersonalizedBot,
PersonalizedBot2,
PersonalizedBot3,
PersonalizedBot4,
RandomBot,
buy_card,
buy_card_reserved,
estimate_value_of_card,
estimate_value_of_token,
take_tokens,
)
def _make_card(tier: int = 1, points: int = 0, color: str = "white", cost: dict | None = None) -> Card:
if cost is None:
cost = dict.fromkeys(GEM_COLORS, 0)
return Card(tier=tier, points=points, color=color, cost=cost)
def _make_game(num_players: int = 2) -> tuple[GameState, list]:
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
return game, bots
def test_random_bot_buys_affordable() -> None:
"""Test RandomBot buys affordable cards."""
random.seed(1)
game, bots = _make_game(2)
p = game.players[0]
for c in BASE_COLORS:
p.tokens[c] = 10
# Should sometimes buy
actions = [bots[0].choose_action(game, p) for _ in range(20)]
buy_actions = [a for a in actions if isinstance(a, BuyCard)]
assert len(buy_actions) > 0
def test_random_bot_reserves() -> None:
"""Test RandomBot reserves cards sometimes."""
random.seed(3)
game, bots = _make_game(2)
actions = [bots[0].choose_action(game, game.players[0]) for _ in range(50)]
reserve_actions = [a for a in actions if isinstance(a, ReserveCard)]
assert len(reserve_actions) > 0
def test_random_bot_choose_discard() -> None:
"""Test RandomBot.choose_discard."""
bot = RandomBot("test")
p = PlayerState(strategy=bot)
p.tokens["white"] = 5
p.tokens["blue"] = 3
discards = bot.choose_discard(None, p, 2)
assert sum(discards.values()) == 2
def test_personalized_bot_takes_different() -> None:
"""Test PersonalizedBot takes different when no affordable cards."""
random.seed(42)
bot = PersonalizedBot("pbot")
game, _ = _make_game(2)
p = game.players[0]
action = bot.choose_action(game, p)
assert action is not None
def test_personalized_bot_choose_discard() -> None:
"""Test PersonalizedBot.choose_discard."""
bot = PersonalizedBot("pbot")
p = PlayerState(strategy=bot)
p.tokens["white"] = 5
discards = bot.choose_discard(None, p, 2)
assert sum(discards.values()) == 2
def test_personalized_bot2_buys_reserved() -> None:
"""Test PersonalizedBot2 buys reserved cards."""
random.seed(42)
bot = PersonalizedBot2("pbot2")
game, _ = _make_game(2)
p = game.players[0]
p.strategy = bot
# Add affordable reserved card
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
p.reserved.append(card)
# Clear table cards to force reserved buy
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
action = bot.choose_action(game, p)
assert isinstance(action, BuyCardReserved)
def test_personalized_bot2_reserves_from_deck() -> None:
"""Test PersonalizedBot2 reserves from deck when few colors available."""
random.seed(42)
bot = PersonalizedBot2("pbot2")
game, _ = _make_game(2)
p = game.players[0]
p.strategy = bot
# Clear table and set only 2 bank colors
for tier in (1, 2, 3):
game.table_by_tier[tier] = []
for c in BASE_COLORS:
game.bank[c] = 0
game.bank["white"] = 1
game.bank["blue"] = 1
action = bot.choose_action(game, p)
assert isinstance(action, (ReserveCard, TakeDifferent))
def test_personalized_bot2_choose_discard() -> None:
"""Test PersonalizedBot2.choose_discard."""
bot = PersonalizedBot2("pbot2")
p = PlayerState(strategy=bot)
p.tokens["red"] = 5
discards = bot.choose_discard(None, p, 2)
assert sum(discards.values()) == 2
def test_personalized_bot3_choose_action() -> None:
"""Test PersonalizedBot3.choose_action."""
random.seed(42)
bot = PersonalizedBot3("pbot3")
game, _ = _make_game(2)
p = game.players[0]
p.strategy = bot
action = bot.choose_action(game, p)
assert action is not None
def test_personalized_bot3_choose_discard() -> None:
"""Test PersonalizedBot3.choose_discard."""
bot = PersonalizedBot3("pbot3")
p = PlayerState(strategy=bot)
p.tokens["green"] = 5
discards = bot.choose_discard(None, p, 2)
assert sum(discards.values()) == 2
def test_personalized_bot4_choose_action() -> None:
"""Test PersonalizedBot4.choose_action."""
random.seed(42)
bot = PersonalizedBot4("pbot4")
game, _ = _make_game(2)
p = game.players[0]
p.strategy = bot
action = bot.choose_action(game, p)
assert action is not None
def test_personalized_bot4_filter_actions() -> None:
"""Test PersonalizedBot4.filter_actions."""
bot = PersonalizedBot4("pbot4")
actions = [
TakeDifferent(colors=["white", "blue", "green"]),
TakeDifferent(colors=["white", "blue"]),
BuyCard(tier=1, index=0),
]
filtered = bot.filter_actions(actions)
# Should keep 3-color TakeDifferent and BuyCard, remove 2-color TakeDifferent
assert len(filtered) == 2
def test_personalized_bot4_choose_discard() -> None:
"""Test PersonalizedBot4.choose_discard."""
bot = PersonalizedBot4("pbot4")
p = PlayerState(strategy=bot)
p.tokens["black"] = 5
discards = bot.choose_discard(None, p, 2)
assert sum(discards.values()) == 2
def test_estimate_value_of_card() -> None:
"""Test estimate_value_of_card."""
game, _ = _make_game(2)
p = game.players[0]
result = estimate_value_of_card(game, p, "white")
assert isinstance(result, int)
def test_estimate_value_of_token() -> None:
"""Test estimate_value_of_token."""
game, _ = _make_game(2)
p = game.players[0]
result = estimate_value_of_token(game, p, "white")
assert isinstance(result, int)
def test_full_game_with_personalized_bots() -> None:
"""Test a full game with different bot types."""
random.seed(42)
bots = [
RandomBot("random"),
PersonalizedBot("p1"),
PersonalizedBot2("p2"),
]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles, turn_limit=200)
game = new_game(bots, config)
winner, turns = run_game(game)
assert winner is not None
assert turns > 0

View File

@@ -1,156 +0,0 @@
"""Tests for python/splendor/human.py - non-TUI parts."""
from __future__ import annotations
from python.splendor.human import (
COST_ABBR,
COLOR_ABBR_TO_FULL,
COLOR_STYLE,
color_token,
fmt_gem,
fmt_number,
format_card,
format_cost,
format_discounts,
format_noble,
format_tokens,
parse_color_token,
)
from python.splendor.base import Card, GEM_COLORS, Noble
import pytest
# --- parse_color_token ---
def test_parse_color_token_full_names() -> None:
"""Test parsing full color names."""
assert parse_color_token("white") == "white"
assert parse_color_token("blue") == "blue"
assert parse_color_token("green") == "green"
assert parse_color_token("red") == "red"
assert parse_color_token("black") == "black"
def test_parse_color_token_abbreviations() -> None:
"""Test parsing abbreviated color names."""
assert parse_color_token("w") == "white"
assert parse_color_token("b") == "blue"
assert parse_color_token("g") == "green"
assert parse_color_token("r") == "red"
assert parse_color_token("k") == "black"
assert parse_color_token("o") == "gold"
def test_parse_color_token_case_insensitive() -> None:
"""Test parsing is case insensitive."""
assert parse_color_token("WHITE") == "white"
assert parse_color_token("B") == "blue"
def test_parse_color_token_unknown() -> None:
"""Test parsing unknown color raises."""
with pytest.raises(ValueError, match="Unknown color"):
parse_color_token("purple")
# --- format functions ---
def test_format_cost() -> None:
"""Test format_cost formats correctly."""
cost = {"white": 2, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0}
result = format_cost(cost)
assert "W:" in result
assert "B:" in result
def test_format_cost_empty() -> None:
"""Test format_cost with all zeros."""
cost = dict.fromkeys(GEM_COLORS, 0)
result = format_cost(cost)
assert result == "-"
def test_format_card() -> None:
"""Test format_card."""
card = Card(tier=1, points=2, color="white", cost={"white": 0, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0})
result = format_card(card)
assert "T1" in result
assert "P2" in result
def test_format_noble() -> None:
"""Test format_noble."""
noble = Noble(name="Noble 1", points=3, requirements={"white": 3, "blue": 3, "green": 3})
result = format_noble(noble)
assert "Noble 1" in result
assert "+3" in result
def test_format_tokens() -> None:
"""Test format_tokens."""
tokens = {"white": 2, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0}
result = format_tokens(tokens)
assert "white:" in result
def test_format_discounts() -> None:
"""Test format_discounts."""
discounts = {"white": 2, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0}
result = format_discounts(discounts)
assert "W:" in result
def test_format_discounts_empty() -> None:
"""Test format_discounts with all zeros."""
discounts = dict.fromkeys(GEM_COLORS, 0)
result = format_discounts(discounts)
assert result == "-"
# --- formatting helpers ---
def test_color_token() -> None:
"""Test color_token."""
result = color_token("white", 3)
assert "white" in result
assert "3" in result
def test_fmt_gem() -> None:
"""Test fmt_gem."""
result = fmt_gem("blue")
assert "blue" in result
def test_fmt_number() -> None:
"""Test fmt_number."""
result = fmt_number(42)
assert "42" in result
# --- constants ---
def test_cost_abbr_all_colors() -> None:
"""Test COST_ABBR has all gem colors."""
for color in GEM_COLORS:
assert color in COST_ABBR
def test_color_abbr_to_full() -> None:
"""Test COLOR_ABBR_TO_FULL mappings."""
assert COLOR_ABBR_TO_FULL["w"] == "white"
assert COLOR_ABBR_TO_FULL["o"] == "gold"
def test_color_style_all_colors() -> None:
"""Test COLOR_STYLE has all gem colors."""
for color in GEM_COLORS:
assert color in COLOR_STYLE
fg, bg = COLOR_STYLE[color]
assert isinstance(fg, str)
assert isinstance(bg, str)

View File

@@ -1,262 +0,0 @@
"""Tests for splendor/human.py command handlers and TUI widgets."""
from __future__ import annotations
import random
from unittest.mock import MagicMock, patch, PropertyMock
from python.splendor.base import (
BASE_COLORS,
GEM_COLORS,
BuyCard,
BuyCardReserved,
Card,
GameConfig,
GameState,
Noble,
PlayerState,
ReserveCard,
TakeDifferent,
TakeDouble,
create_random_cards,
create_random_nobles,
new_game,
)
from python.splendor.bot import RandomBot
from python.splendor.human import (
ActionApp,
Board,
DiscardApp,
NobleChoiceApp,
)
def _make_game(num_players: int = 2):
random.seed(42)
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
return game, bots
# --- ActionApp command handlers ---
def test_action_app_cmd_1_basic() -> None:
"""Test _cmd_1 take different colors."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
app._update_prompt = MagicMock()
app.exit = MagicMock()
result = app._cmd_1(["1", "white", "blue", "green"])
assert result is None
assert isinstance(app.result, TakeDifferent)
assert app.result.colors == ["white", "blue", "green"]
def test_action_app_cmd_1_abbreviations() -> None:
"""Test _cmd_1 with abbreviated colors."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
app.exit = MagicMock()
result = app._cmd_1(["1", "w", "b", "g"])
assert result is None
assert isinstance(app.result, TakeDifferent)
def test_action_app_cmd_1_no_colors() -> None:
"""Test _cmd_1 with no colors."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_1(["1"])
assert result is not None # Error message
def test_action_app_cmd_1_empty_bank() -> None:
"""Test _cmd_1 with empty bank color."""
game, _ = _make_game()
game.bank["white"] = 0
app = ActionApp(game, game.players[0])
result = app._cmd_1(["1", "white"])
assert result is not None # Error message
def test_action_app_cmd_2() -> None:
"""Test _cmd_2 take double."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
app.exit = MagicMock()
result = app._cmd_2(["2", "white"])
assert result is None
assert isinstance(app.result, TakeDouble)
def test_action_app_cmd_2_no_color() -> None:
"""Test _cmd_2 with no color."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_2(["2"])
assert result is not None
def test_action_app_cmd_2_insufficient_bank() -> None:
"""Test _cmd_2 with insufficient bank."""
game, _ = _make_game()
game.bank["white"] = 2
app = ActionApp(game, game.players[0])
result = app._cmd_2(["2", "white"])
assert result is not None
def test_action_app_cmd_3() -> None:
"""Test _cmd_3 buy card."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
app.exit = MagicMock()
result = app._cmd_3(["3", "1", "0"])
assert result is None
assert isinstance(app.result, BuyCard)
def test_action_app_cmd_3_no_args() -> None:
"""Test _cmd_3 with insufficient args."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_3(["3"])
assert result is not None
def test_action_app_cmd_4() -> None:
"""Test _cmd_4 buy reserved card - source has bug passing tier= to BuyCardReserved."""
game, _ = _make_game()
card = Card(tier=1, points=0, color="white", cost=dict.fromkeys(GEM_COLORS, 0))
game.players[0].reserved.append(card)
app = ActionApp(game, game.players[0])
app.exit = MagicMock()
# BuyCardReserved doesn't accept tier=, so the source code has a bug here
import pytest
with pytest.raises(TypeError):
app._cmd_4(["4", "0"])
def test_action_app_cmd_4_no_args() -> None:
"""Test _cmd_4 with no args."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_4(["4"])
assert result is not None
def test_action_app_cmd_4_out_of_range() -> None:
"""Test _cmd_4 with out of range index."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_4(["4", "0"])
assert result is not None
def test_action_app_cmd_5() -> None:
"""Test _cmd_5 reserve face-up card."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
app.exit = MagicMock()
result = app._cmd_5(["5", "1", "0"])
assert result is None
assert isinstance(app.result, ReserveCard)
assert app.result.from_deck is False
def test_action_app_cmd_5_no_args() -> None:
"""Test _cmd_5 with no args."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_5(["5"])
assert result is not None
def test_action_app_cmd_6() -> None:
"""Test _cmd_6 reserve from deck."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
app.exit = MagicMock()
result = app._cmd_6(["6", "1"])
assert result is None
assert isinstance(app.result, ReserveCard)
assert app.result.from_deck is True
def test_action_app_cmd_6_no_args() -> None:
"""Test _cmd_6 with no args."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._cmd_6(["6"])
assert result is not None
def test_action_app_unknown_cmd() -> None:
"""Test unknown command."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
result = app._unknown_cmd(["99"])
assert result == "Unknown command."
# --- ActionApp init ---
def test_action_app_init() -> None:
"""Test ActionApp initialization."""
game, _ = _make_game()
app = ActionApp(game, game.players[0])
assert app.result is None
assert app.message == ""
assert app.game is game
assert app.player is game.players[0]
# --- DiscardApp ---
def test_discard_app_init() -> None:
"""Test DiscardApp initialization."""
game, _ = _make_game()
app = DiscardApp(game, game.players[0])
assert app.discards == dict.fromkeys(GEM_COLORS, 0)
assert app.message == ""
def test_discard_app_remaining_to_discard() -> None:
"""Test DiscardApp._remaining_to_discard."""
game, _ = _make_game()
p = game.players[0]
for c in BASE_COLORS:
p.tokens[c] = 5
app = DiscardApp(game, p)
remaining = app._remaining_to_discard()
assert remaining == p.total_tokens() - game.config.token_limit
# --- NobleChoiceApp ---
def test_noble_choice_app_init() -> None:
"""Test NobleChoiceApp initialization."""
game, _ = _make_game()
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
assert app.result is None
assert app.nobles == nobles
assert app.message == ""
# --- Board ---
def test_board_init() -> None:
"""Test Board initialization."""
game, _ = _make_game()
board = Board(game, game.players[0])
assert board.game is game
assert board.me is game.players[0]

View File

@@ -1,54 +0,0 @@
"""Tests for python/splendor/human.py TUI classes."""
from __future__ import annotations
import random
from python.splendor.base import (
GEM_COLORS,
GameConfig,
PlayerState,
create_random_cards,
create_random_nobles,
new_game,
)
from python.splendor.bot import RandomBot
from python.splendor.human import TuiHuman
def _make_game(num_players: int = 2):
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
return game, bots
def test_tui_human_choose_action_no_tty() -> None:
"""Test TuiHuman returns None when not a TTY."""
random.seed(42)
game, _ = _make_game(2)
human = TuiHuman("test")
# In test environment, stdout is not a TTY
result = human.choose_action(game, game.players[0])
assert result is None
def test_tui_human_choose_discard_no_tty() -> None:
"""Test TuiHuman returns empty discards when not a TTY."""
random.seed(42)
game, _ = _make_game(2)
human = TuiHuman("test")
result = human.choose_discard(game, game.players[0], 2)
assert result == dict.fromkeys(GEM_COLORS, 0)
def test_tui_human_choose_noble_no_tty() -> None:
"""Test TuiHuman returns first noble when not a TTY."""
random.seed(42)
game, _ = _make_game(2)
human = TuiHuman("test")
nobles = game.available_nobles[:2]
result = human.choose_noble(game, game.players[0], nobles)
assert result == nobles[0]

View File

@@ -1,647 +0,0 @@
"""Tests for splendor/human.py Textual widgets and TUI apps.
Covers Board (compose, on_mount, refresh_content, render methods),
ActionApp/DiscardApp/NobleChoiceApp (compose, on_mount, _update_prompt,
on_input_submitted), and TuiHuman tty paths.
"""
from __future__ import annotations
import random
import sys
from unittest.mock import patch
import pytest
from python.splendor.base import (
BASE_COLORS,
GEM_COLORS,
Card,
GameConfig,
GameState,
Noble,
PlayerState,
TakeDifferent,
create_random_cards,
create_random_nobles,
new_game,
)
from python.splendor.bot import RandomBot
from python.splendor.human import (
ActionApp,
Board,
DiscardApp,
NobleChoiceApp,
TuiHuman,
)
def _make_game(num_players: int = 2):
random.seed(42)
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
cards = create_random_cards()
nobles = create_random_nobles()
config = GameConfig(cards=cards, nobles=nobles)
game = new_game(bots, config)
return game, bots
def _patch_player_names(game: GameState) -> None:
"""Add .name attribute to each PlayerState (delegates to strategy.name)."""
for p in game.players:
p.name = p.strategy.name # type: ignore[attr-defined]
# ---------------------------------------------------------------------------
# Board widget tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_board_compose_and_mount() -> None:
"""Board.compose yields expected widget tree; on_mount populates them."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
assert board is not None
# Verify sub-widgets exist
assert app.query_one("#bank_box") is not None
assert app.query_one("#tier1_box") is not None
assert app.query_one("#tier2_box") is not None
assert app.query_one("#tier3_box") is not None
assert app.query_one("#nobles_box") is not None
assert app.query_one("#players_box") is not None
app.exit()
@pytest.mark.asyncio
async def test_board_render_bank() -> None:
"""Board._render_bank writes bank info to bank_box."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_bank()
app.exit()
@pytest.mark.asyncio
async def test_board_render_tiers() -> None:
"""Board._render_tiers populates tier boxes."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_tiers()
app.exit()
@pytest.mark.asyncio
async def test_board_render_tiers_empty() -> None:
"""Board._render_tiers handles empty tiers."""
game, _ = _make_game()
_patch_player_names(game)
for tier in game.table_by_tier:
game.table_by_tier[tier] = []
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_tiers()
app.exit()
@pytest.mark.asyncio
async def test_board_render_nobles() -> None:
"""Board._render_nobles shows noble info."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_nobles()
app.exit()
@pytest.mark.asyncio
async def test_board_render_nobles_empty() -> None:
"""Board._render_nobles handles no nobles."""
game, _ = _make_game()
_patch_player_names(game)
game.available_nobles = []
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_nobles()
app.exit()
@pytest.mark.asyncio
async def test_board_render_players() -> None:
"""Board._render_players shows all player info."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_players()
app.exit()
@pytest.mark.asyncio
async def test_board_render_players_with_nobles_and_cards() -> None:
"""Board._render_players handles players with nobles, cards, and reserved."""
game, _ = _make_game()
_patch_player_names(game)
p = game.players[0]
card = Card(
tier=1, points=1, color="white", cost=dict.fromkeys(GEM_COLORS, 0),
)
p.cards.append(card)
reserved = Card(
tier=2, points=2, color="blue", cost=dict.fromkeys(GEM_COLORS, 0),
)
p.reserved.append(reserved)
noble = Noble(
name="TestNoble", points=3, requirements=dict.fromkeys(GEM_COLORS, 0),
)
p.nobles.append(noble)
app = ActionApp(game, p)
async with app.run_test() as pilot:
board = app.query_one(Board)
board._render_players()
app.exit()
@pytest.mark.asyncio
async def test_board_refresh_content() -> None:
"""Board.refresh_content calls all render sub-methods."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
board = app.query_one(Board)
board.refresh_content()
app.exit()
# ---------------------------------------------------------------------------
# ActionApp tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_action_app_compose_and_mount() -> None:
"""ActionApp composes command_zone, board, footer and sets up prompt."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
from textual.widgets import Footer, Input, Static
assert app.query_one("#input_line", Input) is not None
assert app.query_one("#prompt", Static) is not None
assert app.query_one("#board", Board) is not None
assert app.query_one(Footer) is not None
app.exit()
@pytest.mark.asyncio
async def test_action_app_update_prompt() -> None:
"""ActionApp._update_prompt writes action menu to prompt widget."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
app._update_prompt()
app.exit()
@pytest.mark.asyncio
async def test_action_app_update_prompt_with_message() -> None:
"""ActionApp._update_prompt includes error message when set."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
app.message = "Some error occurred"
app._update_prompt()
app.exit()
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_quit() -> None:
"""ActionApp exits on 'q' input via pilot keyboard."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
await pilot.press("q", "enter")
await pilot.pause()
assert app.result is None
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_quit_word() -> None:
"""ActionApp exits on 'quit' input."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
await pilot.press("q", "u", "i", "t", "enter")
await pilot.pause()
assert app.result is None
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_zero() -> None:
"""ActionApp exits on '0' input."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
await pilot.press("0", "enter")
await pilot.pause()
assert app.result is None
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_empty() -> None:
"""ActionApp ignores empty input."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
await pilot.press("enter")
await pilot.pause()
assert app.result is None
app.exit()
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_valid_cmd() -> None:
"""ActionApp processes valid command '1 w b g' and exits."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
for ch in "1 w b g":
await pilot.press(ch)
await pilot.press("enter")
await pilot.pause()
assert isinstance(app.result, TakeDifferent)
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_error() -> None:
"""ActionApp shows error message for bad command."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
for ch in "xyz":
await pilot.press(ch)
await pilot.press("enter")
await pilot.pause()
assert app.message == "Unknown command."
app.exit()
@pytest.mark.asyncio
async def test_action_app_on_input_submitted_cmd_error() -> None:
"""ActionApp shows error from a valid command number but bad args."""
game, _ = _make_game()
_patch_player_names(game)
app = ActionApp(game, game.players[0])
async with app.run_test() as pilot:
await pilot.press("1", "enter")
await pilot.pause()
assert app.message != ""
app.exit()
# ---------------------------------------------------------------------------
# DiscardApp tests
# ---------------------------------------------------------------------------
def _make_discard_game(excess: int = 1):
"""Create a game where player 0 has excess tokens over the limit."""
game, _bots = _make_game()
_patch_player_names(game)
p = game.players[0]
for c in GEM_COLORS:
p.tokens[c] = 0
p.tokens["white"] = game.config.token_limit + excess
return game, p
@pytest.mark.asyncio
async def test_discard_app_compose_and_mount() -> None:
"""DiscardApp composes header, command_zone, board, footer."""
game, p = _make_discard_game(2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
from textual.widgets import Footer, Header, Input, Static
assert app.query_one(Header) is not None
assert app.query_one("#input_line", Input) is not None
assert app.query_one("#prompt", Static) is not None
assert app.query_one("#board", Board) is not None
assert app.query_one(Footer) is not None
app.exit()
@pytest.mark.asyncio
async def test_discard_app_update_prompt() -> None:
"""DiscardApp._update_prompt shows remaining discards info."""
game, p = _make_discard_game(2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
app._update_prompt()
app.exit()
@pytest.mark.asyncio
async def test_discard_app_update_prompt_with_message() -> None:
"""DiscardApp._update_prompt includes error message."""
game, p = _make_discard_game(2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
app.message = "No more blue tokens"
app._update_prompt()
app.exit()
@pytest.mark.asyncio
async def test_discard_app_on_input_submitted_empty() -> None:
"""DiscardApp ignores empty input."""
game, p = _make_discard_game(2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
await pilot.press("enter")
await pilot.pause()
assert all(v == 0 for v in app.discards.values())
app.exit()
@pytest.mark.asyncio
async def test_discard_app_on_input_submitted_unknown_color() -> None:
"""DiscardApp shows error for unknown color."""
game, p = _make_discard_game(2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
for ch in "purple":
await pilot.press(ch)
await pilot.press("enter")
await pilot.pause()
assert "Unknown color" in app.message
app.exit()
@pytest.mark.asyncio
async def test_discard_app_on_input_submitted_no_tokens() -> None:
"""DiscardApp shows error when no tokens of that color available."""
game, p = _make_discard_game(2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
for ch in "blue":
await pilot.press(ch)
await pilot.press("enter")
await pilot.pause()
assert "No more" in app.message
app.exit()
@pytest.mark.asyncio
async def test_discard_app_on_input_submitted_valid_finishes() -> None:
"""DiscardApp increments discard and exits when done (excess=1)."""
game, p = _make_discard_game(excess=1)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
await pilot.press("w", "enter")
await pilot.pause()
assert app.discards["white"] == 1
@pytest.mark.asyncio
async def test_discard_app_on_input_submitted_not_done_yet() -> None:
"""DiscardApp stays open when more discards still needed (excess=2)."""
game, p = _make_discard_game(excess=2)
app = DiscardApp(game, p)
async with app.run_test() as pilot:
await pilot.press("w", "enter")
await pilot.pause()
assert app.discards["white"] == 1
assert app.message == ""
await pilot.press("w", "enter")
await pilot.pause()
assert app.discards["white"] == 2
# ---------------------------------------------------------------------------
# NobleChoiceApp tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_noble_choice_app_compose_and_mount() -> None:
"""NobleChoiceApp composes header, command_zone, board, footer."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
from textual.widgets import Footer, Header, Input, Static
assert app.query_one(Header) is not None
assert app.query_one("#input_line", Input) is not None
assert app.query_one("#prompt", Static) is not None
assert app.query_one("#board", Board) is not None
assert app.query_one(Footer) is not None
app.exit()
@pytest.mark.asyncio
async def test_noble_choice_app_update_prompt() -> None:
"""NobleChoiceApp._update_prompt lists available nobles."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
app._update_prompt()
app.exit()
@pytest.mark.asyncio
async def test_noble_choice_app_update_prompt_with_message() -> None:
"""NobleChoiceApp._update_prompt includes error message."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
app.message = "Index out of range."
app._update_prompt()
app.exit()
@pytest.mark.asyncio
async def test_noble_choice_app_on_input_submitted_empty() -> None:
"""NobleChoiceApp ignores empty input."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
await pilot.press("enter")
await pilot.pause()
assert app.result is None
app.exit()
@pytest.mark.asyncio
async def test_noble_choice_app_on_input_submitted_not_int() -> None:
"""NobleChoiceApp shows error for non-integer input."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
for ch in "abc":
await pilot.press(ch)
await pilot.press("enter")
await pilot.pause()
assert "valid integer" in app.message
app.exit()
@pytest.mark.asyncio
async def test_noble_choice_app_on_input_submitted_out_of_range() -> None:
"""NobleChoiceApp shows error for index out of range."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
await pilot.press("9", "enter")
await pilot.pause()
assert "out of range" in app.message.lower()
app.exit()
@pytest.mark.asyncio
async def test_noble_choice_app_on_input_submitted_valid() -> None:
"""NobleChoiceApp selects noble and exits on valid index."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
await pilot.press("0", "enter")
await pilot.pause()
assert app.result is nobles[0]
@pytest.mark.asyncio
async def test_noble_choice_app_on_input_submitted_second_noble() -> None:
"""NobleChoiceApp selects second noble."""
game, _ = _make_game()
_patch_player_names(game)
nobles = game.available_nobles[:2]
app = NobleChoiceApp(game, game.players[0], nobles)
async with app.run_test() as pilot:
await pilot.press("1", "enter")
await pilot.pause()
assert app.result is nobles[1]
# ---------------------------------------------------------------------------
# TuiHuman tty path tests
# ---------------------------------------------------------------------------
def test_tui_human_choose_action_tty() -> None:
"""TuiHuman.choose_action runs ActionApp when stdout is a tty."""
random.seed(42)
game, _ = _make_game()
human = TuiHuman("test")
with patch.object(sys.stdout, "isatty", return_value=True):
with patch.object(ActionApp, "run") as mock_run:
result = human.choose_action(game, game.players[0])
mock_run.assert_called_once()
assert result is None
def test_tui_human_choose_discard_tty() -> None:
"""TuiHuman.choose_discard runs DiscardApp when stdout is a tty."""
random.seed(42)
game, _ = _make_game()
human = TuiHuman("test")
with patch.object(sys.stdout, "isatty", return_value=True):
with patch.object(DiscardApp, "run") as mock_run:
result = human.choose_discard(game, game.players[0], 2)
mock_run.assert_called_once()
assert result == dict.fromkeys(GEM_COLORS, 0)
def test_tui_human_choose_noble_tty() -> None:
"""TuiHuman.choose_noble runs NobleChoiceApp when stdout is a tty."""
random.seed(42)
game, _ = _make_game()
nobles = game.available_nobles[:2]
human = TuiHuman("test")
with patch.object(sys.stdout, "isatty", return_value=True):
with patch.object(NobleChoiceApp, "run") as mock_run:
result = human.choose_noble(game, game.players[0], nobles)
mock_run.assert_called_once()
assert result is None

View File

@@ -1,35 +0,0 @@
"""Tests for python/splendor/main.py."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
def test_splendor_main_import() -> None:
"""Test that splendor main module can be imported."""
from python.splendor.main import main
assert callable(main)
def test_splendor_main_calls_run_game() -> None:
"""Test main creates human + bot and runs game."""
# main() uses wrong signature for new_game (passes strings instead of strategies)
# so we just verify it can be called with mocked internals
with (
patch("python.splendor.main.TuiHuman") as mock_tui,
patch("python.splendor.main.RandomBot") as mock_bot,
patch("python.splendor.main.new_game") as mock_new_game,
patch("python.splendor.main.run_game") as mock_run_game,
):
mock_tui.return_value = MagicMock()
mock_bot.return_value = MagicMock()
mock_new_game.return_value = MagicMock()
mock_run_game.return_value = (MagicMock(), 10)
from python.splendor.main import main
main()
mock_new_game.assert_called_once()
mock_run_game.assert_called_once()

View File

@@ -1,88 +0,0 @@
"""Tests for python/splendor/simulat.py."""
from __future__ import annotations
import json
import random
from pathlib import Path
from unittest.mock import patch
from python.splendor.base import load_cards, load_nobles
from python.splendor.simulat import main
def test_simulat_main(tmp_path: Path) -> None:
"""Test simulat main function with mock game data."""
random.seed(42)
# Create temporary game data
cards_dir = tmp_path / "game_data" / "cards"
nobles_dir = tmp_path / "game_data" / "nobles"
cards_dir.mkdir(parents=True)
nobles_dir.mkdir(parents=True)
cards = []
for tier in (1, 2, 3):
for color in ("white", "blue", "green", "red", "black"):
cards.append({
"tier": tier,
"points": tier,
"color": color,
"cost": {"white": tier, "blue": 0, "green": 0, "red": 0, "black": 0, "gold": 0},
})
(cards_dir / "default.json").write_text(json.dumps(cards))
nobles = [
{"name": f"Noble {i}", "points": 3, "requirements": {"white": 3, "blue": 3, "green": 3}}
for i in range(5)
]
(nobles_dir / "default.json").write_text(json.dumps(nobles))
# Patch Path(__file__).parent to point to tmp_path
fake_parent = tmp_path
with patch("python.splendor.simulat.Path") as mock_path_cls:
mock_path_cls.return_value.__truediv__ = Path.__truediv__
mock_file = mock_path_cls().__truediv__("simulat.py")
# Make Path(__file__).parent return tmp_path
mock_path_cls.reset_mock()
mock_path_instance = mock_path_cls.return_value
mock_path_instance.parent = fake_parent
# Actually just patch load_cards and load_nobles
cards_data = load_cards(cards_dir / "default.json")
nobles_data = load_nobles(nobles_dir / "default.json")
with (
patch("python.splendor.simulat.load_cards", return_value=cards_data),
patch("python.splendor.simulat.load_nobles", return_value=nobles_data),
):
main()
def test_load_cards_and_nobles(tmp_path: Path) -> None:
"""Test that load_cards and load_nobles work correctly."""
cards_dir = tmp_path / "cards"
cards_dir.mkdir()
cards = [
{
"tier": 1,
"points": 0,
"color": "white",
"cost": {"white": 1, "blue": 0, "green": 0, "red": 0, "black": 0, "gold": 0},
}
]
cards_file = cards_dir / "default.json"
cards_file.write_text(json.dumps(cards))
loaded = load_cards(cards_file)
assert len(loaded) == 1
assert loaded[0].color == "white"
nobles_dir = tmp_path / "nobles"
nobles_dir.mkdir()
nobles = [{"name": "Noble A", "points": 3, "requirements": {"white": 3}}]
nobles_file = nobles_dir / "default.json"
nobles_file.write_text(json.dumps(nobles))
loaded_nobles = load_nobles(nobles_file)
assert len(loaded_nobles) == 1
assert loaded_nobles[0].name == "Noble A"

View File

@@ -1,162 +0,0 @@
"""Tests for python/stuff modules."""
from __future__ import annotations
from python.stuff.capasitor import (
calculate_capacitor_capacity,
calculate_pack_capacity,
calculate_pack_capacity2,
)
from python.stuff.voltage_drop import (
Length,
LengthUnit,
MaterialType,
Temperature,
TemperatureUnit,
calculate_awg_diameter_mm,
calculate_resistance_per_meter,
calculate_wire_area_m2,
get_material_resistivity,
max_wire_length,
voltage_drop,
)
# --- capasitor tests ---
def test_calculate_capacitor_capacity() -> None:
"""Test capacitor capacity calculation."""
result = calculate_capacitor_capacity(voltage=2.7, farads=500)
assert isinstance(result, float)
def test_calculate_pack_capacity() -> None:
"""Test pack capacity calculation."""
result = calculate_pack_capacity(cells=10, cell_voltage=2.7, farads=500)
assert isinstance(result, float)
def test_calculate_pack_capacity2() -> None:
"""Test pack capacity2 calculation returns capacity and cost."""
capacity, cost = calculate_pack_capacity2(cells=10, cell_voltage=2.7, farads=3000, cell_cost=11.60)
assert isinstance(capacity, float)
assert cost == 11.60 * 10
# --- voltage_drop tests ---
def test_temperature_celsius() -> None:
"""Test Temperature with celsius."""
t = Temperature(20.0, TemperatureUnit.CELSIUS)
assert float(t) == 20.0
def test_temperature_fahrenheit() -> None:
"""Test Temperature with fahrenheit."""
t = Temperature(100.0, TemperatureUnit.FAHRENHEIT)
assert isinstance(float(t), float)
def test_temperature_kelvin() -> None:
"""Test Temperature with kelvin."""
t = Temperature(300.0, TemperatureUnit.KELVIN)
assert isinstance(float(t), float)
def test_temperature_default_unit() -> None:
"""Test Temperature defaults to celsius."""
t = Temperature(25.0)
assert float(t) == 25.0
def test_length_meters() -> None:
"""Test Length in meters."""
length = Length(10.0, LengthUnit.METERS)
assert float(length) == 10.0
def test_length_feet() -> None:
"""Test Length in feet."""
length = Length(10.0, LengthUnit.FEET)
assert abs(float(length) - 3.048) < 0.001
def test_length_inches() -> None:
"""Test Length in inches."""
length = Length(100.0, LengthUnit.INCHES)
assert abs(float(length) - 2.54) < 0.001
def test_length_feet_method() -> None:
"""Test Length.feet() conversion."""
length = Length(1.0, LengthUnit.METERS)
assert abs(length.feet() - 3.2808) < 0.001
def test_get_material_resistivity_default_temp() -> None:
"""Test material resistivity with default temperature."""
r = get_material_resistivity(MaterialType.COPPER)
assert r > 0
def test_get_material_resistivity_with_temp() -> None:
"""Test material resistivity with explicit temperature."""
r = get_material_resistivity(MaterialType.ALUMINUM, Temperature(50.0))
assert r > 0
def test_get_material_resistivity_all_materials() -> None:
"""Test resistivity for all materials."""
for material in MaterialType:
r = get_material_resistivity(material)
assert r > 0
def test_calculate_awg_diameter_mm() -> None:
"""Test AWG diameter calculation."""
d = calculate_awg_diameter_mm(10)
assert d > 0
def test_calculate_wire_area_m2() -> None:
"""Test wire area calculation."""
area = calculate_wire_area_m2(10)
assert area > 0
def test_calculate_resistance_per_meter() -> None:
"""Test resistance per meter calculation."""
r = calculate_resistance_per_meter(10)
assert r > 0
def test_voltage_drop_calculation() -> None:
"""Test voltage drop calculation."""
vd = voltage_drop(
gauge=10,
material=MaterialType.CCA,
length=Length(20, LengthUnit.FEET),
current_a=20,
)
assert vd > 0
def test_max_wire_length_default_temp() -> None:
"""Test max wire length with default temperature."""
result = max_wire_length(gauge=10, material=MaterialType.CCA, current_amps=20)
assert float(result) > 0
assert result.feet() > 0
def test_max_wire_length_with_temp() -> None:
"""Test max wire length with explicit temperature."""
result = max_wire_length(
gauge=10,
material=MaterialType.COPPER,
current_amps=10,
voltage_drop=0.5,
temperature=Temperature(30.0),
)
assert float(result) > 0

View File

@@ -1,10 +0,0 @@
"""Tests for capasitor main function."""
from __future__ import annotations
from python.stuff.capasitor import main
def test_capasitor_main(capsys: object) -> None:
"""Test capasitor main function runs."""
main()

View File

@@ -1,17 +0,0 @@
"""Tests for python/stuff/thing.py."""
from __future__ import annotations
from python.stuff.thing import caculat_batry_specs
def test_caculat_batry_specs() -> None:
"""Test battery specs calculation."""
capacity, voltage = caculat_batry_specs(
cell_amp_hour=300,
cell_voltage=3.2,
cells_per_pack=8,
packs=2,
)
assert voltage == 3.2 * 8
assert capacity == voltage * 300 * 2

View File

@@ -1,38 +0,0 @@
"""Tests for python/testing/logging modules."""
from __future__ import annotations
from python.testing.logging.bar import bar
from python.testing.logging.configure_logger import configure_logger
from python.testing.logging.foo import foo
from python.testing.logging.main import main
def test_bar() -> None:
"""Test bar function."""
bar()
def test_configure_logger_default() -> None:
"""Test configure_logger with default level."""
configure_logger()
def test_configure_logger_debug() -> None:
"""Test configure_logger with debug level."""
configure_logger("DEBUG")
def test_configure_logger_with_test() -> None:
"""Test configure_logger with test name."""
configure_logger("INFO", "TEST")
def test_foo() -> None:
"""Test foo function."""
foo()
def test_main() -> None:
"""Test main function."""
main()

View File

@@ -1,265 +0,0 @@
"""Tests for python/van_weather modules."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from python.van_weather.models import Config, DailyForecast, HourlyForecast, Weather
from python.van_weather.main import (
CONDITION_MAP,
fetch_weather,
get_ha_state,
parse_daily_forecast,
parse_hourly_forecast,
post_to_ha,
update_weather,
_post_weather_data,
)
if TYPE_CHECKING:
pass
# --- models tests ---
def test_config() -> None:
"""Test Config creation."""
config = Config(ha_url="http://ha.local", ha_token="token123", pirate_weather_api_key="key123")
assert config.ha_url == "http://ha.local"
assert config.lat_entity == "sensor.gps_latitude"
assert config.lon_entity == "sensor.gps_longitude"
assert config.mask_decimals == 1
def test_daily_forecast() -> None:
"""Test DailyForecast creation and serialization."""
dt = datetime(2024, 1, 1, tzinfo=UTC)
forecast = DailyForecast(
date_time=dt,
condition="sunny",
temperature=75.0,
templow=55.0,
precipitation_probability=0.1,
)
assert forecast.condition == "sunny"
serialized = forecast.model_dump()
assert serialized["date_time"] == dt.isoformat()
def test_hourly_forecast() -> None:
"""Test HourlyForecast creation and serialization."""
dt = datetime(2024, 1, 1, 12, 0, tzinfo=UTC)
forecast = HourlyForecast(
date_time=dt,
condition="cloudy",
temperature=65.0,
precipitation_probability=0.3,
)
assert forecast.temperature == 65.0
serialized = forecast.model_dump()
assert serialized["date_time"] == dt.isoformat()
def test_weather_defaults() -> None:
"""Test Weather default values."""
weather = Weather()
assert weather.temperature is None
assert weather.daily_forecasts == []
assert weather.hourly_forecasts == []
def test_weather_full() -> None:
"""Test Weather with all fields."""
weather = Weather(
temperature=72.0,
feels_like=70.0,
humidity=0.5,
wind_speed=10.0,
wind_bearing=180.0,
condition="sunny",
summary="Clear",
pressure=1013.0,
visibility=10.0,
)
assert weather.temperature == 72.0
assert weather.condition == "sunny"
# --- main tests ---
def test_condition_map() -> None:
"""Test CONDITION_MAP has expected entries."""
assert CONDITION_MAP["clear-day"] == "sunny"
assert CONDITION_MAP["rain"] == "rainy"
assert CONDITION_MAP["snow"] == "snowy"
def test_get_ha_state() -> None:
"""Test get_ha_state."""
mock_response = MagicMock()
mock_response.json.return_value = {"state": "45.123"}
mock_response.raise_for_status.return_value = None
with patch("python.van_weather.main.requests.get", return_value=mock_response) as mock_get:
result = get_ha_state("http://ha.local", "token", "sensor.lat")
assert result == 45.123
mock_get.assert_called_once()
def test_parse_daily_forecast() -> None:
"""Test parse_daily_forecast."""
data = {
"daily": {
"data": [
{
"time": 1704067200,
"icon": "clear-day",
"temperatureHigh": 75.0,
"temperatureLow": 55.0,
"precipProbability": 0.1,
},
{
"time": 1704153600,
"icon": "rain",
},
]
}
}
result = parse_daily_forecast(data)
assert len(result) == 2
assert result[0].condition == "sunny"
assert result[0].temperature == 75.0
def test_parse_daily_forecast_empty() -> None:
"""Test parse_daily_forecast with empty data."""
result = parse_daily_forecast({})
assert result == []
def test_parse_daily_forecast_no_timestamp() -> None:
"""Test parse_daily_forecast skips entries without time."""
data = {"daily": {"data": [{"icon": "clear-day"}]}}
result = parse_daily_forecast(data)
assert result == []
def test_parse_hourly_forecast() -> None:
"""Test parse_hourly_forecast."""
data = {
"hourly": {
"data": [
{
"time": 1704067200,
"icon": "cloudy",
"temperature": 65.0,
"precipProbability": 0.3,
},
]
}
}
result = parse_hourly_forecast(data)
assert len(result) == 1
assert result[0].condition == "cloudy"
def test_parse_hourly_forecast_empty() -> None:
"""Test parse_hourly_forecast with empty data."""
result = parse_hourly_forecast({})
assert result == []
def test_parse_hourly_forecast_no_timestamp() -> None:
"""Test parse_hourly_forecast skips entries without time."""
data = {"hourly": {"data": [{"icon": "rain"}]}}
result = parse_hourly_forecast(data)
assert result == []
def test_fetch_weather() -> None:
"""Test fetch_weather."""
mock_response = MagicMock()
mock_response.json.return_value = {
"currently": {
"temperature": 72.0,
"apparentTemperature": 70.0,
"humidity": 0.5,
"windSpeed": 10.0,
"windBearing": 180.0,
"icon": "clear-day",
"summary": "Clear",
"pressure": 1013.0,
"visibility": 10.0,
},
"daily": {"data": []},
"hourly": {"data": []},
}
mock_response.raise_for_status.return_value = None
with patch("python.van_weather.main.requests.get", return_value=mock_response):
weather = fetch_weather("apikey", 45.0, -122.0)
assert weather.temperature == 72.0
assert weather.condition == "sunny"
def test_post_weather_data() -> None:
"""Test _post_weather_data."""
weather = Weather(
temperature=72.0,
feels_like=70.0,
humidity=0.5,
wind_speed=10.0,
wind_bearing=180.0,
condition="sunny",
pressure=1013.0,
visibility=10.0,
)
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
with patch("python.van_weather.main.requests.post", return_value=mock_response) as mock_post:
_post_weather_data("http://ha.local", "token", weather)
assert mock_post.call_count > 0
def test_post_to_ha_retry_on_failure() -> None:
"""Test post_to_ha retries on failure."""
weather = Weather(temperature=72.0)
import requests
with (
patch("python.van_weather.main._post_weather_data", side_effect=requests.RequestException("fail")),
patch("python.van_weather.main.time.sleep"),
):
post_to_ha("http://ha.local", "token", weather)
def test_post_to_ha_success() -> None:
"""Test post_to_ha calls _post_weather_data on each attempt."""
weather = Weather(temperature=72.0)
with patch("python.van_weather.main._post_weather_data") as mock_post:
post_to_ha("http://ha.local", "token", weather)
# The function loops through all attempts even on success (no break)
assert mock_post.call_count == 6
def test_update_weather() -> None:
"""Test update_weather orchestration."""
config = Config(ha_url="http://ha.local", ha_token="token", pirate_weather_api_key="key")
with (
patch("python.van_weather.main.get_ha_state", side_effect=[45.123, -122.456]),
patch("python.van_weather.main.fetch_weather", return_value=Weather(temperature=72.0, condition="sunny")),
patch("python.van_weather.main.post_to_ha"),
):
update_weather(config)

View File

@@ -1,28 +0,0 @@
"""Tests for van_weather/main.py main() function."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from python.van_weather.main import main
def test_van_weather_main() -> None:
"""Test main sets up scheduler."""
with (
patch("python.van_weather.main.BlockingScheduler") as mock_sched_cls,
patch("python.van_weather.main.configure_logger"),
):
mock_sched = MagicMock()
mock_sched_cls.return_value = mock_sched
main(
ha_url="http://ha.local",
ha_token="token",
api_key="key",
interval=60,
log_level="INFO",
)
mock_sched.add_job.assert_called_once()
mock_sched.start.assert_called_once()

View File

@@ -1,361 +0,0 @@
"""Tests for python/zfs/dataset.py covering missing lines."""
from __future__ import annotations
import json
from unittest.mock import patch
import pytest
from python.zfs.dataset import Dataset, Snapshot, _zfs_list
DATASET = "python.zfs.dataset"
SAMPLE_SNAPSHOT_DATA = {
"createtxg": "123",
"properties": {
"creation": {"value": "1620000000"},
"defer_destroy": {"value": "off"},
"guid": {"value": "456"},
"objsetid": {"value": "789"},
"referenced": {"value": "1024"},
"used": {"value": "512"},
"userrefs": {"value": "0"},
"version": {"value": "1"},
"written": {"value": "2048"},
},
"name": "pool/dataset@snap1",
}
SAMPLE_DATASET_DATA = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {
"pool/dataset": {
"properties": {
"aclinherit": {"value": "restricted"},
"aclmode": {"value": "discard"},
"acltype": {"value": "off"},
"available": {"value": "1000000"},
"canmount": {"value": "on"},
"checksum": {"value": "on"},
"clones": {"value": ""},
"compression": {"value": "lz4"},
"copies": {"value": "1"},
"createtxg": {"value": "1234"},
"creation": {"value": "1620000000"},
"dedup": {"value": "off"},
"devices": {"value": "on"},
"encryption": {"value": "off"},
"exec": {"value": "on"},
"filesystem_limit": {"value": "none"},
"guid": {"value": "5678"},
"keystatus": {"value": "none"},
"logbias": {"value": "latency"},
"mlslabel": {"value": "none"},
"mounted": {"value": "yes"},
"mountpoint": {"value": "/pool/dataset"},
"quota": {"value": "0"},
"readonly": {"value": "off"},
"recordsize": {"value": "131072"},
"redundant_metadata": {"value": "all"},
"referenced": {"value": "512000"},
"refquota": {"value": "0"},
"refreservation": {"value": "0"},
"reservation": {"value": "0"},
"setuid": {"value": "on"},
"sharenfs": {"value": "off"},
"snapdir": {"value": "hidden"},
"snapshot_limit": {"value": "none"},
"sync": {"value": "standard"},
"used": {"value": "1024000"},
"usedbychildren": {"value": "512000"},
"usedbydataset": {"value": "256000"},
"usedbysnapshots": {"value": "256000"},
"version": {"value": "5"},
"volmode": {"value": "default"},
"volsize": {"value": "none"},
"vscan": {"value": "off"},
"written": {"value": "4096"},
"xattr": {"value": "on"},
}
}
},
}
def _make_dataset() -> Dataset:
"""Create a Dataset instance with mocked _zfs_list."""
with patch(f"{DATASET}._zfs_list", return_value=SAMPLE_DATASET_DATA):
return Dataset("pool/dataset")
# --- _zfs_list version check error (line 29) ---
def test_zfs_list_returns_data_on_valid_version() -> None:
"""Test _zfs_list returns parsed data when version is correct."""
valid_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {},
}
with patch(f"{DATASET}.bash_wrapper", return_value=(json.dumps(valid_data), 0)):
result = _zfs_list("zfs list pool -pHj -o all")
assert result == valid_data
def test_zfs_list_raises_on_wrong_vers_minor() -> None:
"""Test _zfs_list raises RuntimeError when vers_minor is wrong."""
bad_data = {
"output_version": {"vers_major": 0, "vers_minor": 2, "command": "zfs list"},
}
with (
patch(f"{DATASET}.bash_wrapper", return_value=(json.dumps(bad_data), 0)),
pytest.raises(RuntimeError, match="Datasets are not in the correct format"),
):
_zfs_list("zfs list pool -pHj -o all")
def test_zfs_list_raises_on_wrong_command() -> None:
"""Test _zfs_list raises RuntimeError when command field is wrong."""
bad_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zpool list"},
}
with (
patch(f"{DATASET}.bash_wrapper", return_value=(json.dumps(bad_data), 0)),
pytest.raises(RuntimeError, match="Datasets are not in the correct format"),
):
_zfs_list("zfs list pool -pHj -o all")
# --- Snapshot.__repr__() (line 52) ---
def test_snapshot_repr() -> None:
"""Test Snapshot __repr__ returns correct format."""
snapshot = Snapshot(SAMPLE_SNAPSHOT_DATA)
result = repr(snapshot)
assert result == "name=snap1 used=512 refer=1024"
def test_snapshot_repr_different_values() -> None:
"""Test Snapshot __repr__ with different values."""
data = {
**SAMPLE_SNAPSHOT_DATA,
"name": "pool/dataset@daily-2024-01-01",
"properties": {
**SAMPLE_SNAPSHOT_DATA["properties"],
"used": {"value": "999"},
"referenced": {"value": "5555"},
},
}
snapshot = Snapshot(data)
assert "daily-2024-01-01" in repr(snapshot)
assert "999" in repr(snapshot)
assert "5555" in repr(snapshot)
# --- Dataset.get_snapshots() (lines 113-115) ---
def test_dataset_get_snapshots() -> None:
"""Test Dataset.get_snapshots returns list of Snapshot objects."""
dataset = _make_dataset()
snapshot_list_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {
"pool/dataset@snap1": SAMPLE_SNAPSHOT_DATA,
"pool/dataset@snap2": {
**SAMPLE_SNAPSHOT_DATA,
"name": "pool/dataset@snap2",
},
},
}
with patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data):
snapshots = dataset.get_snapshots()
assert snapshots is not None
assert len(snapshots) == 2
assert all(isinstance(s, Snapshot) for s in snapshots)
def test_dataset_get_snapshots_empty() -> None:
"""Test Dataset.get_snapshots returns empty list when no snapshots."""
dataset = _make_dataset()
snapshot_list_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {},
}
with patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data):
snapshots = dataset.get_snapshots()
assert snapshots == []
# --- Dataset.create_snapshot() (lines 123-133) ---
def test_dataset_create_snapshot_success() -> None:
"""Test create_snapshot returns success message when return code is 0."""
dataset = _make_dataset()
with patch(f"{DATASET}.bash_wrapper", return_value=("", 0)):
result = dataset.create_snapshot("my-snap")
assert result == "snapshot created"
def test_dataset_create_snapshot_already_exists() -> None:
"""Test create_snapshot returns message when snapshot already exists."""
dataset = _make_dataset()
snapshot_list_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {
"pool/dataset@my-snap": SAMPLE_SNAPSHOT_DATA,
},
}
with (
patch(f"{DATASET}.bash_wrapper", return_value=("dataset already exists", 1)),
patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data),
):
# The snapshot data has name "pool/dataset@snap1" which extracts to "snap1"
# We need the snapshot name to match, so use "snap1"
result = dataset.create_snapshot("snap1")
assert "already exists" in result
def test_dataset_create_snapshot_failure() -> None:
"""Test create_snapshot returns failure message on unknown error."""
dataset = _make_dataset()
snapshot_list_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {},
}
with (
patch(f"{DATASET}.bash_wrapper", return_value=("some error", 1)),
patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data),
):
result = dataset.create_snapshot("new-snap")
assert "Failed to create snapshot" in result
def test_dataset_create_snapshot_failure_no_snapshots() -> None:
"""Test create_snapshot failure when get_snapshots returns empty list."""
dataset = _make_dataset()
# get_snapshots returns empty list (falsy), so the if branch is skipped
snapshot_list_data = {
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
"datasets": {},
}
with (
patch(f"{DATASET}.bash_wrapper", return_value=("error", 1)),
patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data),
):
result = dataset.create_snapshot("nonexistent")
assert "Failed to create snapshot" in result
# --- Dataset.delete_snapshot() (lines 141-148) ---
def test_dataset_delete_snapshot_success() -> None:
"""Test delete_snapshot returns None on success."""
dataset = _make_dataset()
with patch(f"{DATASET}.bash_wrapper", return_value=("", 0)):
result = dataset.delete_snapshot("my-snap")
assert result is None
def test_dataset_delete_snapshot_dependent_clones() -> None:
"""Test delete_snapshot returns message when snapshot has dependent clones."""
dataset = _make_dataset()
error_msg = "cannot destroy 'pool/dataset@my-snap': snapshot has dependent clones"
with patch(f"{DATASET}.bash_wrapper", return_value=(error_msg, 1)):
result = dataset.delete_snapshot("my-snap")
assert result == "snapshot has dependent clones"
def test_dataset_delete_snapshot_other_failure() -> None:
"""Test delete_snapshot raises RuntimeError on other failures."""
dataset = _make_dataset()
with (
patch(f"{DATASET}.bash_wrapper", return_value=("some other error", 1)),
pytest.raises(RuntimeError, match="Failed to delete snapshot"),
):
dataset.delete_snapshot("my-snap")
# --- Dataset.__repr__() (line 152) ---
def test_dataset_repr() -> None:
"""Test Dataset __repr__ includes all attributes."""
dataset = _make_dataset()
result = repr(dataset)
expected_attrs = [
"aclinherit",
"aclmode",
"acltype",
"available",
"canmount",
"checksum",
"clones",
"compression",
"copies",
"createtxg",
"creation",
"dedup",
"devices",
"encryption",
"exec",
"filesystem_limit",
"guid",
"keystatus",
"logbias",
"mlslabel",
"mounted",
"mountpoint",
"name",
"quota",
"readonly",
"recordsize",
"redundant_metadata",
"referenced",
"refquota",
"refreservation",
"reservation",
"setuid",
"sharenfs",
"snapdir",
"snapshot_limit",
"sync",
"used",
"usedbychildren",
"usedbydataset",
"usedbysnapshots",
"version",
"volmode",
"volsize",
"vscan",
"written",
"xattr",
]
for attr in expected_attrs:
assert f"self.{attr}=" in result, f"Missing {attr} in repr"