mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-21 06:39:09 -04:00
Compare commits
12 Commits
claude/imp
...
feature/ad
| Author | SHA1 | Date | |
|---|---|---|---|
| 65c4f1d23e | |||
| 75a67294ea | |||
| 58b25f2e89 | |||
| 568bf8dd38 | |||
| 82851eb287 | |||
| b7bce0bcb9 | |||
| 583af965ad | |||
| ec80bf1c5f | |||
| bd490334f5 | |||
| e893ea0f57 | |||
| 18f149b831 | |||
| 69f5b87e5f |
@@ -24,6 +24,7 @@
|
||||
fastapi
|
||||
fastapi-cli
|
||||
httpx
|
||||
python-multipart
|
||||
mypy
|
||||
polars
|
||||
psycopg
|
||||
|
||||
@@ -7,7 +7,25 @@ requires-python = "~=3.13.0"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
# these dependencies are a best effort and aren't guaranteed to work
|
||||
dependencies = ["apprise", "apscheduler", "httpx", "polars", "pydantic", "pyyaml", "requests", "typer"]
|
||||
# for up-to-date dependencies, see overlays/default.nix
|
||||
dependencies = [
|
||||
"alembic",
|
||||
"apprise",
|
||||
"apscheduler",
|
||||
"httpx",
|
||||
"python-multipart",
|
||||
"polars",
|
||||
"psycopg[binary]",
|
||||
"pydantic",
|
||||
"pyyaml",
|
||||
"requests",
|
||||
"sqlalchemy",
|
||||
"typer",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
database = "python.database_cli:app"
|
||||
van-inventory = "python.van_inventory.main:serve"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = python/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
file_template = %%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
revision_environment = true
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
hooks = dynamic_schema,ruff
|
||||
dynamic_schema.type = dynamic_schema
|
||||
|
||||
ruff.type = ruff
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -9,20 +9,24 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from alembic import context
|
||||
from alembic.script import write_hooks
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from python.common import bash_wrapper
|
||||
from python.orm import RichieBase
|
||||
from python.orm.base import get_postgres_engine
|
||||
from python.orm.common import get_postgres_engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
config = context.config
|
||||
|
||||
base_class: type[DeclarativeBase] = config.attributes.get("base")
|
||||
if base_class is None:
|
||||
error = "No base class provided. Use the database CLI to run alembic commands."
|
||||
raise RuntimeError(error)
|
||||
|
||||
target_metadata = RichieBase.metadata
|
||||
target_metadata = base_class.metadata
|
||||
logging.basicConfig(
|
||||
level="DEBUG",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
||||
@@ -35,8 +39,9 @@ logging.basicConfig(
|
||||
def dynamic_schema(filename: str, _options: dict[Any, Any]) -> None:
|
||||
"""Dynamic schema."""
|
||||
original_file = Path(filename).read_text()
|
||||
dynamic_schema_file_part1 = original_file.replace(f"schema='{RichieBase.schema_name}'", "schema=schema")
|
||||
dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{RichieBase.schema_name}.", "f'{schema}.")
|
||||
schema_name = base_class.schema_name
|
||||
dynamic_schema_file_part1 = original_file.replace(f"schema='{schema_name}'", "schema=schema")
|
||||
dynamic_schema_file = dynamic_schema_file_part1.replace(f"'{schema_name}.", "f'{schema}.")
|
||||
Path(filename).write_text(dynamic_schema_file)
|
||||
|
||||
|
||||
@@ -52,12 +57,12 @@ def include_name(
|
||||
type_: Literal["schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint"],
|
||||
_parent_names: MutableMapping[Literal["schema_name", "table_name", "schema_qualified_table_name"], str | None],
|
||||
) -> bool:
|
||||
"""This filter table to be included in the migration.
|
||||
"""Filter tables to be included in the migration.
|
||||
|
||||
Args:
|
||||
name (str): The name of the table.
|
||||
type_ (str): The type of the table.
|
||||
parent_names (list[str]): The names of the parent tables.
|
||||
_parent_names (MutableMapping): The names of the parent tables.
|
||||
|
||||
Returns:
|
||||
bool: True if the table should be included, False otherwise.
|
||||
@@ -75,19 +80,30 @@ def run_migrations_online() -> None:
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = get_postgres_engine()
|
||||
env_prefix = config.attributes.get("env_prefix", "POSTGRES")
|
||||
connectable = get_postgres_engine(name=env_prefix)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
schema = base_class.schema_name
|
||||
if not connectable.dialect.has_schema(connection, schema):
|
||||
answer = input(f"Schema {schema!r} does not exist. Create it? [y/N] ")
|
||||
if answer.lower() != "y":
|
||||
error = f"Schema {schema!r} does not exist. Exiting."
|
||||
raise SystemExit(error)
|
||||
connection.execute(CreateSchema(schema))
|
||||
connection.commit()
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
include_schemas=True,
|
||||
version_table_schema=RichieBase.schema_name,
|
||||
version_table_schema=schema,
|
||||
include_name=include_name,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
connection.commit()
|
||||
|
||||
|
||||
run_migrations_online()
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""add congress tracker tables.
|
||||
|
||||
Revision ID: 3f71565e38de
|
||||
Revises: edd7dd61a3d2
|
||||
Create Date: 2026-02-12 16:36:09.457303
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import RichieBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3f71565e38de"
|
||||
down_revision: str | None = "edd7dd61a3d2"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = RichieBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"bill",
|
||||
sa.Column("congress", sa.Integer(), nullable=False),
|
||||
sa.Column("bill_type", sa.String(), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("title_short", sa.String(), nullable=True),
|
||||
sa.Column("official_title", sa.String(), nullable=True),
|
||||
sa.Column("status", sa.String(), nullable=True),
|
||||
sa.Column("status_at", sa.Date(), nullable=True),
|
||||
sa.Column("sponsor_bioguide_id", sa.String(), nullable=True),
|
||||
sa.Column("subjects_top_term", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_bill")),
|
||||
sa.UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_index("ix_bill_congress", "bill", ["congress"], unique=False, schema=schema)
|
||||
op.create_table(
|
||||
"legislator",
|
||||
sa.Column("bioguide_id", sa.Text(), nullable=False),
|
||||
sa.Column("thomas_id", sa.String(), nullable=True),
|
||||
sa.Column("lis_id", sa.String(), nullable=True),
|
||||
sa.Column("govtrack_id", sa.Integer(), nullable=True),
|
||||
sa.Column("opensecrets_id", sa.String(), nullable=True),
|
||||
sa.Column("fec_ids", sa.String(), nullable=True),
|
||||
sa.Column("first_name", sa.String(), nullable=False),
|
||||
sa.Column("last_name", sa.String(), nullable=False),
|
||||
sa.Column("official_full_name", sa.String(), nullable=True),
|
||||
sa.Column("nickname", sa.String(), nullable=True),
|
||||
sa.Column("birthday", sa.Date(), nullable=True),
|
||||
sa.Column("gender", sa.String(), nullable=True),
|
||||
sa.Column("current_party", sa.String(), nullable=True),
|
||||
sa.Column("current_state", sa.String(), nullable=True),
|
||||
sa.Column("current_district", sa.Integer(), nullable=True),
|
||||
sa.Column("current_chamber", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_legislator")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_index(op.f("ix_legislator_bioguide_id"), "legislator", ["bioguide_id"], unique=True, schema=schema)
|
||||
op.create_table(
|
||||
"vote",
|
||||
sa.Column("congress", sa.Integer(), nullable=False),
|
||||
sa.Column("chamber", sa.String(), nullable=False),
|
||||
sa.Column("session", sa.Integer(), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=False),
|
||||
sa.Column("vote_type", sa.String(), nullable=True),
|
||||
sa.Column("question", sa.String(), nullable=True),
|
||||
sa.Column("result", sa.String(), nullable=True),
|
||||
sa.Column("result_text", sa.String(), nullable=True),
|
||||
sa.Column("vote_date", sa.Date(), nullable=False),
|
||||
sa.Column("yea_count", sa.Integer(), nullable=True),
|
||||
sa.Column("nay_count", sa.Integer(), nullable=True),
|
||||
sa.Column("not_voting_count", sa.Integer(), nullable=True),
|
||||
sa.Column("present_count", sa.Integer(), nullable=True),
|
||||
sa.Column("bill_id", sa.Integer(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(["bill_id"], [f"{schema}.bill.id"], name=op.f("fk_vote_bill_id_bill")),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_vote")),
|
||||
sa.UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_index("ix_vote_congress_chamber", "vote", ["congress", "chamber"], unique=False, schema=schema)
|
||||
op.create_index("ix_vote_date", "vote", ["vote_date"], unique=False, schema=schema)
|
||||
op.create_table(
|
||||
"vote_record",
|
||||
sa.Column("vote_id", sa.Integer(), nullable=False),
|
||||
sa.Column("legislator_id", sa.Integer(), nullable=False),
|
||||
sa.Column("position", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["legislator_id"],
|
||||
[f"{schema}.legislator.id"],
|
||||
name=op.f("fk_vote_record_legislator_id_legislator"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["vote_id"], [f"{schema}.vote.id"], name=op.f("fk_vote_record_vote_id_vote"), ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("vote_id", "legislator_id", name=op.f("pk_vote_record")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("vote_record", schema=schema)
|
||||
op.drop_index("ix_vote_date", table_name="vote", schema=schema)
|
||||
op.drop_index("ix_vote_congress_chamber", table_name="vote", schema=schema)
|
||||
op.drop_table("vote", schema=schema)
|
||||
op.drop_index(op.f("ix_legislator_bioguide_id"), table_name="legislator", schema=schema)
|
||||
op.drop_table("legislator", schema=schema)
|
||||
op.drop_index("ix_bill_congress", table_name="bill", schema=schema)
|
||||
op.drop_table("bill", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from python.orm import RichieBase
|
||||
from python.orm import ${config.attributes["base"].__name__}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@@ -24,7 +24,7 @@ down_revision: str | None = ${repr(down_revision)}
|
||||
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
|
||||
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
|
||||
|
||||
schema=RichieBase.schema_name
|
||||
schema=${config.attributes["base"].__name__}.schema_name
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""starting van invintory.
|
||||
|
||||
Revision ID: 15e733499804
|
||||
Revises:
|
||||
Create Date: 2026-03-08 00:18:20.759720
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from python.orm import VanInventoryBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "15e733499804"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
schema = VanInventoryBase.schema_name
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"items",
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("quantity", sa.Float(), nullable=False),
|
||||
sa.Column("unit", sa.String(), nullable=False),
|
||||
sa.Column("category", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_items")),
|
||||
sa.UniqueConstraint("name", name=op.f("uq_items_name")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"meals",
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("instructions", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_meals")),
|
||||
sa.UniqueConstraint("name", name=op.f("uq_meals_name")),
|
||||
schema=schema,
|
||||
)
|
||||
op.create_table(
|
||||
"meal_ingredients",
|
||||
sa.Column("meal_id", sa.Integer(), nullable=False),
|
||||
sa.Column("item_id", sa.Integer(), nullable=False),
|
||||
sa.Column("quantity_needed", sa.Float(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(["item_id"], [f"{schema}.items.id"], name=op.f("fk_meal_ingredients_item_id_items")),
|
||||
sa.ForeignKeyConstraint(["meal_id"], [f"{schema}.meals.id"], name=op.f("fk_meal_ingredients_meal_id_meals")),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_meal_ingredients")),
|
||||
sa.UniqueConstraint("meal_id", "item_id", name=op.f("uq_meal_ingredients_meal_id")),
|
||||
schema=schema,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("meal_ingredients", schema=schema)
|
||||
op.drop_table("meals", schema=schema)
|
||||
op.drop_table("items", schema=schema)
|
||||
# ### end Alembic commands ###
|
||||
@@ -16,7 +16,7 @@ from fastapi import FastAPI
|
||||
|
||||
from python.api.routers import contact_router, create_frontend_router
|
||||
from python.common import configure_logger
|
||||
from python.orm.base import get_postgres_engine
|
||||
from python.orm.common import get_postgres_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from python.api.dependencies import DbSession
|
||||
from python.orm.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||
|
||||
|
||||
class NeedBase(BaseModel):
|
||||
|
||||
114
python/database_cli.py
Normal file
114
python/database_cli.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""CLI wrapper around alembic for multi-database support.
|
||||
|
||||
Usage:
|
||||
database <db_name> <command> [args...]
|
||||
|
||||
Examples:
|
||||
database van_inventory upgrade head
|
||||
database van_inventory downgrade head-1
|
||||
database van_inventory revision --autogenerate -m "add meals table"
|
||||
database van_inventory check
|
||||
database richie check
|
||||
database richie upgrade head
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
from alembic.config import CommandLine, Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseConfig:
|
||||
"""Configuration for a database."""
|
||||
|
||||
env_prefix: str
|
||||
version_location: str
|
||||
base_module: str
|
||||
base_class_name: str
|
||||
models_module: str
|
||||
script_location: str = "python/alembic"
|
||||
file_template: str = "%%(year)d_%%(month).2d_%%(day).2d-%%(slug)s_%%(rev)s"
|
||||
|
||||
def get_base(self) -> type[DeclarativeBase]:
|
||||
"""Import and return the Base class."""
|
||||
module = import_module(self.base_module)
|
||||
return getattr(module, self.base_class_name)
|
||||
|
||||
def import_models(self) -> None:
|
||||
"""Import ORM models so alembic autogenerate can detect them."""
|
||||
import_module(self.models_module)
|
||||
|
||||
def alembic_config(self) -> Config:
|
||||
"""Build an alembic Config for this database."""
|
||||
# Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
|
||||
from alembic.config import Config as AlembicConfig # noqa: PLC0415
|
||||
|
||||
cfg = AlembicConfig()
|
||||
cfg.set_main_option("script_location", self.script_location)
|
||||
cfg.set_main_option("file_template", self.file_template)
|
||||
cfg.set_main_option("prepend_sys_path", ".")
|
||||
cfg.set_main_option("version_path_separator", "os")
|
||||
cfg.set_main_option("version_locations", self.version_location)
|
||||
cfg.set_main_option("revision_environment", "true")
|
||||
cfg.set_section_option("post_write_hooks", "hooks", "dynamic_schema,ruff")
|
||||
cfg.set_section_option("post_write_hooks", "dynamic_schema.type", "dynamic_schema")
|
||||
cfg.set_section_option("post_write_hooks", "ruff.type", "ruff")
|
||||
cfg.attributes["base"] = self.get_base()
|
||||
cfg.attributes["env_prefix"] = self.env_prefix
|
||||
self.import_models()
|
||||
return cfg
|
||||
|
||||
|
||||
DATABASES: dict[str, DatabaseConfig] = {
|
||||
"richie": DatabaseConfig(
|
||||
env_prefix="RICHIE",
|
||||
version_location="python/alembic/richie/versions",
|
||||
base_module="python.orm.richie.base",
|
||||
base_class_name="RichieBase",
|
||||
models_module="python.orm.richie.contact",
|
||||
),
|
||||
"van_inventory": DatabaseConfig(
|
||||
env_prefix="VAN_INVENTORY",
|
||||
version_location="python/alembic/van_inventory/versions",
|
||||
base_module="python.orm.van_inventory.base",
|
||||
base_class_name="VanInventoryBase",
|
||||
models_module="python.orm.van_inventory.models",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
app = typer.Typer(help="Multi-database alembic wrapper.")
|
||||
|
||||
|
||||
@app.command(
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
db_name: Annotated[str, typer.Argument(help=f"Database name. Options: {', '.join(DATABASES)}")],
|
||||
command: Annotated[str, typer.Argument(help="Alembic command (upgrade, downgrade, revision, check, etc.)")],
|
||||
) -> None:
|
||||
"""Run an alembic command against the specified database."""
|
||||
db_config = DATABASES.get(db_name)
|
||||
if not db_config:
|
||||
typer.echo(f"Unknown database: {db_name!r}. Available: {', '.join(DATABASES)}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
alembic_cfg = db_config.alembic_config()
|
||||
|
||||
cmd_line = CommandLine()
|
||||
options = cmd_line.parser.parse_args([command, *ctx.args])
|
||||
cmd_line.run_cmd(alembic_cfg, options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
@@ -1,22 +1,9 @@
|
||||
"""ORM package exports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.base import RichieBase, TableBase
|
||||
from python.orm.contact import (
|
||||
Contact,
|
||||
ContactNeed,
|
||||
ContactRelationship,
|
||||
Need,
|
||||
RelationshipType,
|
||||
)
|
||||
from python.orm.richie.base import RichieBase
|
||||
from python.orm.van_inventory.base import VanInventoryBase
|
||||
|
||||
__all__ = [
|
||||
"Contact",
|
||||
"ContactNeed",
|
||||
"ContactRelationship",
|
||||
"Need",
|
||||
"RelationshipType",
|
||||
"RichieBase",
|
||||
"TableBase",
|
||||
"VanInventoryBase",
|
||||
]
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Base ORM definitions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from os import getenv
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import DateTime, MetaData, create_engine, func
|
||||
from sqlalchemy.engine import URL, Engine
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class RichieBase(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
|
||||
schema_name = "main"
|
||||
|
||||
metadata = MetaData(
|
||||
schema=schema_name,
|
||||
naming_convention={
|
||||
"ix": "ix_%(table_name)s_%(column_0_name)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TableBase(AbstractConcreteBase, RichieBase):
|
||||
"""Abstract concrete base for tables with IDs and timestamps."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
|
||||
def get_connection_info() -> tuple[str, str, str, str, str | None]:
|
||||
"""Get connection info from environment variables."""
|
||||
database = getenv("POSTGRES_DB")
|
||||
host = getenv("POSTGRES_HOST")
|
||||
port = getenv("POSTGRES_PORT")
|
||||
username = getenv("POSTGRES_USER")
|
||||
password = getenv("POSTGRES_PASSWORD")
|
||||
|
||||
if None in (database, host, port, username):
|
||||
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
||||
raise ValueError(error)
|
||||
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
||||
|
||||
|
||||
def get_postgres_engine(*, pool_pre_ping: bool = True) -> Engine:
|
||||
"""Create a SQLAlchemy engine from environment variables."""
|
||||
database, host, port, username, password = get_connection_info()
|
||||
|
||||
url = URL.create(
|
||||
drivername="postgresql+psycopg",
|
||||
username=username,
|
||||
password=password,
|
||||
host=host,
|
||||
port=int(port),
|
||||
database=database,
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
url=url,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
pool_recycle=1800,
|
||||
)
|
||||
51
python/orm/common.py
Normal file
51
python/orm/common.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Shared ORM definitions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import getenv
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import URL, Engine
|
||||
|
||||
NAMING_CONVENTION = {
|
||||
"ix": "ix_%(table_name)s_%(column_0_name)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
|
||||
def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
||||
"""Get connection info from environment variables."""
|
||||
database = getenv(f"{name}_DB")
|
||||
host = getenv(f"{name}_HOST")
|
||||
port = getenv(f"{name}_PORT")
|
||||
username = getenv(f"{name}_USER")
|
||||
password = getenv(f"{name}_PASSWORD")
|
||||
|
||||
if None in (database, host, port, username):
|
||||
error = f"Missing environment variables for Postgres connection.\n{database=}\n{host=}\n{port=}\n{username=}\n"
|
||||
raise ValueError(error)
|
||||
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
||||
|
||||
|
||||
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
|
||||
"""Create a SQLAlchemy engine from environment variables."""
|
||||
database, host, port, username, password = get_connection_info(name)
|
||||
|
||||
url = URL.create(
|
||||
drivername="postgresql+psycopg",
|
||||
username=username,
|
||||
password=password,
|
||||
host=host,
|
||||
port=int(port),
|
||||
database=database,
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
url=url,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
pool_recycle=1800,
|
||||
)
|
||||
27
python/orm/richie/__init__.py
Normal file
27
python/orm/richie/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Richie database ORM exports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.orm.richie.base import RichieBase, TableBase
|
||||
from python.orm.richie.congress import Bill, Legislator, Vote, VoteRecord
|
||||
from python.orm.richie.contact import (
|
||||
Contact,
|
||||
ContactNeed,
|
||||
ContactRelationship,
|
||||
Need,
|
||||
RelationshipType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Bill",
|
||||
"Contact",
|
||||
"ContactNeed",
|
||||
"ContactRelationship",
|
||||
"Legislator",
|
||||
"Need",
|
||||
"RelationshipType",
|
||||
"RichieBase",
|
||||
"TableBase",
|
||||
"Vote",
|
||||
"VoteRecord",
|
||||
]
|
||||
39
python/orm/richie/base.py
Normal file
39
python/orm/richie/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Richie database ORM base."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, MetaData, func
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from python.orm.common import NAMING_CONVENTION
|
||||
|
||||
|
||||
class RichieBase(DeclarativeBase):
|
||||
"""Base class for richie database ORM models."""
|
||||
|
||||
schema_name = "main"
|
||||
|
||||
metadata = MetaData(
|
||||
schema=schema_name,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
)
|
||||
|
||||
|
||||
class TableBase(AbstractConcreteBase, RichieBase):
|
||||
"""Abstract concrete base for richie tables with IDs and timestamps."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
150
python/orm/richie/congress.py
Normal file
150
python/orm/richie/congress.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Congress Tracker database models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.richie.base import RichieBase, TableBase
|
||||
|
||||
|
||||
class Legislator(TableBase):
|
||||
"""Legislator model - members of Congress."""
|
||||
|
||||
__tablename__ = "legislator"
|
||||
|
||||
# Natural key - bioguide ID is the authoritative identifier
|
||||
bioguide_id: Mapped[str] = mapped_column(Text, unique=True, index=True)
|
||||
|
||||
# Other IDs for cross-referencing
|
||||
thomas_id: Mapped[str | None]
|
||||
lis_id: Mapped[str | None]
|
||||
govtrack_id: Mapped[int | None]
|
||||
opensecrets_id: Mapped[str | None]
|
||||
fec_ids: Mapped[str | None] # JSON array stored as string
|
||||
|
||||
# Name info
|
||||
first_name: Mapped[str]
|
||||
last_name: Mapped[str]
|
||||
official_full_name: Mapped[str | None]
|
||||
nickname: Mapped[str | None]
|
||||
|
||||
# Bio
|
||||
birthday: Mapped[date | None]
|
||||
gender: Mapped[str | None] # M/F
|
||||
|
||||
# Current term info (denormalized for query efficiency)
|
||||
current_party: Mapped[str | None]
|
||||
current_state: Mapped[str | None]
|
||||
current_district: Mapped[int | None] # House only
|
||||
current_chamber: Mapped[str | None] # rep/sen
|
||||
|
||||
# Relationships
|
||||
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||
"VoteRecord",
|
||||
back_populates="legislator",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class Bill(TableBase):
|
||||
"""Bill model - legislation introduced in Congress."""
|
||||
|
||||
__tablename__ = "bill"
|
||||
|
||||
# Composite natural key: congress + bill_type + number
|
||||
congress: Mapped[int]
|
||||
bill_type: Mapped[str] # hr, s, hres, sres, hjres, sjres
|
||||
number: Mapped[int]
|
||||
|
||||
# Bill info
|
||||
title: Mapped[str | None]
|
||||
title_short: Mapped[str | None]
|
||||
official_title: Mapped[str | None]
|
||||
|
||||
# Status
|
||||
status: Mapped[str | None]
|
||||
status_at: Mapped[date | None]
|
||||
|
||||
# Sponsor
|
||||
sponsor_bioguide_id: Mapped[str | None]
|
||||
|
||||
# Subjects
|
||||
subjects_top_term: Mapped[str | None]
|
||||
|
||||
# Relationships
|
||||
votes: Mapped[list[Vote]] = relationship(
|
||||
"Vote",
|
||||
back_populates="bill",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("congress", "bill_type", "number", name="uq_bill_congress_type_number"),
|
||||
Index("ix_bill_congress", "congress"),
|
||||
)
|
||||
|
||||
|
||||
class Vote(TableBase):
|
||||
"""Vote model - roll call votes in Congress."""
|
||||
|
||||
__tablename__ = "vote"
|
||||
|
||||
# Composite natural key: congress + chamber + session + number
|
||||
congress: Mapped[int]
|
||||
chamber: Mapped[str] # house/senate
|
||||
session: Mapped[int]
|
||||
number: Mapped[int]
|
||||
|
||||
# Vote details
|
||||
vote_type: Mapped[str | None]
|
||||
question: Mapped[str | None]
|
||||
result: Mapped[str | None]
|
||||
result_text: Mapped[str | None]
|
||||
|
||||
# Timing
|
||||
vote_date: Mapped[date]
|
||||
|
||||
# Vote counts (denormalized for efficiency)
|
||||
yea_count: Mapped[int | None]
|
||||
nay_count: Mapped[int | None]
|
||||
not_voting_count: Mapped[int | None]
|
||||
present_count: Mapped[int | None]
|
||||
|
||||
# Related bill (optional - not all votes are on bills)
|
||||
bill_id: Mapped[int | None] = mapped_column(ForeignKey("main.bill.id"))
|
||||
|
||||
# Relationships
|
||||
bill: Mapped[Bill | None] = relationship("Bill", back_populates="votes")
|
||||
vote_records: Mapped[list[VoteRecord]] = relationship(
|
||||
"VoteRecord",
|
||||
back_populates="vote",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("congress", "chamber", "session", "number", name="uq_vote_congress_chamber_session_number"),
|
||||
Index("ix_vote_date", "vote_date"),
|
||||
Index("ix_vote_congress_chamber", "congress", "chamber"),
|
||||
)
|
||||
|
||||
|
||||
class VoteRecord(RichieBase):
|
||||
"""Association table: Vote <-> Legislator with position."""
|
||||
|
||||
__tablename__ = "vote_record"
|
||||
|
||||
vote_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("main.vote.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
legislator_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("main.legislator.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
position: Mapped[str] # Yea, Nay, Not Voting, Present
|
||||
|
||||
# Relationships
|
||||
vote: Mapped[Vote] = relationship("Vote", back_populates="vote_records")
|
||||
legislator: Mapped[Legislator] = relationship("Legislator", back_populates="vote_records")
|
||||
@@ -7,7 +7,7 @@ from enum import Enum
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.base import RichieBase, TableBase
|
||||
from python.orm.richie.base import RichieBase, TableBase
|
||||
|
||||
|
||||
class RelationshipType(str, Enum):
|
||||
1
python/orm/van_inventory/__init__.py
Normal file
1
python/orm/van_inventory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Van inventory database ORM exports."""
|
||||
39
python/orm/van_inventory/base.py
Normal file
39
python/orm/van_inventory/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Van inventory database ORM base."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, MetaData, func
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from python.orm.common import NAMING_CONVENTION
|
||||
|
||||
|
||||
class VanInventoryBase(DeclarativeBase):
|
||||
"""Base class for van_inventory database ORM models."""
|
||||
|
||||
schema_name = "main"
|
||||
|
||||
metadata = MetaData(
|
||||
schema=schema_name,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
)
|
||||
|
||||
|
||||
class VanTableBase(AbstractConcreteBase, VanInventoryBase):
|
||||
"""Abstract concrete base for van_inventory tables with IDs and timestamps."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
46
python/orm/van_inventory/models.py
Normal file
46
python/orm/van_inventory/models.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Van inventory ORM models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from python.orm.van_inventory.base import VanTableBase
|
||||
|
||||
|
||||
class Item(VanTableBase):
|
||||
"""A food item in the van."""
|
||||
|
||||
__tablename__ = "items"
|
||||
|
||||
name: Mapped[str] = mapped_column(unique=True)
|
||||
quantity: Mapped[float] = mapped_column(default=0)
|
||||
unit: Mapped[str]
|
||||
category: Mapped[str | None]
|
||||
|
||||
meal_ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="item")
|
||||
|
||||
|
||||
class Meal(VanTableBase):
|
||||
"""A meal that can be made from items in the van."""
|
||||
|
||||
__tablename__ = "meals"
|
||||
|
||||
name: Mapped[str] = mapped_column(unique=True)
|
||||
instructions: Mapped[str | None]
|
||||
|
||||
ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="meal")
|
||||
|
||||
|
||||
class MealIngredient(VanTableBase):
|
||||
"""Links a meal to the items it requires, with quantities."""
|
||||
|
||||
__tablename__ = "meal_ingredients"
|
||||
__table_args__ = (UniqueConstraint("meal_id", "item_id"),)
|
||||
|
||||
meal_id: Mapped[int] = mapped_column(ForeignKey("meals.id"))
|
||||
item_id: Mapped[int] = mapped_column(ForeignKey("items.id"))
|
||||
quantity_needed: Mapped[float]
|
||||
|
||||
meal: Mapped[Meal] = relationship(back_populates="ingredients")
|
||||
item: Mapped[Item] = relationship(back_populates="meal_ingredients")
|
||||
1
python/van_inventory/__init__.py
Normal file
1
python/van_inventory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Van inventory FastAPI application."""
|
||||
16
python/van_inventory/dependencies.py
Normal file
16
python/van_inventory/dependencies.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""FastAPI dependencies for van inventory."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def get_db(request: Request) -> Iterator[Session]:
|
||||
"""Get database session from app state."""
|
||||
with Session(request.app.state.engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
56
python/van_inventory/main.py
Normal file
56
python/van_inventory/main.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""FastAPI app for van inventory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from python.common import configure_logger
|
||||
from python.orm.common import get_postgres_engine
|
||||
from python.van_inventory.routers import api_router, frontend_router
|
||||
|
||||
STATIC_DIR = Path(__file__).resolve().parent / "static"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
app.state.engine = get_postgres_engine(name="VAN_INVENTORY")
|
||||
yield
|
||||
app.state.engine.dispose()
|
||||
|
||||
app = FastAPI(title="Van Inventory", lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.include_router(api_router)
|
||||
app.include_router(frontend_router)
|
||||
return app
|
||||
|
||||
|
||||
def serve(
|
||||
# Intentionally binds all interfaces — this is a LAN-only van server
|
||||
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "0.0.0.0", # noqa: S104
|
||||
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8001,
|
||||
log_level: Annotated[str, typer.Option("--log-level", "-l", help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Start the Van Inventory server."""
|
||||
configure_logger(log_level)
|
||||
app = create_app()
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(serve)
|
||||
6
python/van_inventory/routers/__init__.py
Normal file
6
python/van_inventory/routers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Van inventory API routers."""
|
||||
|
||||
from python.van_inventory.routers.api import router as api_router
|
||||
from python.van_inventory.routers.frontend import router as frontend_router
|
||||
|
||||
__all__ = ["api_router", "frontend_router"]
|
||||
314
python/van_inventory/routers/api.py
Normal file
314
python/van_inventory/routers/api.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""Van inventory API router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from python.orm.van_inventory.models import Item, Meal, MealIngredient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.van_inventory.dependencies import DbSession
|
||||
|
||||
|
||||
# --- Schemas ---
|
||||
|
||||
|
||||
class ItemCreate(BaseModel):
|
||||
"""Schema for creating an item."""
|
||||
|
||||
name: str
|
||||
quantity: float = Field(default=0, ge=0)
|
||||
unit: str
|
||||
category: str | None = None
|
||||
|
||||
|
||||
class ItemUpdate(BaseModel):
|
||||
"""Schema for updating an item."""
|
||||
|
||||
name: str | None = None
|
||||
quantity: float | None = Field(default=None, ge=0)
|
||||
unit: str | None = None
|
||||
category: str | None = None
|
||||
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
"""Schema for item response."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
quantity: float
|
||||
unit: str
|
||||
category: str | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class IngredientCreate(BaseModel):
|
||||
"""Schema for adding an ingredient to a meal."""
|
||||
|
||||
item_id: int
|
||||
quantity_needed: float = Field(gt=0)
|
||||
|
||||
|
||||
class MealCreate(BaseModel):
|
||||
"""Schema for creating a meal."""
|
||||
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
ingredients: list[IngredientCreate] = []
|
||||
|
||||
|
||||
class MealUpdate(BaseModel):
|
||||
"""Schema for updating a meal."""
|
||||
|
||||
name: str | None = None
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class IngredientResponse(BaseModel):
|
||||
"""Schema for ingredient response."""
|
||||
|
||||
item_id: int
|
||||
item_name: str
|
||||
quantity_needed: float
|
||||
unit: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MealResponse(BaseModel):
|
||||
"""Schema for meal response."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
instructions: str | None
|
||||
ingredients: list[IngredientResponse] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@classmethod
|
||||
def from_meal(cls, meal: Meal) -> MealResponse:
|
||||
"""Build a MealResponse from an ORM Meal with loaded ingredients."""
|
||||
return cls(
|
||||
id=meal.id,
|
||||
name=meal.name,
|
||||
instructions=meal.instructions,
|
||||
ingredients=[
|
||||
IngredientResponse(
|
||||
item_id=mi.item_id,
|
||||
item_name=mi.item.name,
|
||||
quantity_needed=mi.quantity_needed,
|
||||
unit=mi.item.unit,
|
||||
)
|
||||
for mi in meal.ingredients
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ShoppingItem(BaseModel):
|
||||
"""An item needed for a meal that is short on stock."""
|
||||
|
||||
item_name: str
|
||||
unit: str
|
||||
needed: float
|
||||
have: float
|
||||
short: float
|
||||
|
||||
|
||||
class MealAvailability(BaseModel):
|
||||
"""Availability status for a meal."""
|
||||
|
||||
meal_id: int
|
||||
meal_name: str
|
||||
can_make: bool
|
||||
missing: list[ShoppingItem] = []
|
||||
|
||||
|
||||
# --- Routes ---
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["van_inventory"])
|
||||
|
||||
|
||||
# Items
|
||||
|
||||
|
||||
@router.post("/items", response_model=ItemResponse)
|
||||
def create_item(item: ItemCreate, db: DbSession) -> Item:
|
||||
"""Create a new inventory item."""
|
||||
db_item = Item(**item.model_dump())
|
||||
db.add(db_item)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
return db_item
|
||||
|
||||
|
||||
@router.get("/items", response_model=list[ItemResponse])
|
||||
def list_items(db: DbSession) -> list[Item]:
|
||||
"""List all inventory items."""
|
||||
return list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||
|
||||
|
||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||
def get_item(item_id: int, db: DbSession) -> Item:
|
||||
"""Get an item by ID."""
|
||||
item = db.get(Item, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.patch("/items/{item_id}", response_model=ItemResponse)
|
||||
def update_item(item_id: int, item: ItemUpdate, db: DbSession) -> Item:
|
||||
"""Update an item by ID."""
|
||||
db_item = db.get(Item, item_id)
|
||||
if not db_item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
for key, value in item.model_dump(exclude_unset=True).items():
|
||||
setattr(db_item, key, value)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
return db_item
|
||||
|
||||
|
||||
@router.delete("/items/{item_id}")
|
||||
def delete_item(item_id: int, db: DbSession) -> dict[str, bool]:
|
||||
"""Delete an item by ID."""
|
||||
item = db.get(Item, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# Meals
|
||||
|
||||
|
||||
@router.post("/meals", response_model=MealResponse)
|
||||
def create_meal(meal: MealCreate, db: DbSession) -> MealResponse:
|
||||
"""Create a new meal with optional ingredients."""
|
||||
for ing in meal.ingredients:
|
||||
if not db.get(Item, ing.item_id):
|
||||
raise HTTPException(status_code=422, detail=f"Item {ing.item_id} not found")
|
||||
db_meal = Meal(name=meal.name, instructions=meal.instructions)
|
||||
db.add(db_meal)
|
||||
db.flush()
|
||||
for ing in meal.ingredients:
|
||||
db.add(MealIngredient(meal_id=db_meal.id, item_id=ing.item_id, quantity_needed=ing.quantity_needed))
|
||||
db.commit()
|
||||
db_meal = db.scalar(
|
||||
select(Meal)
|
||||
.where(Meal.id == db_meal.id)
|
||||
.options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||
)
|
||||
return MealResponse.from_meal(db_meal)
|
||||
|
||||
|
||||
@router.get("/meals", response_model=list[MealResponse])
|
||||
def list_meals(db: DbSession) -> list[MealResponse]:
|
||||
"""List all meals with ingredients."""
|
||||
meals = list(
|
||||
db.scalars(
|
||||
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
|
||||
).all()
|
||||
)
|
||||
return [MealResponse.from_meal(m) for m in meals]
|
||||
|
||||
|
||||
@router.get("/meals/availability", response_model=list[MealAvailability])
|
||||
def check_all_meals(db: DbSession) -> list[MealAvailability]:
|
||||
"""Check which meals can be made with current inventory."""
|
||||
meals = list(
|
||||
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
|
||||
)
|
||||
return [_check_meal(m) for m in meals]
|
||||
|
||||
|
||||
@router.get("/meals/{meal_id}", response_model=MealResponse)
|
||||
def get_meal(meal_id: int, db: DbSession) -> MealResponse:
|
||||
"""Get a meal by ID with ingredients."""
|
||||
meal = db.scalar(
|
||||
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||
)
|
||||
if not meal:
|
||||
raise HTTPException(status_code=404, detail="Meal not found")
|
||||
return MealResponse.from_meal(meal)
|
||||
|
||||
|
||||
@router.delete("/meals/{meal_id}")
|
||||
def delete_meal(meal_id: int, db: DbSession) -> dict[str, bool]:
|
||||
"""Delete a meal by ID."""
|
||||
meal = db.get(Meal, meal_id)
|
||||
if not meal:
|
||||
raise HTTPException(status_code=404, detail="Meal not found")
|
||||
db.delete(meal)
|
||||
db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.post("/meals/{meal_id}/ingredients", response_model=MealResponse)
|
||||
def add_ingredient(meal_id: int, ingredient: IngredientCreate, db: DbSession) -> MealResponse:
|
||||
"""Add an ingredient to a meal."""
|
||||
meal = db.get(Meal, meal_id)
|
||||
if not meal:
|
||||
raise HTTPException(status_code=404, detail="Meal not found")
|
||||
if not db.get(Item, ingredient.item_id):
|
||||
raise HTTPException(status_code=422, detail="Item not found")
|
||||
existing = db.scalar(
|
||||
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == ingredient.item_id)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
|
||||
db.add(MealIngredient(meal_id=meal_id, item_id=ingredient.item_id, quantity_needed=ingredient.quantity_needed))
|
||||
db.commit()
|
||||
meal = db.scalar(
|
||||
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||
)
|
||||
return MealResponse.from_meal(meal)
|
||||
|
||||
|
||||
@router.delete("/meals/{meal_id}/ingredients/{item_id}")
|
||||
def remove_ingredient(meal_id: int, item_id: int, db: DbSession) -> dict[str, bool]:
|
||||
"""Remove an ingredient from a meal."""
|
||||
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
|
||||
if not mi:
|
||||
raise HTTPException(status_code=404, detail="Ingredient not found")
|
||||
db.delete(mi)
|
||||
db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.get("/meals/{meal_id}/availability", response_model=MealAvailability)
|
||||
def check_meal(meal_id: int, db: DbSession) -> MealAvailability:
|
||||
"""Check if a specific meal can be made and what's missing."""
|
||||
meal = db.scalar(
|
||||
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||
)
|
||||
if not meal:
|
||||
raise HTTPException(status_code=404, detail="Meal not found")
|
||||
return _check_meal(meal)
|
||||
|
||||
|
||||
def _check_meal(meal: Meal) -> MealAvailability:
|
||||
missing = [
|
||||
ShoppingItem(
|
||||
item_name=mi.item.name,
|
||||
unit=mi.item.unit,
|
||||
needed=mi.quantity_needed,
|
||||
have=mi.item.quantity,
|
||||
short=mi.quantity_needed - mi.item.quantity,
|
||||
)
|
||||
for mi in meal.ingredients
|
||||
if mi.item.quantity < mi.quantity_needed
|
||||
]
|
||||
return MealAvailability(
|
||||
meal_id=meal.id,
|
||||
meal_name=meal.name,
|
||||
can_make=len(missing) == 0,
|
||||
missing=missing,
|
||||
)
|
||||
198
python/van_inventory/routers/frontend.py
Normal file
198
python/van_inventory/routers/frontend.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""HTMX frontend routes for van inventory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from python.orm.van_inventory.models import Item, Meal, MealIngredient
|
||||
|
||||
# FastAPI needs DbSession at runtime to resolve the Depends() annotation
|
||||
from python.van_inventory.dependencies import DbSession # noqa: TC001
|
||||
from python.van_inventory.routers.api import _check_meal
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).resolve().parent.parent / "templates"
|
||||
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
||||
|
||||
router = APIRouter(tags=["frontend"])
|
||||
|
||||
|
||||
# --- Items ---
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
def items_page(request: Request, db: DbSession) -> HTMLResponse:
|
||||
"""Render the inventory page."""
|
||||
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||
return templates.TemplateResponse(request, "items.html", {"items": items})
|
||||
|
||||
|
||||
@router.post("/items", response_class=HTMLResponse)
|
||||
def htmx_create_item(
|
||||
request: Request,
|
||||
db: DbSession,
|
||||
name: Annotated[str, Form()],
|
||||
quantity: Annotated[float, Form()] = 0,
|
||||
unit: Annotated[str, Form()] = "",
|
||||
category: Annotated[str | None, Form()] = None,
|
||||
) -> HTMLResponse:
|
||||
"""Create an item and return updated item rows."""
|
||||
if quantity < 0:
|
||||
raise HTTPException(status_code=422, detail="Quantity must not be negative")
|
||||
db.add(Item(name=name, quantity=quantity, unit=unit, category=category or None))
|
||||
db.commit()
|
||||
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
|
||||
|
||||
|
||||
@router.patch("/items/{item_id}", response_class=HTMLResponse)
|
||||
def htmx_update_item(
|
||||
request: Request,
|
||||
item_id: int,
|
||||
db: DbSession,
|
||||
quantity: Annotated[float, Form()],
|
||||
) -> HTMLResponse:
|
||||
"""Update an item's quantity and return updated item rows."""
|
||||
if quantity < 0:
|
||||
raise HTTPException(status_code=422, detail="Quantity must not be negative")
|
||||
item = db.get(Item, item_id)
|
||||
if item:
|
||||
item.quantity = quantity
|
||||
db.commit()
|
||||
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
|
||||
|
||||
|
||||
@router.delete("/items/{item_id}", response_class=HTMLResponse)
|
||||
def htmx_delete_item(request: Request, item_id: int, db: DbSession) -> HTMLResponse:
|
||||
"""Delete an item and return updated item rows."""
|
||||
item = db.get(Item, item_id)
|
||||
if item:
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
|
||||
|
||||
|
||||
# --- Meals ---
|
||||
|
||||
|
||||
def _load_meals(db: DbSession) -> list[Meal]:
|
||||
return list(
|
||||
db.scalars(
|
||||
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
@router.get("/meals", response_class=HTMLResponse)
|
||||
def meals_page(request: Request, db: DbSession) -> HTMLResponse:
|
||||
"""Render the meals page."""
|
||||
meals = _load_meals(db)
|
||||
return templates.TemplateResponse(request, "meals.html", {"meals": meals})
|
||||
|
||||
|
||||
@router.post("/meals", response_class=HTMLResponse)
|
||||
def htmx_create_meal(
|
||||
request: Request,
|
||||
db: DbSession,
|
||||
name: Annotated[str, Form()],
|
||||
instructions: Annotated[str | None, Form()] = None,
|
||||
) -> HTMLResponse:
|
||||
"""Create a meal and return updated meal rows."""
|
||||
db.add(Meal(name=name, instructions=instructions or None))
|
||||
db.commit()
|
||||
meals = _load_meals(db)
|
||||
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
|
||||
|
||||
|
||||
@router.delete("/meals/{meal_id}", response_class=HTMLResponse)
|
||||
def htmx_delete_meal(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
|
||||
"""Delete a meal and return updated meal rows."""
|
||||
meal = db.get(Meal, meal_id)
|
||||
if meal:
|
||||
db.delete(meal)
|
||||
db.commit()
|
||||
meals = _load_meals(db)
|
||||
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
|
||||
|
||||
|
||||
# --- Meal detail ---
|
||||
|
||||
|
||||
def _load_meal(db: DbSession, meal_id: int) -> Meal | None:
|
||||
return db.scalar(
|
||||
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
|
||||
)
|
||||
|
||||
|
||||
@router.get("/meals/{meal_id}", response_class=HTMLResponse)
|
||||
def meal_detail_page(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
|
||||
"""Render the meal detail page."""
|
||||
meal = _load_meal(db, meal_id)
|
||||
if not meal:
|
||||
raise HTTPException(status_code=404, detail="Meal not found")
|
||||
items = list(db.scalars(select(Item).order_by(Item.name)).all())
|
||||
return templates.TemplateResponse(request, "meal_detail.html", {"meal": meal, "items": items})
|
||||
|
||||
|
||||
@router.post("/meals/{meal_id}/ingredients", response_class=HTMLResponse)
|
||||
def htmx_add_ingredient(
|
||||
request: Request,
|
||||
meal_id: int,
|
||||
db: DbSession,
|
||||
item_id: Annotated[int, Form()],
|
||||
quantity_needed: Annotated[float, Form()],
|
||||
) -> HTMLResponse:
|
||||
"""Add an ingredient to a meal and return updated ingredient rows."""
|
||||
if quantity_needed <= 0:
|
||||
raise HTTPException(status_code=422, detail="Quantity must be positive")
|
||||
meal = db.get(Meal, meal_id)
|
||||
if not meal:
|
||||
raise HTTPException(status_code=404, detail="Meal not found")
|
||||
if not db.get(Item, item_id):
|
||||
raise HTTPException(status_code=422, detail="Item not found")
|
||||
existing = db.scalar(
|
||||
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
|
||||
db.add(MealIngredient(meal_id=meal_id, item_id=item_id, quantity_needed=quantity_needed))
|
||||
db.commit()
|
||||
meal = _load_meal(db, meal_id)
|
||||
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
|
||||
|
||||
|
||||
@router.delete("/meals/{meal_id}/ingredients/{item_id}", response_class=HTMLResponse)
|
||||
def htmx_remove_ingredient(
|
||||
request: Request,
|
||||
meal_id: int,
|
||||
item_id: int,
|
||||
db: DbSession,
|
||||
) -> HTMLResponse:
|
||||
"""Remove an ingredient from a meal and return updated ingredient rows."""
|
||||
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
|
||||
if mi:
|
||||
db.delete(mi)
|
||||
db.commit()
|
||||
meal = _load_meal(db, meal_id)
|
||||
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
|
||||
|
||||
|
||||
# --- Availability ---
|
||||
|
||||
|
||||
@router.get("/availability", response_class=HTMLResponse)
|
||||
def availability_page(request: Request, db: DbSession) -> HTMLResponse:
|
||||
"""Render the meal availability page."""
|
||||
meals = list(
|
||||
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
|
||||
)
|
||||
availability = [_check_meal(m) for m in meals]
|
||||
return templates.TemplateResponse(request, "availability.html", {"availability": availability})
|
||||
212
python/van_inventory/static/style.css
Normal file
212
python/van_inventory/static/style.css
Normal file
@@ -0,0 +1,212 @@
|
||||
:root {
|
||||
--neon-pink: #ff2a6d;
|
||||
--neon-cyan: #05d9e8;
|
||||
--neon-yellow: #f9f002;
|
||||
--neon-purple: #d300c5;
|
||||
--bg-dark: #0a0a0f;
|
||||
--bg-panel: #0d0d1a;
|
||||
--bg-input: #111128;
|
||||
--border: #1a1a3e;
|
||||
--text: #c0c0d0;
|
||||
--text-dim: #8e8ea0;
|
||||
}
|
||||
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
|
||||
body {
|
||||
font-family: 'Share Tech Mono', monospace;
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
padding: 1rem;
|
||||
background: var(--bg-dark);
|
||||
color: var(--text);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
/* Scanline overlay */
|
||||
body::before {
|
||||
content: '';
|
||||
position: fixed;
|
||||
top: 0; left: 0; right: 0; bottom: 0;
|
||||
background: repeating-linear-gradient(
|
||||
0deg,
|
||||
transparent,
|
||||
transparent 2px,
|
||||
rgba(0, 0, 0, 0.08) 2px,
|
||||
rgba(0, 0, 0, 0.08) 4px
|
||||
);
|
||||
pointer-events: none;
|
||||
z-index: 9999;
|
||||
}
|
||||
|
||||
h1, h2, h3 {
|
||||
font-family: 'Orbitron', sans-serif;
|
||||
margin-bottom: 0.5rem;
|
||||
color: var(--neon-cyan);
|
||||
text-shadow: 0 0 10px rgba(5, 217, 232, 0.5), 0 0 40px rgba(5, 217, 232, 0.2);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 2px;
|
||||
}
|
||||
|
||||
a { color: var(--neon-pink); text-decoration: none; transition: all 0.2s; }
|
||||
a:hover {
|
||||
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8), 0 0 20px rgba(255, 42, 109, 0.4);
|
||||
}
|
||||
|
||||
nav {
|
||||
display: flex;
|
||||
gap: 1.5rem;
|
||||
padding: 1rem 0;
|
||||
border-bottom: 1px solid var(--border);
|
||||
margin-bottom: 1.5rem;
|
||||
position: relative;
|
||||
}
|
||||
nav::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
bottom: -1px;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 1px;
|
||||
background: linear-gradient(90deg, var(--neon-pink), var(--neon-cyan), var(--neon-purple));
|
||||
opacity: 0.6;
|
||||
}
|
||||
nav a {
|
||||
font-family: 'Orbitron', sans-serif;
|
||||
font-weight: 700;
|
||||
font-size: 0.85rem;
|
||||
letter-spacing: 1px;
|
||||
text-transform: uppercase;
|
||||
padding: 0.3rem 0;
|
||||
border-bottom: 2px solid transparent;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
nav a:hover {
|
||||
border-bottom-color: var(--neon-pink);
|
||||
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8);
|
||||
}
|
||||
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin: 1rem 0;
|
||||
border: 1px solid var(--border);
|
||||
}
|
||||
th, td {
|
||||
text-align: left;
|
||||
padding: 0.6rem 0.75rem;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
th {
|
||||
font-family: 'Orbitron', sans-serif;
|
||||
color: var(--neon-cyan);
|
||||
font-size: 0.7rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 2px;
|
||||
background: var(--bg-panel);
|
||||
border-bottom: 1px solid var(--neon-cyan);
|
||||
text-shadow: 0 0 6px rgba(5, 217, 232, 0.3);
|
||||
}
|
||||
tr:hover td {
|
||||
background: rgba(5, 217, 232, 0.03);
|
||||
}
|
||||
|
||||
form {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.5rem;
|
||||
align-items: end;
|
||||
margin: 1rem 0;
|
||||
padding: 1rem;
|
||||
border: 1px solid var(--border);
|
||||
background: var(--bg-panel);
|
||||
}
|
||||
|
||||
input, select {
|
||||
padding: 0.5rem 0.6rem;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 2px;
|
||||
background: var(--bg-input);
|
||||
color: var(--neon-cyan);
|
||||
font-family: 'Share Tech Mono', monospace;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
input:focus, select:focus {
|
||||
outline: none;
|
||||
border-color: var(--neon-cyan);
|
||||
box-shadow: 0 0 8px rgba(5, 217, 232, 0.3), inset 0 0 8px rgba(5, 217, 232, 0.05);
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 0.5rem 1.2rem;
|
||||
border: 1px solid var(--neon-pink);
|
||||
border-radius: 2px;
|
||||
background: transparent;
|
||||
color: var(--neon-pink);
|
||||
cursor: pointer;
|
||||
font-family: 'Orbitron', sans-serif;
|
||||
font-weight: 700;
|
||||
font-size: 0.7rem;
|
||||
letter-spacing: 1px;
|
||||
text-transform: uppercase;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
button:hover {
|
||||
background: var(--neon-pink);
|
||||
color: var(--bg-dark);
|
||||
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5), 0 0 30px rgba(255, 42, 109, 0.2);
|
||||
}
|
||||
button.danger {
|
||||
border-color: var(--text-dim);
|
||||
color: var(--text-dim);
|
||||
}
|
||||
button.danger:hover {
|
||||
border-color: var(--neon-pink);
|
||||
background: var(--neon-pink);
|
||||
color: var(--bg-dark);
|
||||
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5);
|
||||
}
|
||||
|
||||
.badge {
|
||||
display: inline-block;
|
||||
padding: 0.2rem 0.6rem;
|
||||
border-radius: 2px;
|
||||
font-family: 'Orbitron', sans-serif;
|
||||
font-size: 0.65rem;
|
||||
font-weight: 700;
|
||||
letter-spacing: 1px;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
.badge.yes {
|
||||
background: rgba(5, 217, 232, 0.1);
|
||||
color: var(--neon-cyan);
|
||||
border: 1px solid var(--neon-cyan);
|
||||
text-shadow: 0 0 6px rgba(5, 217, 232, 0.5);
|
||||
}
|
||||
.badge.no {
|
||||
background: rgba(255, 42, 109, 0.1);
|
||||
color: var(--neon-pink);
|
||||
border: 1px solid var(--neon-pink);
|
||||
text-shadow: 0 0 6px rgba(255, 42, 109, 0.5);
|
||||
}
|
||||
|
||||
.missing-list { font-size: 0.85rem; color: var(--text-dim); }
|
||||
|
||||
label {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-dim);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.2rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
}
|
||||
|
||||
.flash {
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 0.5rem 0;
|
||||
border-radius: 2px;
|
||||
background: rgba(5, 217, 232, 0.1);
|
||||
color: var(--neon-cyan);
|
||||
border: 1px solid var(--neon-cyan);
|
||||
}
|
||||
30
python/van_inventory/templates/availability.html
Normal file
30
python/van_inventory/templates/availability.html
Normal file
@@ -0,0 +1,30 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}What Can I Make? - Van{% endblock %}
|
||||
{% block content %}
|
||||
<h1>What Can I Make?</h1>
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Meal</th><th>Status</th><th>Missing</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for meal in availability %}
|
||||
<tr>
|
||||
<td><a href="/meals/{{ meal.meal_id }}">{{ meal.meal_name }}</a></td>
|
||||
<td>
|
||||
{% if meal.can_make %}
|
||||
<span class="badge yes">Ready</span>
|
||||
{% else %}
|
||||
<span class="badge no">Missing items</span>
|
||||
{% endif %}
|
||||
</td>
|
||||
<td class="missing-list">
|
||||
{% for m in meal.missing %}
|
||||
{{ m.item_name }}: need {{ m.short }} more {{ m.unit }}{% if not loop.last %}, {% endif %}
|
||||
{% endfor %}
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% endblock %}
|
||||
20
python/van_inventory/templates/base.html
Normal file
20
python/van_inventory/templates/base.html
Normal file
@@ -0,0 +1,20 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{% block title %}Van Inventory{% endblock %}</title>
|
||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&display=swap" rel="stylesheet">
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<nav>
|
||||
<a href="/">Inventory</a>
|
||||
<a href="/meals">Meals</a>
|
||||
<a href="/availability">What Can I Make?</a>
|
||||
</nav>
|
||||
{% block content %}{% endblock %}
|
||||
</body>
|
||||
</html>
|
||||
17
python/van_inventory/templates/items.html
Normal file
17
python/van_inventory/templates/items.html
Normal file
@@ -0,0 +1,17 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}Inventory - Van{% endblock %}
|
||||
{% block content %}
|
||||
<h1>Van Inventory</h1>
|
||||
|
||||
<form hx-post="/items" hx-target="#item-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
|
||||
<label>Name <input type="text" name="name" required></label>
|
||||
<label>Qty <input type="number" name="quantity" step="any" value="0" min="0" required></label>
|
||||
<label>Unit <input type="text" name="unit" required placeholder="lbs, cans, etc"></label>
|
||||
<label>Category <input type="text" name="category" placeholder="optional"></label>
|
||||
<button type="submit">Add Item</button>
|
||||
</form>
|
||||
|
||||
<div id="item-list">
|
||||
{% include "partials/item_rows.html" %}
|
||||
</div>
|
||||
{% endblock %}
|
||||
24
python/van_inventory/templates/meal_detail.html
Normal file
24
python/van_inventory/templates/meal_detail.html
Normal file
@@ -0,0 +1,24 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}{{ meal.name }} - Van{% endblock %}
|
||||
{% block content %}
|
||||
<h1>{{ meal.name }}</h1>
|
||||
{% if meal.instructions %}<p>{{ meal.instructions }}</p>{% endif %}
|
||||
|
||||
<h2>Ingredients</h2>
|
||||
<form hx-post="/meals/{{ meal.id }}/ingredients" hx-target="#ingredient-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
|
||||
<label>Item
|
||||
<select name="item_id" required>
|
||||
<option value="">--</option>
|
||||
{% for item in items %}
|
||||
<option value="{{ item.id }}">{{ item.name }} ({{ item.unit }})</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</label>
|
||||
<label>Qty needed <input type="number" name="quantity_needed" step="any" min="0.01" required></label>
|
||||
<button type="submit">Add</button>
|
||||
</form>
|
||||
|
||||
<div id="ingredient-list">
|
||||
{% include "partials/ingredient_rows.html" %}
|
||||
</div>
|
||||
{% endblock %}
|
||||
15
python/van_inventory/templates/meals.html
Normal file
15
python/van_inventory/templates/meals.html
Normal file
@@ -0,0 +1,15 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}Meals - Van{% endblock %}
|
||||
{% block content %}
|
||||
<h1>Meals</h1>
|
||||
|
||||
<form hx-post="/meals" hx-target="#meal-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
|
||||
<label>Name <input type="text" name="name" required></label>
|
||||
<label>Instructions <input type="text" name="instructions" placeholder="optional"></label>
|
||||
<button type="submit">Add Meal</button>
|
||||
</form>
|
||||
|
||||
<div id="meal-list">
|
||||
{% include "partials/meal_rows.html" %}
|
||||
</div>
|
||||
{% endblock %}
|
||||
16
python/van_inventory/templates/partials/ingredient_rows.html
Normal file
16
python/van_inventory/templates/partials/ingredient_rows.html
Normal file
@@ -0,0 +1,16 @@
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Item</th><th>Needed</th><th>Have</th><th>Unit</th><th></th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for mi in meal.ingredients %}
|
||||
<tr>
|
||||
<td>{{ mi.item.name }}</td>
|
||||
<td>{{ mi.quantity_needed }}</td>
|
||||
<td>{{ mi.item.quantity }}</td>
|
||||
<td>{{ mi.item.unit }}</td>
|
||||
<td><button class="danger" hx-delete="/meals/{{ meal.id }}/ingredients/{{ mi.item_id }}" hx-target="#ingredient-list" hx-swap="innerHTML" hx-confirm="Remove {{ mi.item.name }}?">X</button></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
21
python/van_inventory/templates/partials/item_rows.html
Normal file
21
python/van_inventory/templates/partials/item_rows.html
Normal file
@@ -0,0 +1,21 @@
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Name</th><th>Qty</th><th>Unit</th><th>Category</th><th></th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for item in items %}
|
||||
<tr>
|
||||
<td>{{ item.name }}</td>
|
||||
<td>
|
||||
<form hx-patch="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" style="display:inline; margin:0;">
|
||||
<input type="number" name="quantity" value="{{ item.quantity }}" step="any" min="0" style="width:5rem">
|
||||
<button type="submit" style="padding:0.2rem 0.5rem; font-size:0.8rem;">Update</button>
|
||||
</form>
|
||||
</td>
|
||||
<td>{{ item.unit }}</td>
|
||||
<td>{{ item.category or "" }}</td>
|
||||
<td><button class="danger" hx-delete="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" hx-confirm="Delete {{ item.name }}?">X</button></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
15
python/van_inventory/templates/partials/meal_rows.html
Normal file
15
python/van_inventory/templates/partials/meal_rows.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Name</th><th>Ingredients</th><th>Instructions</th><th></th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for meal in meals %}
|
||||
<tr>
|
||||
<td><a href="/meals/{{ meal.id }}">{{ meal.name }}</a></td>
|
||||
<td>{{ meal.ingredients | length }}</td>
|
||||
<td>{{ (meal.instructions or "")[:50] }}</td>
|
||||
<td><button class="danger" hx-delete="/meals/{{ meal.id }}" hx-target="#meal-list" hx-swap="innerHTML" hx-confirm="Delete {{ meal.name }}?">X</button></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -11,6 +11,7 @@
|
||||
authentication = pkgs.lib.mkOverride 10 ''
|
||||
|
||||
# admins
|
||||
# These are required for the nixos postgresql setup
|
||||
local all postgres trust
|
||||
host all postgres 127.0.0.1/32 trust
|
||||
host all postgres ::1/128 trust
|
||||
@@ -21,6 +22,8 @@
|
||||
host all richie 192.168.90.1/24 trust
|
||||
host all richie 192.168.99.1/24 trust
|
||||
|
||||
local vaninventory vaninventory trust
|
||||
|
||||
#type database DBuser origin-address auth-method
|
||||
local hass hass trust
|
||||
|
||||
@@ -62,6 +65,13 @@
|
||||
replication = true;
|
||||
};
|
||||
}
|
||||
{
|
||||
name = "vaninventory";
|
||||
ensureDBOwnership = true;
|
||||
ensureClauses = {
|
||||
login = true;
|
||||
};
|
||||
}
|
||||
{
|
||||
name = "hass";
|
||||
ensureDBOwnership = true;
|
||||
@@ -76,6 +86,7 @@
|
||||
ensureDatabases = [
|
||||
"hass"
|
||||
"richie"
|
||||
"vaninventory"
|
||||
];
|
||||
# Thank you NotAShelf
|
||||
# https://github.com/NotAShelf/nyx/blob/d407b4d6e5ab7f60350af61a3d73a62a5e9ac660/modules/core/roles/server/system/services/databases/postgresql.nix#L74
|
||||
|
||||
48
systems/brain/services/van_inventory.nix
Normal file
48
systems/brain/services/van_inventory.nix
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
pkgs,
|
||||
inputs,
|
||||
...
|
||||
}:
|
||||
{
|
||||
networking.firewall.allowedTCPPorts = [ 8001 ];
|
||||
|
||||
users.users.vaninventory = {
|
||||
isSystemUser = true;
|
||||
group = "vaninventory";
|
||||
};
|
||||
users.groups.vaninventory = { };
|
||||
|
||||
systemd.services.van_inventory = {
|
||||
description = "Van Inventory API";
|
||||
after = [
|
||||
"network.target"
|
||||
"postgresql.service"
|
||||
];
|
||||
requires = [ "postgresql.service" ];
|
||||
wantedBy = [ "multi-user.target" ];
|
||||
|
||||
environment = {
|
||||
PYTHONPATH = "${inputs.self}/";
|
||||
VAN_INVENTORY_DB = "vaninventory";
|
||||
VAN_INVENTORY_USER = "vaninventory";
|
||||
VAN_INVENTORY_HOST = "/run/postgresql";
|
||||
VAN_INVENTORY_PORT = "5432";
|
||||
};
|
||||
|
||||
serviceConfig = {
|
||||
Type = "simple";
|
||||
User = "van-inventory";
|
||||
Group = "van-inventory";
|
||||
ExecStart = "${pkgs.my_python}/bin/python -m python.van_inventory.main --host 0.0.0.0 --port 8001";
|
||||
Restart = "on-failure";
|
||||
RestartSec = "5s";
|
||||
StandardOutput = "journal";
|
||||
StandardError = "journal";
|
||||
NoNewPrivileges = true;
|
||||
ProtectSystem = "strict";
|
||||
ProtectHome = "read-only";
|
||||
PrivateTmp = true;
|
||||
ReadOnlyPaths = [ "${inputs.self}" ];
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Tests for python/api modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.api.routers.contact import (
|
||||
ContactBase,
|
||||
ContactCreate,
|
||||
ContactListResponse,
|
||||
ContactRelationshipCreate,
|
||||
ContactRelationshipResponse,
|
||||
ContactRelationshipUpdate,
|
||||
ContactUpdate,
|
||||
GraphData,
|
||||
GraphEdge,
|
||||
GraphNode,
|
||||
NeedBase,
|
||||
NeedCreate,
|
||||
NeedResponse,
|
||||
RelationshipTypeInfo,
|
||||
router,
|
||||
)
|
||||
from python.api.routers.frontend import create_frontend_router
|
||||
from python.orm.contact import RelationshipType
|
||||
|
||||
|
||||
# --- Pydantic schema tests ---
|
||||
|
||||
|
||||
def test_need_base() -> None:
|
||||
"""Test NeedBase schema."""
|
||||
need = NeedBase(name="ADHD", description="Attention deficit")
|
||||
assert need.name == "ADHD"
|
||||
|
||||
|
||||
def test_need_create() -> None:
|
||||
"""Test NeedCreate schema."""
|
||||
need = NeedCreate(name="Light Sensitive")
|
||||
assert need.name == "Light Sensitive"
|
||||
assert need.description is None
|
||||
|
||||
|
||||
def test_need_response() -> None:
|
||||
"""Test NeedResponse schema."""
|
||||
need = NeedResponse(id=1, name="ADHD", description="test")
|
||||
assert need.id == 1
|
||||
|
||||
|
||||
def test_contact_base() -> None:
|
||||
"""Test ContactBase schema."""
|
||||
contact = ContactBase(name="John")
|
||||
assert contact.name == "John"
|
||||
assert contact.age is None
|
||||
assert contact.bio is None
|
||||
|
||||
|
||||
def test_contact_create() -> None:
|
||||
"""Test ContactCreate schema."""
|
||||
contact = ContactCreate(name="John", need_ids=[1, 2])
|
||||
assert contact.need_ids == [1, 2]
|
||||
|
||||
|
||||
def test_contact_create_no_needs() -> None:
|
||||
"""Test ContactCreate with no needs."""
|
||||
contact = ContactCreate(name="John")
|
||||
assert contact.need_ids == []
|
||||
|
||||
|
||||
def test_contact_update() -> None:
|
||||
"""Test ContactUpdate schema."""
|
||||
update = ContactUpdate(name="Jane", age=30)
|
||||
assert update.name == "Jane"
|
||||
assert update.age == 30
|
||||
|
||||
|
||||
def test_contact_update_partial() -> None:
|
||||
"""Test ContactUpdate with partial data."""
|
||||
update = ContactUpdate(age=25)
|
||||
assert update.name is None
|
||||
assert update.age == 25
|
||||
|
||||
|
||||
def test_contact_list_response() -> None:
|
||||
"""Test ContactListResponse schema."""
|
||||
contact = ContactListResponse(id=1, name="John")
|
||||
assert contact.id == 1
|
||||
|
||||
|
||||
def test_contact_relationship_create() -> None:
|
||||
"""Test ContactRelationshipCreate schema."""
|
||||
rel = ContactRelationshipCreate(
|
||||
related_contact_id=2,
|
||||
relationship_type=RelationshipType.FRIEND,
|
||||
)
|
||||
assert rel.related_contact_id == 2
|
||||
assert rel.closeness_weight is None
|
||||
|
||||
|
||||
def test_contact_relationship_create_with_weight() -> None:
|
||||
"""Test ContactRelationshipCreate with custom weight."""
|
||||
rel = ContactRelationshipCreate(
|
||||
related_contact_id=2,
|
||||
relationship_type=RelationshipType.SPOUSE,
|
||||
closeness_weight=10,
|
||||
)
|
||||
assert rel.closeness_weight == 10
|
||||
|
||||
|
||||
def test_contact_relationship_update() -> None:
|
||||
"""Test ContactRelationshipUpdate schema."""
|
||||
update = ContactRelationshipUpdate(closeness_weight=8)
|
||||
assert update.relationship_type is None
|
||||
assert update.closeness_weight == 8
|
||||
|
||||
|
||||
def test_contact_relationship_response() -> None:
|
||||
"""Test ContactRelationshipResponse schema."""
|
||||
resp = ContactRelationshipResponse(
|
||||
contact_id=1,
|
||||
related_contact_id=2,
|
||||
relationship_type="friend",
|
||||
closeness_weight=6,
|
||||
)
|
||||
assert resp.contact_id == 1
|
||||
|
||||
|
||||
def test_relationship_type_info() -> None:
|
||||
"""Test RelationshipTypeInfo schema."""
|
||||
info = RelationshipTypeInfo(value="spouse", display_name="Spouse", default_weight=10)
|
||||
assert info.value == "spouse"
|
||||
|
||||
|
||||
def test_graph_node() -> None:
|
||||
"""Test GraphNode schema."""
|
||||
node = GraphNode(id=1, name="John", current_job="Dev")
|
||||
assert node.id == 1
|
||||
|
||||
|
||||
def test_graph_edge() -> None:
|
||||
"""Test GraphEdge schema."""
|
||||
edge = GraphEdge(source=1, target=2, relationship_type="friend", closeness_weight=6)
|
||||
assert edge.source == 1
|
||||
|
||||
|
||||
def test_graph_data() -> None:
|
||||
"""Test GraphData schema."""
|
||||
data = GraphData(
|
||||
nodes=[GraphNode(id=1, name="John")],
|
||||
edges=[GraphEdge(source=1, target=2, relationship_type="friend", closeness_weight=6)],
|
||||
)
|
||||
assert len(data.nodes) == 1
|
||||
assert len(data.edges) == 1
|
||||
|
||||
|
||||
# --- frontend router test ---
|
||||
|
||||
|
||||
def test_create_frontend_router(tmp_path: Path) -> None:
|
||||
"""Test create_frontend_router creates router."""
|
||||
# Create required assets dir and index.html
|
||||
assets_dir = tmp_path / "assets"
|
||||
assets_dir.mkdir()
|
||||
index = tmp_path / "index.html"
|
||||
index.write_text("<html></html>")
|
||||
|
||||
router = create_frontend_router(tmp_path)
|
||||
assert router is not None
|
||||
|
||||
|
||||
# --- API main tests ---
|
||||
|
||||
|
||||
def test_create_app() -> None:
|
||||
"""Test create_app creates FastAPI app."""
|
||||
with patch("python.api.main.get_postgres_engine"):
|
||||
from python.api.main import create_app
|
||||
|
||||
app = create_app()
|
||||
assert app is not None
|
||||
assert app.title == "Contact Database API"
|
||||
|
||||
|
||||
def test_create_app_with_frontend(tmp_path: Path) -> None:
|
||||
"""Test create_app with frontend directory."""
|
||||
assets_dir = tmp_path / "assets"
|
||||
assets_dir.mkdir()
|
||||
index = tmp_path / "index.html"
|
||||
index.write_text("<html></html>")
|
||||
|
||||
with patch("python.api.main.get_postgres_engine"):
|
||||
from python.api.main import create_app
|
||||
|
||||
app = create_app(frontend_dir=tmp_path)
|
||||
assert app is not None
|
||||
|
||||
|
||||
def test_build_frontend_none() -> None:
|
||||
"""Test build_frontend with None returns None."""
|
||||
from python.api.main import build_frontend
|
||||
|
||||
result = build_frontend(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_build_frontend_missing_dir() -> None:
|
||||
"""Test build_frontend with missing directory raises."""
|
||||
from python.api.main import build_frontend
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
build_frontend(Path("/nonexistent/path"))
|
||||
|
||||
|
||||
# --- dependencies test ---
|
||||
|
||||
|
||||
def test_db_session_dependency() -> None:
|
||||
"""Test get_db dependency."""
|
||||
from python.api.dependencies import get_db
|
||||
|
||||
mock_engine = create_engine("sqlite:///:memory:")
|
||||
mock_request = MagicMock()
|
||||
mock_request.app.state.engine = mock_engine
|
||||
|
||||
gen = get_db(mock_request)
|
||||
session = next(gen)
|
||||
assert isinstance(session, Session)
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration:
|
||||
pass
|
||||
@@ -1,469 +0,0 @@
|
||||
"""Integration tests for API router using SQLite in-memory database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from python.api.routers.contact import (
|
||||
ContactCreate,
|
||||
ContactRelationshipCreate,
|
||||
ContactRelationshipUpdate,
|
||||
ContactUpdate,
|
||||
NeedCreate,
|
||||
add_contact_relationship,
|
||||
add_need_to_contact,
|
||||
create_contact,
|
||||
create_need,
|
||||
delete_contact,
|
||||
delete_need,
|
||||
get_contact,
|
||||
get_contact_relationships,
|
||||
get_need,
|
||||
get_relationship_graph,
|
||||
list_contacts,
|
||||
list_needs,
|
||||
list_relationship_types,
|
||||
RelationshipTypeInfo,
|
||||
remove_contact_relationship,
|
||||
remove_need_from_contact,
|
||||
update_contact,
|
||||
update_contact_relationship,
|
||||
)
|
||||
from python.orm.base import RichieBase
|
||||
from python.orm.contact import Contact, ContactNeed, ContactRelationship, Need, RelationshipType
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _create_db() -> Session:
|
||||
"""Create in-memory SQLite database with schema."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
# Create tables without schema prefix for SQLite
|
||||
RichieBase.metadata.create_all(engine, checkfirst=True)
|
||||
return Session(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db() -> Session:
|
||||
"""Database session fixture."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
# SQLite doesn't support schemas, so we need to drop the schema reference
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
meta = MetaData()
|
||||
for table in RichieBase.metadata.sorted_tables:
|
||||
# Create table without schema
|
||||
table.to_metadata(meta)
|
||||
meta.create_all(engine)
|
||||
session = Session(engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
# --- Need CRUD tests ---
|
||||
|
||||
|
||||
def test_create_need(db: Session) -> None:
|
||||
"""Test creating a need."""
|
||||
need = create_need(NeedCreate(name="ADHD", description="Attention deficit"), db)
|
||||
assert need.name == "ADHD"
|
||||
assert need.id is not None
|
||||
|
||||
|
||||
def test_list_needs(db: Session) -> None:
|
||||
"""Test listing needs."""
|
||||
create_need(NeedCreate(name="ADHD"), db)
|
||||
create_need(NeedCreate(name="Light Sensitive"), db)
|
||||
needs = list_needs(db)
|
||||
assert len(needs) == 2
|
||||
|
||||
|
||||
def test_get_need(db: Session) -> None:
|
||||
"""Test getting a need by ID."""
|
||||
created = create_need(NeedCreate(name="ADHD"), db)
|
||||
need = get_need(created.id, db)
|
||||
assert need.name == "ADHD"
|
||||
|
||||
|
||||
def test_get_need_not_found(db: Session) -> None:
|
||||
"""Test getting a need that doesn't exist."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_need(999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_delete_need(db: Session) -> None:
|
||||
"""Test deleting a need."""
|
||||
created = create_need(NeedCreate(name="ADHD"), db)
|
||||
result = delete_need(created.id, db)
|
||||
assert result == {"deleted": True}
|
||||
|
||||
|
||||
def test_delete_need_not_found(db: Session) -> None:
|
||||
"""Test deleting a need that doesn't exist."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
delete_need(999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
# --- Contact CRUD tests ---
|
||||
|
||||
|
||||
def test_create_contact(db: Session) -> None:
|
||||
"""Test creating a contact."""
|
||||
contact = create_contact(ContactCreate(name="John"), db)
|
||||
assert contact.name == "John"
|
||||
assert contact.id is not None
|
||||
|
||||
|
||||
def test_create_contact_with_needs(db: Session) -> None:
|
||||
"""Test creating a contact with needs."""
|
||||
need = create_need(NeedCreate(name="ADHD"), db)
|
||||
contact = create_contact(ContactCreate(name="John", need_ids=[need.id]), db)
|
||||
assert len(contact.needs) == 1
|
||||
|
||||
|
||||
def test_list_contacts(db: Session) -> None:
|
||||
"""Test listing contacts."""
|
||||
create_contact(ContactCreate(name="John"), db)
|
||||
create_contact(ContactCreate(name="Jane"), db)
|
||||
contacts = list_contacts(db)
|
||||
assert len(contacts) == 2
|
||||
|
||||
|
||||
def test_list_contacts_pagination(db: Session) -> None:
|
||||
"""Test listing contacts with pagination."""
|
||||
for i in range(5):
|
||||
create_contact(ContactCreate(name=f"Contact {i}"), db)
|
||||
contacts = list_contacts(db, skip=2, limit=2)
|
||||
assert len(contacts) == 2
|
||||
|
||||
|
||||
def test_get_contact(db: Session) -> None:
|
||||
"""Test getting a contact by ID."""
|
||||
created = create_contact(ContactCreate(name="John"), db)
|
||||
contact = get_contact(created.id, db)
|
||||
assert contact.name == "John"
|
||||
|
||||
|
||||
def test_get_contact_not_found(db: Session) -> None:
|
||||
"""Test getting a contact that doesn't exist."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_contact(999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_update_contact(db: Session) -> None:
|
||||
"""Test updating a contact."""
|
||||
created = create_contact(ContactCreate(name="John"), db)
|
||||
updated = update_contact(created.id, ContactUpdate(name="Jane", age=30), db)
|
||||
assert updated.name == "Jane"
|
||||
assert updated.age == 30
|
||||
|
||||
|
||||
def test_update_contact_with_needs(db: Session) -> None:
|
||||
"""Test updating a contact's needs."""
|
||||
need = create_need(NeedCreate(name="ADHD"), db)
|
||||
created = create_contact(ContactCreate(name="John"), db)
|
||||
updated = update_contact(created.id, ContactUpdate(need_ids=[need.id]), db)
|
||||
assert len(updated.needs) == 1
|
||||
|
||||
|
||||
def test_update_contact_not_found(db: Session) -> None:
|
||||
"""Test updating a contact that doesn't exist."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
update_contact(999, ContactUpdate(name="Jane"), db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_delete_contact(db: Session) -> None:
|
||||
"""Test deleting a contact."""
|
||||
created = create_contact(ContactCreate(name="John"), db)
|
||||
result = delete_contact(created.id, db)
|
||||
assert result == {"deleted": True}
|
||||
|
||||
|
||||
def test_delete_contact_not_found(db: Session) -> None:
|
||||
"""Test deleting a contact that doesn't exist."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
delete_contact(999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
# --- Need-Contact association tests ---
|
||||
|
||||
|
||||
def test_add_need_to_contact(db: Session) -> None:
|
||||
"""Test adding a need to a contact."""
|
||||
need = create_need(NeedCreate(name="ADHD"), db)
|
||||
contact = create_contact(ContactCreate(name="John"), db)
|
||||
result = add_need_to_contact(contact.id, need.id, db)
|
||||
assert result == {"added": True}
|
||||
|
||||
|
||||
def test_add_need_to_contact_contact_not_found(db: Session) -> None:
|
||||
"""Test adding need to nonexistent contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
need = create_need(NeedCreate(name="ADHD"), db)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_need_to_contact(999, need.id, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_add_need_to_contact_need_not_found(db: Session) -> None:
|
||||
"""Test adding nonexistent need to contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
contact = create_contact(ContactCreate(name="John"), db)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_need_to_contact(contact.id, 999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_remove_need_from_contact(db: Session) -> None:
|
||||
"""Test removing a need from a contact."""
|
||||
need = create_need(NeedCreate(name="ADHD"), db)
|
||||
contact = create_contact(ContactCreate(name="John", need_ids=[need.id]), db)
|
||||
result = remove_need_from_contact(contact.id, need.id, db)
|
||||
assert result == {"removed": True}
|
||||
|
||||
|
||||
def test_remove_need_from_contact_contact_not_found(db: Session) -> None:
|
||||
"""Test removing need from nonexistent contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
remove_need_from_contact(999, 1, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_remove_need_from_contact_need_not_found(db: Session) -> None:
|
||||
"""Test removing nonexistent need from contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
contact = create_contact(ContactCreate(name="John"), db)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
remove_need_from_contact(contact.id, 999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
# --- Relationship tests ---
|
||||
|
||||
|
||||
def test_add_contact_relationship(db: Session) -> None:
|
||||
"""Test adding a relationship between contacts."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
rel = add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
assert rel.contact_id == c1.id
|
||||
assert rel.related_contact_id == c2.id
|
||||
|
||||
|
||||
def test_add_contact_relationship_default_weight(db: Session) -> None:
|
||||
"""Test relationship uses default weight from type."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
rel = add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.SPOUSE),
|
||||
db,
|
||||
)
|
||||
assert rel.closeness_weight == RelationshipType.SPOUSE.default_weight
|
||||
|
||||
|
||||
def test_add_contact_relationship_custom_weight(db: Session) -> None:
|
||||
"""Test relationship with custom weight."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
rel = add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND, closeness_weight=8),
|
||||
db,
|
||||
)
|
||||
assert rel.closeness_weight == 8
|
||||
|
||||
|
||||
def test_add_contact_relationship_contact_not_found(db: Session) -> None:
|
||||
"""Test adding relationship with nonexistent contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_contact_relationship(
|
||||
999,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_add_contact_relationship_related_not_found(db: Session) -> None:
|
||||
"""Test adding relationship with nonexistent related contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=999, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_add_contact_relationship_self(db: Session) -> None:
|
||||
"""Test cannot relate contact to itself."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c1.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
def test_get_contact_relationships(db: Session) -> None:
|
||||
"""Test getting relationships for a contact."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
rels = get_contact_relationships(c1.id, db)
|
||||
assert len(rels) == 1
|
||||
|
||||
|
||||
def test_get_contact_relationships_not_found(db: Session) -> None:
|
||||
"""Test getting relationships for nonexistent contact."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_contact_relationships(999, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_update_contact_relationship(db: Session) -> None:
|
||||
"""Test updating a relationship."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
updated = update_contact_relationship(
|
||||
c1.id,
|
||||
c2.id,
|
||||
ContactRelationshipUpdate(closeness_weight=9),
|
||||
db,
|
||||
)
|
||||
assert updated.closeness_weight == 9
|
||||
|
||||
|
||||
def test_update_contact_relationship_type(db: Session) -> None:
|
||||
"""Test updating relationship type."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
updated = update_contact_relationship(
|
||||
c1.id,
|
||||
c2.id,
|
||||
ContactRelationshipUpdate(relationship_type=RelationshipType.BEST_FRIEND),
|
||||
db,
|
||||
)
|
||||
assert updated.relationship_type == "best_friend"
|
||||
|
||||
|
||||
def test_update_contact_relationship_not_found(db: Session) -> None:
|
||||
"""Test updating nonexistent relationship."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
update_contact_relationship(
|
||||
999,
|
||||
998,
|
||||
ContactRelationshipUpdate(closeness_weight=5),
|
||||
db,
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_remove_contact_relationship(db: Session) -> None:
|
||||
"""Test removing a relationship."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
result = remove_contact_relationship(c1.id, c2.id, db)
|
||||
assert result == {"deleted": True}
|
||||
|
||||
|
||||
def test_remove_contact_relationship_not_found(db: Session) -> None:
|
||||
"""Test removing nonexistent relationship."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
remove_contact_relationship(999, 998, db)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
# --- list_relationship_types ---
|
||||
|
||||
|
||||
def test_list_relationship_types() -> None:
|
||||
"""Test listing relationship types."""
|
||||
types = list_relationship_types()
|
||||
assert len(types) == len(RelationshipType)
|
||||
assert all(isinstance(t, RelationshipTypeInfo) for t in types)
|
||||
|
||||
|
||||
# --- graph tests ---
|
||||
|
||||
|
||||
def test_get_relationship_graph(db: Session) -> None:
|
||||
"""Test getting relationship graph."""
|
||||
c1 = create_contact(ContactCreate(name="John"), db)
|
||||
c2 = create_contact(ContactCreate(name="Jane"), db)
|
||||
add_contact_relationship(
|
||||
c1.id,
|
||||
ContactRelationshipCreate(related_contact_id=c2.id, relationship_type=RelationshipType.FRIEND),
|
||||
db,
|
||||
)
|
||||
graph = get_relationship_graph(db)
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
|
||||
|
||||
def test_get_relationship_graph_empty(db: Session) -> None:
|
||||
"""Test getting empty relationship graph."""
|
||||
graph = get_relationship_graph(db)
|
||||
assert len(graph.nodes) == 0
|
||||
assert len(graph.edges) == 0
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Extended tests for python/api/main.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.api.main import build_frontend, create_app
|
||||
|
||||
|
||||
def test_build_frontend_runs_npm(tmp_path: Path) -> None:
|
||||
"""Test build_frontend runs npm commands."""
|
||||
source_dir = tmp_path / "frontend"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
dist_dir = tmp_path / "build" / "dist"
|
||||
dist_dir.mkdir(parents=True)
|
||||
(dist_dir / "index.html").write_text("<html></html>")
|
||||
|
||||
def mock_copytree(src: Path, dst: Path, dirs_exist_ok: bool = False) -> None:
|
||||
if "dist" in str(src):
|
||||
Path(dst).mkdir(parents=True, exist_ok=True)
|
||||
(Path(dst) / "index.html").write_text("<html></html>")
|
||||
|
||||
with (
|
||||
patch("python.api.main.subprocess.run") as mock_run,
|
||||
patch("python.api.main.shutil.copytree") as mock_copy,
|
||||
patch("python.api.main.shutil.rmtree"),
|
||||
patch("python.api.main.tempfile.mkdtemp") as mock_mkdtemp,
|
||||
):
|
||||
# First mkdtemp for build dir, second for output dir
|
||||
build_dir = str(tmp_path / "build")
|
||||
output_dir = str(tmp_path / "output")
|
||||
mock_mkdtemp.side_effect = [build_dir, output_dir]
|
||||
|
||||
# dist_dir exists check
|
||||
with patch("pathlib.Path.exists", return_value=True):
|
||||
result = build_frontend(source_dir, cache_dir=tmp_path / ".npm")
|
||||
|
||||
assert mock_run.call_count == 2 # npm install + npm run build
|
||||
|
||||
|
||||
def test_build_frontend_no_dist(tmp_path: Path) -> None:
|
||||
"""Test build_frontend raises when dist directory not found."""
|
||||
source_dir = tmp_path / "frontend"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
with (
|
||||
patch("python.api.main.subprocess.run"),
|
||||
patch("python.api.main.shutil.copytree"),
|
||||
patch("python.api.main.tempfile.mkdtemp", return_value=str(tmp_path / "build")),
|
||||
pytest.raises(FileNotFoundError, match="Build output not found"),
|
||||
):
|
||||
build_frontend(source_dir)
|
||||
|
||||
|
||||
def test_create_app_includes_contact_router() -> None:
|
||||
"""Test create_app includes contact router."""
|
||||
app = create_app()
|
||||
routes = [r.path for r in app.routes]
|
||||
# Should have API routes
|
||||
assert any("/api" in r for r in routes)
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Tests for api/main.py serve function and frontend router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.api.main import build_frontend, create_app, serve
|
||||
|
||||
|
||||
def test_build_frontend_none_source() -> None:
|
||||
"""Test build_frontend returns None when no source dir."""
|
||||
result = build_frontend(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_build_frontend_nonexistent_dir(tmp_path: Path) -> None:
|
||||
"""Test build_frontend raises for nonexistent directory."""
|
||||
with pytest.raises(FileExistsError):
|
||||
build_frontend(tmp_path / "nonexistent")
|
||||
|
||||
|
||||
def test_create_app_with_frontend(tmp_path: Path) -> None:
|
||||
"""Test create_app with frontend directory."""
|
||||
# Create a minimal frontend dir with assets
|
||||
assets = tmp_path / "assets"
|
||||
assets.mkdir()
|
||||
(tmp_path / "index.html").write_text("<html></html>")
|
||||
|
||||
app = create_app(frontend_dir=tmp_path)
|
||||
routes = [r.path for r in app.routes]
|
||||
assert any("/api" in r for r in routes)
|
||||
|
||||
|
||||
def test_serve_calls_uvicorn() -> None:
|
||||
"""Test serve function calls uvicorn.run."""
|
||||
with (
|
||||
patch("python.api.main.uvicorn.run") as mock_run,
|
||||
patch("python.api.main.build_frontend", return_value=None),
|
||||
patch("python.api.main.configure_logger"),
|
||||
patch.dict("os.environ", {"HOME": "/tmp"}),
|
||||
):
|
||||
serve(host="localhost", port=8000, log_level="INFO")
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
def test_serve_with_frontend_dir(tmp_path: Path) -> None:
|
||||
"""Test serve function with frontend dir."""
|
||||
assets = tmp_path / "assets"
|
||||
assets.mkdir()
|
||||
(tmp_path / "index.html").write_text("<html></html>")
|
||||
with (
|
||||
patch("python.api.main.uvicorn.run") as mock_run,
|
||||
patch("python.api.main.build_frontend", return_value=tmp_path),
|
||||
patch("python.api.main.configure_logger"),
|
||||
patch.dict("os.environ", {"HOME": "/tmp"}),
|
||||
):
|
||||
serve(host="localhost", frontend_dir=tmp_path, port=8000, log_level="INFO")
|
||||
mock_run.assert_called_once()
|
||||
@@ -1,364 +0,0 @@
|
||||
"""Tests for python/eval_warnings/main.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
from zipfile import ZipFile
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from python.eval_warnings.main import (
|
||||
EvalWarning,
|
||||
FileChange,
|
||||
apply_changes,
|
||||
compute_warning_hash,
|
||||
check_duplicate_pr,
|
||||
download_logs,
|
||||
extract_referenced_files,
|
||||
parse_changes,
|
||||
parse_warnings,
|
||||
query_ollama,
|
||||
run_cmd,
|
||||
create_pr,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def test_eval_warning_frozen() -> None:
|
||||
"""Test EvalWarning is frozen dataclass."""
|
||||
w = EvalWarning(system="test", message="warning: test msg")
|
||||
assert w.system == "test"
|
||||
assert w.message == "warning: test msg"
|
||||
|
||||
|
||||
def test_file_change() -> None:
|
||||
"""Test FileChange dataclass."""
|
||||
fc = FileChange(file_path="test.nix", original="old", fixed="new")
|
||||
assert fc.file_path == "test.nix"
|
||||
|
||||
|
||||
def test_run_cmd() -> None:
|
||||
"""Test run_cmd."""
|
||||
result = run_cmd(["echo", "hello"])
|
||||
assert result.stdout.strip() == "hello"
|
||||
|
||||
|
||||
def test_run_cmd_check_false() -> None:
|
||||
"""Test run_cmd with check=False."""
|
||||
result = run_cmd(["ls", "/nonexistent"], check=False)
|
||||
assert result.returncode != 0
|
||||
|
||||
|
||||
def test_parse_warnings_basic() -> None:
|
||||
"""Test parse_warnings extracts warnings."""
|
||||
logs = {
|
||||
"build-server1/2_Build.txt": "warning: test warning\nsome other line\ntrace: warning: another warning\n",
|
||||
}
|
||||
warnings = parse_warnings(logs)
|
||||
assert len(warnings) == 2
|
||||
|
||||
|
||||
def test_parse_warnings_ignores_untrusted_flake() -> None:
|
||||
"""Test parse_warnings ignores untrusted flake settings."""
|
||||
logs = {
|
||||
"build-server1/2_Build.txt": "warning: ignoring untrusted flake configuration setting foo\n",
|
||||
}
|
||||
warnings = parse_warnings(logs)
|
||||
assert len(warnings) == 0
|
||||
|
||||
|
||||
def test_parse_warnings_strips_timestamp() -> None:
|
||||
"""Test parse_warnings strips timestamps."""
|
||||
logs = {
|
||||
"build-server1/2_Build.txt": "2024-01-01T00:00:00.000Z warning: test msg\n",
|
||||
}
|
||||
warnings = parse_warnings(logs)
|
||||
assert len(warnings) == 1
|
||||
w = warnings.pop()
|
||||
assert w.message == "warning: test msg"
|
||||
assert w.system == "server1"
|
||||
|
||||
|
||||
def test_parse_warnings_empty() -> None:
|
||||
"""Test parse_warnings with no warnings."""
|
||||
logs = {"build-server1/2_Build.txt": "all good\n"}
|
||||
warnings = parse_warnings(logs)
|
||||
assert len(warnings) == 0
|
||||
|
||||
|
||||
def test_compute_warning_hash() -> None:
|
||||
"""Test compute_warning_hash returns consistent 8-char hash."""
|
||||
warnings = {EvalWarning(system="s1", message="msg1")}
|
||||
h = compute_warning_hash(warnings)
|
||||
assert len(h) == 8
|
||||
# Same input -> same hash
|
||||
assert compute_warning_hash(warnings) == h
|
||||
|
||||
|
||||
def test_compute_warning_hash_different() -> None:
|
||||
"""Test different warnings produce different hashes."""
|
||||
w1 = {EvalWarning(system="s1", message="msg1")}
|
||||
w2 = {EvalWarning(system="s1", message="msg2")}
|
||||
assert compute_warning_hash(w1) != compute_warning_hash(w2)
|
||||
|
||||
|
||||
def test_extract_referenced_files(tmp_path: Path) -> None:
|
||||
"""Test extract_referenced_files reads existing files."""
|
||||
nix_file = tmp_path / "test.nix"
|
||||
nix_file.write_text("{ pkgs }: pkgs")
|
||||
|
||||
warnings = {EvalWarning(system="s1", message=f"warning: in /nix/store/abc-source/{nix_file}")}
|
||||
# Won't find the file since it uses absolute paths resolved differently
|
||||
files = extract_referenced_files(warnings)
|
||||
# Result depends on actual file resolution
|
||||
assert isinstance(files, dict)
|
||||
|
||||
|
||||
def test_check_duplicate_pr_no_duplicate() -> None:
|
||||
"""Test check_duplicate_pr when no duplicate exists."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "fix: resolve nix eval warnings (abcd1234)\nfix: other (efgh5678)\n"
|
||||
|
||||
with patch("python.eval_warnings.main.run_cmd", return_value=mock_result):
|
||||
assert check_duplicate_pr("xxxxxxxx") is False
|
||||
|
||||
|
||||
def test_check_duplicate_pr_found() -> None:
|
||||
"""Test check_duplicate_pr when duplicate exists."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "fix: resolve nix eval warnings (abcd1234)\n"
|
||||
|
||||
with patch("python.eval_warnings.main.run_cmd", return_value=mock_result):
|
||||
assert check_duplicate_pr("abcd1234") is True
|
||||
|
||||
|
||||
def test_check_duplicate_pr_error() -> None:
|
||||
"""Test check_duplicate_pr raises on error."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 1
|
||||
mock_result.stderr = "gh error"
|
||||
|
||||
with (
|
||||
patch("python.eval_warnings.main.run_cmd", return_value=mock_result),
|
||||
pytest.raises(RuntimeError, match="Failed to check for duplicate PRs"),
|
||||
):
|
||||
check_duplicate_pr("test")
|
||||
|
||||
|
||||
def test_parse_changes_basic() -> None:
|
||||
"""Test parse_changes with valid response."""
|
||||
response = """## **REASONING**
|
||||
Some reasoning here.
|
||||
|
||||
## **CHANGES**
|
||||
FILE: test.nix
|
||||
<<<<<<< ORIGINAL
|
||||
old line
|
||||
=======
|
||||
new line
|
||||
>>>>>>> FIXED
|
||||
"""
|
||||
changes = parse_changes(response)
|
||||
assert len(changes) == 1
|
||||
assert changes[0].file_path == "test.nix"
|
||||
assert changes[0].original == "old line"
|
||||
assert changes[0].fixed == "new line"
|
||||
|
||||
|
||||
def test_parse_changes_no_changes_section() -> None:
|
||||
"""Test parse_changes with missing CHANGES section."""
|
||||
response = "Some text without changes"
|
||||
changes = parse_changes(response)
|
||||
assert changes == []
|
||||
|
||||
|
||||
def test_parse_changes_multiple() -> None:
|
||||
"""Test parse_changes with multiple file changes."""
|
||||
response = """**CHANGES**
|
||||
FILE: file1.nix
|
||||
<<<<<<< ORIGINAL
|
||||
old1
|
||||
=======
|
||||
new1
|
||||
>>>>>>> FIXED
|
||||
FILE: file2.nix
|
||||
<<<<<<< ORIGINAL
|
||||
old2
|
||||
=======
|
||||
new2
|
||||
>>>>>>> FIXED
|
||||
"""
|
||||
changes = parse_changes(response)
|
||||
assert len(changes) == 2
|
||||
|
||||
|
||||
def test_apply_changes(tmp_path: Path) -> None:
|
||||
"""Test apply_changes applies changes to files."""
|
||||
test_file = tmp_path / "test.nix"
|
||||
test_file.write_text("old content here")
|
||||
|
||||
changes = [FileChange(file_path=str(test_file), original="old content", fixed="new content")]
|
||||
|
||||
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
|
||||
applied = apply_changes(changes)
|
||||
|
||||
assert applied == 1
|
||||
assert "new content here" in test_file.read_text()
|
||||
|
||||
|
||||
def test_apply_changes_file_not_found(tmp_path: Path) -> None:
|
||||
"""Test apply_changes skips missing files."""
|
||||
changes = [FileChange(file_path=str(tmp_path / "missing.nix"), original="old", fixed="new")]
|
||||
|
||||
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
|
||||
applied = apply_changes(changes)
|
||||
|
||||
assert applied == 0
|
||||
|
||||
|
||||
def test_apply_changes_original_not_found(tmp_path: Path) -> None:
|
||||
"""Test apply_changes skips if original text not in file."""
|
||||
test_file = tmp_path / "test.nix"
|
||||
test_file.write_text("different content")
|
||||
|
||||
changes = [FileChange(file_path=str(test_file), original="not found", fixed="new")]
|
||||
|
||||
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
|
||||
applied = apply_changes(changes)
|
||||
|
||||
assert applied == 0
|
||||
|
||||
|
||||
def test_apply_changes_path_traversal(tmp_path: Path) -> None:
|
||||
"""Test apply_changes blocks path traversal."""
|
||||
changes = [FileChange(file_path="/etc/passwd", original="old", fixed="new")]
|
||||
|
||||
with patch("python.eval_warnings.main.Path.cwd", return_value=tmp_path):
|
||||
applied = apply_changes(changes)
|
||||
|
||||
assert applied == 0
|
||||
|
||||
|
||||
def test_query_ollama_success() -> None:
|
||||
"""Test query_ollama returns response."""
|
||||
warnings = {EvalWarning(system="s1", message="warning: test")}
|
||||
files = {"test.nix": "{ pkgs }: pkgs"}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"response": "some fix suggestion"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch("python.eval_warnings.main.post", return_value=mock_response):
|
||||
result = query_ollama(warnings, files, "http://localhost:11434")
|
||||
|
||||
assert result == "some fix suggestion"
|
||||
|
||||
|
||||
def test_query_ollama_failure() -> None:
|
||||
"""Test query_ollama returns None on failure."""
|
||||
from httpx import HTTPError
|
||||
|
||||
warnings = {EvalWarning(system="s1", message="warning: test")}
|
||||
files = {}
|
||||
|
||||
with patch("python.eval_warnings.main.post", side_effect=HTTPError("fail")):
|
||||
result = query_ollama(warnings, files, "http://localhost:11434")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_download_logs_success() -> None:
|
||||
"""Test download_logs extracts build log files from zip."""
|
||||
# Create a zip file in memory
|
||||
buf = BytesIO()
|
||||
with ZipFile(buf, "w") as zf:
|
||||
zf.writestr("build-server1/2_Build.txt", "warning: test")
|
||||
zf.writestr("other-file.txt", "not a build log")
|
||||
zip_bytes = buf.getvalue()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = zip_bytes
|
||||
|
||||
with patch("python.eval_warnings.main.subprocess.run", return_value=mock_result):
|
||||
logs = download_logs("12345", "owner/repo")
|
||||
|
||||
assert "build-server1/2_Build.txt" in logs
|
||||
assert "other-file.txt" not in logs
|
||||
|
||||
|
||||
def test_download_logs_failure() -> None:
|
||||
"""Test download_logs raises on failure."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 1
|
||||
mock_result.stderr = b"error"
|
||||
|
||||
with (
|
||||
patch("python.eval_warnings.main.subprocess.run", return_value=mock_result),
|
||||
pytest.raises(RuntimeError, match="Failed to download logs"),
|
||||
):
|
||||
download_logs("12345", "owner/repo")
|
||||
|
||||
|
||||
def test_create_pr() -> None:
|
||||
"""Test create_pr creates branch and PR."""
|
||||
warnings = {EvalWarning(system="s1", message="warning: test")}
|
||||
llm_response = "**REASONING**\nSome fix.\n**CHANGES**\nstuff"
|
||||
|
||||
mock_diff_result = MagicMock()
|
||||
mock_diff_result.returncode = 1 # changes exist
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mock_run_cmd(cmd: list[str], *, check: bool = True) -> MagicMock:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
result = MagicMock()
|
||||
result.returncode = 0
|
||||
result.stdout = ""
|
||||
if "diff" in cmd:
|
||||
result.returncode = 1
|
||||
return result
|
||||
|
||||
with patch("python.eval_warnings.main.run_cmd", side_effect=mock_run_cmd):
|
||||
create_pr("abcd1234", warnings, llm_response, "https://example.com/run/1")
|
||||
|
||||
assert call_count > 0
|
||||
|
||||
|
||||
def test_create_pr_no_changes() -> None:
|
||||
"""Test create_pr does nothing when no file changes."""
|
||||
warnings = {EvalWarning(system="s1", message="warning: test")}
|
||||
llm_response = "**REASONING**\nNo changes needed.\n**CHANGES**\n"
|
||||
|
||||
def mock_run_cmd(cmd: list[str], *, check: bool = True) -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.returncode = 0
|
||||
result.stdout = ""
|
||||
return result
|
||||
|
||||
with patch("python.eval_warnings.main.run_cmd", side_effect=mock_run_cmd):
|
||||
create_pr("abcd1234", warnings, llm_response, "https://example.com/run/1")
|
||||
|
||||
|
||||
def test_create_pr_no_reasoning() -> None:
|
||||
"""Test create_pr handles missing REASONING section."""
|
||||
warnings = {EvalWarning(system="s1", message="warning: test")}
|
||||
llm_response = "No reasoning here"
|
||||
|
||||
def mock_run_cmd(cmd: list[str], *, check: bool = True) -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.returncode = 0 if "diff" not in cmd else 1
|
||||
result.stdout = ""
|
||||
return result
|
||||
|
||||
with patch("python.eval_warnings.main.run_cmd", side_effect=mock_run_cmd):
|
||||
create_pr("abcd1234", warnings, llm_response, "https://example.com/run/1")
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Extended tests for python/eval_warnings/main.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.eval_warnings.main import (
|
||||
EvalWarning,
|
||||
extract_referenced_files,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_referenced_files_nix_store_paths(tmp_path: Path) -> None:
|
||||
"""Test extracting files from nix store paths."""
|
||||
# Create matching directory structure
|
||||
systems_dir = tmp_path / "systems"
|
||||
systems_dir.mkdir()
|
||||
nix_file = systems_dir / "test.nix"
|
||||
nix_file.write_text("{ pkgs }: pkgs")
|
||||
|
||||
warnings = {
|
||||
EvalWarning(
|
||||
system="s1",
|
||||
message="warning: in /nix/store/abc-source/systems/test.nix:5: deprecated",
|
||||
)
|
||||
}
|
||||
|
||||
# Change to tmp_path so relative paths work
|
||||
old_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tmp_path)
|
||||
files = extract_referenced_files(warnings)
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
assert "systems/test.nix" in files
|
||||
assert files["systems/test.nix"] == "{ pkgs }: pkgs"
|
||||
|
||||
|
||||
def test_extract_referenced_files_no_files_found() -> None:
|
||||
"""Test extract_referenced_files when no files are found."""
|
||||
warnings = {
|
||||
EvalWarning(
|
||||
system="s1",
|
||||
message="warning: something generic without file paths",
|
||||
)
|
||||
}
|
||||
files = extract_referenced_files(warnings)
|
||||
# Either empty or has flake.nix fallback
|
||||
assert isinstance(files, dict)
|
||||
|
||||
|
||||
def test_extract_referenced_files_repo_relative_paths(tmp_path: Path) -> None:
|
||||
"""Test extracting repo-relative file paths."""
|
||||
# Create the referenced file
|
||||
systems_dir = tmp_path / "systems" / "foo"
|
||||
systems_dir.mkdir(parents=True)
|
||||
nix_file = systems_dir / "bar.nix"
|
||||
nix_file.write_text("{ config }: {}")
|
||||
|
||||
warnings = {
|
||||
EvalWarning(
|
||||
system="s1",
|
||||
message="warning: in systems/foo/bar.nix:10: test",
|
||||
)
|
||||
}
|
||||
|
||||
old_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tmp_path)
|
||||
files = extract_referenced_files(warnings)
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
assert "systems/foo/bar.nix" in files
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Tests for eval_warnings/main.py main() entry point."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_eval_warnings_main_no_warnings() -> None:
|
||||
"""Test main() when no warnings are found."""
|
||||
from python.eval_warnings.main import main
|
||||
|
||||
with (
|
||||
patch("python.eval_warnings.main.configure_logger"),
|
||||
patch("python.eval_warnings.main.download_logs", return_value="clean log"),
|
||||
patch("python.eval_warnings.main.parse_warnings", return_value=set()),
|
||||
):
|
||||
main(
|
||||
run_id="123",
|
||||
repo="owner/repo",
|
||||
ollama_url="http://localhost:11434",
|
||||
run_url="http://example.com/run",
|
||||
log_level="INFO",
|
||||
)
|
||||
|
||||
|
||||
def test_eval_warnings_main_duplicate_pr() -> None:
|
||||
"""Test main() when a duplicate PR exists."""
|
||||
from python.eval_warnings.main import main, EvalWarning
|
||||
|
||||
warnings = {EvalWarning(system="s1", message="test")}
|
||||
with (
|
||||
patch("python.eval_warnings.main.configure_logger"),
|
||||
patch("python.eval_warnings.main.download_logs", return_value="log"),
|
||||
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
|
||||
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
|
||||
patch("python.eval_warnings.main.check_duplicate_pr", return_value=True),
|
||||
):
|
||||
main(
|
||||
run_id="123",
|
||||
repo="owner/repo",
|
||||
ollama_url="http://localhost:11434",
|
||||
run_url="http://example.com/run",
|
||||
)
|
||||
|
||||
|
||||
def test_eval_warnings_main_no_llm_response() -> None:
|
||||
"""Test main() when LLM returns no response."""
|
||||
from python.eval_warnings.main import main, EvalWarning
|
||||
|
||||
warnings = {EvalWarning(system="s1", message="test")}
|
||||
with (
|
||||
patch("python.eval_warnings.main.configure_logger"),
|
||||
patch("python.eval_warnings.main.download_logs", return_value="log"),
|
||||
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
|
||||
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
|
||||
patch("python.eval_warnings.main.check_duplicate_pr", return_value=False),
|
||||
patch("python.eval_warnings.main.extract_referenced_files", return_value={}),
|
||||
patch("python.eval_warnings.main.query_ollama", return_value=None),
|
||||
):
|
||||
main(
|
||||
run_id="123",
|
||||
repo="owner/repo",
|
||||
ollama_url="http://localhost:11434",
|
||||
run_url="http://example.com/run",
|
||||
)
|
||||
|
||||
|
||||
def test_eval_warnings_main_no_changes_applied() -> None:
|
||||
"""Test main() when no changes are applied."""
|
||||
from python.eval_warnings.main import main, EvalWarning
|
||||
|
||||
warnings = {EvalWarning(system="s1", message="test")}
|
||||
with (
|
||||
patch("python.eval_warnings.main.configure_logger"),
|
||||
patch("python.eval_warnings.main.download_logs", return_value="log"),
|
||||
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
|
||||
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
|
||||
patch("python.eval_warnings.main.check_duplicate_pr", return_value=False),
|
||||
patch("python.eval_warnings.main.extract_referenced_files", return_value={}),
|
||||
patch("python.eval_warnings.main.query_ollama", return_value="some response"),
|
||||
patch("python.eval_warnings.main.parse_changes", return_value=[]),
|
||||
patch("python.eval_warnings.main.apply_changes", return_value=0),
|
||||
):
|
||||
main(
|
||||
run_id="123",
|
||||
repo="owner/repo",
|
||||
ollama_url="http://localhost:11434",
|
||||
run_url="http://example.com/run",
|
||||
)
|
||||
|
||||
|
||||
def test_eval_warnings_main_full_success() -> None:
|
||||
"""Test main() full success path."""
|
||||
from python.eval_warnings.main import main, EvalWarning
|
||||
|
||||
warnings = {EvalWarning(system="s1", message="test")}
|
||||
with (
|
||||
patch("python.eval_warnings.main.configure_logger"),
|
||||
patch("python.eval_warnings.main.download_logs", return_value="log"),
|
||||
patch("python.eval_warnings.main.parse_warnings", return_value=warnings),
|
||||
patch("python.eval_warnings.main.compute_warning_hash", return_value="abc123"),
|
||||
patch("python.eval_warnings.main.check_duplicate_pr", return_value=False),
|
||||
patch("python.eval_warnings.main.extract_referenced_files", return_value={}),
|
||||
patch("python.eval_warnings.main.query_ollama", return_value="response"),
|
||||
patch("python.eval_warnings.main.parse_changes", return_value=[{"file": "a.nix"}]),
|
||||
patch("python.eval_warnings.main.apply_changes", return_value=1),
|
||||
patch("python.eval_warnings.main.create_pr") as mock_pr,
|
||||
):
|
||||
main(
|
||||
run_id="123",
|
||||
repo="owner/repo",
|
||||
ollama_url="http://localhost:11434",
|
||||
run_url="http://example.com/run",
|
||||
)
|
||||
mock_pr.assert_called_once()
|
||||
@@ -1,248 +0,0 @@
|
||||
"""Tests for python/heater modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
# --- models tests ---
|
||||
|
||||
|
||||
def test_device_config() -> None:
|
||||
"""Test DeviceConfig creation."""
|
||||
config = DeviceConfig(device_id="abc123", ip="192.168.1.1", local_key="key123")
|
||||
assert config.device_id == "abc123"
|
||||
assert config.ip == "192.168.1.1"
|
||||
assert config.local_key == "key123"
|
||||
assert config.version == 3.5
|
||||
|
||||
|
||||
def test_device_config_custom_version() -> None:
|
||||
"""Test DeviceConfig with custom version."""
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key", version=3.3)
|
||||
assert config.version == 3.3
|
||||
|
||||
|
||||
def test_heater_status_defaults() -> None:
|
||||
"""Test HeaterStatus default values."""
|
||||
status = HeaterStatus(power=True)
|
||||
assert status.power is True
|
||||
assert status.setpoint is None
|
||||
assert status.state is None
|
||||
assert status.error_code is None
|
||||
assert status.raw_dps == {}
|
||||
|
||||
|
||||
def test_heater_status_full() -> None:
|
||||
"""Test HeaterStatus with all fields."""
|
||||
status = HeaterStatus(
|
||||
power=True,
|
||||
setpoint=72,
|
||||
state="Heat",
|
||||
error_code=0,
|
||||
raw_dps={"1": True, "101": 72},
|
||||
)
|
||||
assert status.power is True
|
||||
assert status.setpoint == 72
|
||||
assert status.state == "Heat"
|
||||
|
||||
|
||||
def test_action_result_success() -> None:
|
||||
"""Test ActionResult success."""
|
||||
result = ActionResult(success=True, action="on", power=True)
|
||||
assert result.success is True
|
||||
assert result.action == "on"
|
||||
assert result.power is True
|
||||
assert result.error is None
|
||||
|
||||
|
||||
def test_action_result_failure() -> None:
|
||||
"""Test ActionResult failure."""
|
||||
result = ActionResult(success=False, action="on", error="Connection failed")
|
||||
assert result.success is False
|
||||
assert result.error == "Connection failed"
|
||||
|
||||
|
||||
# --- controller tests (with mocked tinytuya) ---
|
||||
|
||||
|
||||
def _get_controller_class() -> type:
|
||||
"""Import HeaterController with mocked tinytuya."""
|
||||
mock_tinytuya = MagicMock()
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
# Force reimport
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
return HeaterController
|
||||
|
||||
|
||||
def test_heater_controller_status_success() -> None:
|
||||
"""Test HeaterController.status returns correct status."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.status.return_value = {"dps": {"1": True, "101": 72, "102": "Heat", "108": 0}}
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
status = controller.status()
|
||||
|
||||
assert status.power is True
|
||||
assert status.setpoint == 72
|
||||
assert status.state == "Heat"
|
||||
|
||||
|
||||
def test_heater_controller_status_error() -> None:
|
||||
"""Test HeaterController.status handles device error."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.status.return_value = {"Error": "Connection timeout"}
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
status = controller.status()
|
||||
|
||||
assert status.power is False
|
||||
|
||||
|
||||
def test_heater_controller_turn_on() -> None:
|
||||
"""Test HeaterController.turn_on."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
result = controller.turn_on()
|
||||
|
||||
assert result.success is True
|
||||
assert result.action == "on"
|
||||
assert result.power is True
|
||||
|
||||
|
||||
def test_heater_controller_turn_on_error() -> None:
|
||||
"""Test HeaterController.turn_on handles errors."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.side_effect = ConnectionError("timeout")
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
result = controller.turn_on()
|
||||
|
||||
assert result.success is False
|
||||
assert "timeout" in result.error
|
||||
|
||||
|
||||
def test_heater_controller_turn_off() -> None:
|
||||
"""Test HeaterController.turn_off."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
result = controller.turn_off()
|
||||
|
||||
assert result.success is True
|
||||
assert result.action == "off"
|
||||
assert result.power is False
|
||||
|
||||
|
||||
def test_heater_controller_turn_off_error() -> None:
|
||||
"""Test HeaterController.turn_off handles errors."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.side_effect = ConnectionError("timeout")
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
result = controller.turn_off()
|
||||
|
||||
assert result.success is False
|
||||
|
||||
|
||||
def test_heater_controller_toggle_on_to_off() -> None:
|
||||
"""Test HeaterController.toggle when heater is on."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.status.return_value = {"dps": {"1": True}}
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
result = controller.toggle()
|
||||
|
||||
assert result.success is True
|
||||
assert result.action == "off"
|
||||
|
||||
|
||||
def test_heater_controller_toggle_off_to_on() -> None:
|
||||
"""Test HeaterController.toggle when heater is off."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.status.return_value = {"dps": {"1": False}}
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
controller = HeaterController(config)
|
||||
result = controller.toggle()
|
||||
|
||||
assert result.success is True
|
||||
assert result.action == "on"
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Tests for python/heater/main.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
|
||||
|
||||
|
||||
def test_create_app() -> None:
|
||||
"""Test create_app creates FastAPI app."""
|
||||
mock_tinytuya = MagicMock()
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
assert app is not None
|
||||
assert app.title == "Heater Control API"
|
||||
|
||||
|
||||
def test_serve_missing_params() -> None:
|
||||
"""Test serve raises with missing parameters."""
|
||||
import typer
|
||||
|
||||
mock_tinytuya = MagicMock()
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import serve
|
||||
|
||||
with patch("python.heater.main.configure_logger"):
|
||||
try:
|
||||
serve(host="0.0.0.0", port=8124, log_level="INFO")
|
||||
except (typer.Exit, SystemExit):
|
||||
pass
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Extended tests for python/heater/main.py - FastAPI routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
|
||||
|
||||
|
||||
def test_heater_app_routes() -> None:
|
||||
"""Test heater app has expected routes."""
|
||||
mock_tinytuya = MagicMock()
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
|
||||
route_paths = [r.path for r in app.routes]
|
||||
assert "/status" in route_paths
|
||||
assert "/on" in route_paths
|
||||
assert "/off" in route_paths
|
||||
assert "/toggle" in route_paths
|
||||
|
||||
|
||||
def test_heater_get_status_route() -> None:
|
||||
"""Test /status route handler."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.status.return_value = {"dps": {"1": True, "101": 72}}
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
|
||||
# Simulate lifespan by setting controller
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
# Find and call the status handler
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/status":
|
||||
result = route.endpoint()
|
||||
assert result.power is True
|
||||
break
|
||||
|
||||
|
||||
def test_heater_on_route() -> None:
|
||||
"""Test /on route handler."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/on":
|
||||
result = route.endpoint()
|
||||
assert result.success is True
|
||||
break
|
||||
|
||||
|
||||
def test_heater_off_route() -> None:
|
||||
"""Test /off route handler."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/off":
|
||||
result = route.endpoint()
|
||||
assert result.success is True
|
||||
break
|
||||
|
||||
|
||||
def test_heater_toggle_route() -> None:
|
||||
"""Test /toggle route handler."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.status.return_value = {"dps": {"1": True}}
|
||||
mock_device.set_value.return_value = None
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/toggle":
|
||||
result = route.endpoint()
|
||||
assert result.success is True
|
||||
break
|
||||
|
||||
|
||||
def test_heater_on_route_failure() -> None:
|
||||
"""Test /on route raises HTTPException on failure."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.side_effect = ConnectionError("fail")
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
from fastapi import HTTPException
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
import pytest
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/on":
|
||||
with pytest.raises(HTTPException):
|
||||
route.endpoint()
|
||||
break
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Tests for heater/main.py serve function and lifespan."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.exceptions import Exit
|
||||
|
||||
from python.heater.models import DeviceConfig
|
||||
|
||||
|
||||
def test_serve_missing_params() -> None:
|
||||
"""Test serve raises when device params are missing."""
|
||||
mock_tinytuya = MagicMock()
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import serve
|
||||
|
||||
with pytest.raises(Exit):
|
||||
serve(host="localhost", port=8124, log_level="INFO")
|
||||
|
||||
|
||||
def test_serve_with_params() -> None:
|
||||
"""Test serve starts uvicorn when params provided."""
|
||||
mock_tinytuya = MagicMock()
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import serve
|
||||
|
||||
with patch("python.heater.main.uvicorn.run") as mock_run:
|
||||
serve(
|
||||
host="localhost",
|
||||
port=8124,
|
||||
log_level="INFO",
|
||||
device_id="abc",
|
||||
device_ip="10.0.0.1",
|
||||
local_key="key123",
|
||||
)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
def test_heater_off_route_failure() -> None:
|
||||
"""Test /off route raises HTTPException on failure."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
mock_device.set_value.side_effect = ConnectionError("fail")
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
from fastapi import HTTPException
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/off":
|
||||
with pytest.raises(HTTPException):
|
||||
route.endpoint()
|
||||
break
|
||||
|
||||
|
||||
def test_heater_toggle_route_failure() -> None:
|
||||
"""Test /toggle route raises HTTPException on failure."""
|
||||
mock_tinytuya = MagicMock()
|
||||
mock_device = MagicMock()
|
||||
# toggle calls status() first then set_value - make set_value fail
|
||||
mock_device.status.return_value = {"dps": {"1": True}}
|
||||
mock_device.set_value.side_effect = ConnectionError("fail")
|
||||
mock_tinytuya.Device.return_value = mock_device
|
||||
|
||||
with patch.dict(sys.modules, {"tinytuya": mock_tinytuya}):
|
||||
if "python.heater.controller" in sys.modules:
|
||||
del sys.modules["python.heater.controller"]
|
||||
if "python.heater.main" in sys.modules:
|
||||
del sys.modules["python.heater.main"]
|
||||
from python.heater.main import create_app
|
||||
from python.heater.controller import HeaterController
|
||||
from fastapi import HTTPException
|
||||
|
||||
config = DeviceConfig(device_id="abc", ip="10.0.0.1", local_key="key")
|
||||
app = create_app(config)
|
||||
app.state.controller = HeaterController(config)
|
||||
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and route.path == "/toggle":
|
||||
with pytest.raises(HTTPException):
|
||||
route.endpoint()
|
||||
break
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Tests for python/installer modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import curses
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.installer.tui import (
|
||||
Cursor,
|
||||
State,
|
||||
calculate_device_menu_padding,
|
||||
get_device,
|
||||
)
|
||||
|
||||
|
||||
# --- Cursor tests ---
|
||||
|
||||
|
||||
def test_cursor_init() -> None:
|
||||
"""Test Cursor initialization."""
|
||||
c = Cursor()
|
||||
assert c.get_x() == 0
|
||||
assert c.get_y() == 0
|
||||
assert c.height == 0
|
||||
assert c.width == 0
|
||||
|
||||
|
||||
def test_cursor_set_height_width() -> None:
|
||||
"""Test Cursor set_height and set_width."""
|
||||
c = Cursor()
|
||||
c.set_height(100)
|
||||
c.set_width(200)
|
||||
assert c.height == 100
|
||||
assert c.width == 200
|
||||
|
||||
|
||||
def test_cursor_bounce_check() -> None:
|
||||
"""Test Cursor bounce checks."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
|
||||
assert c.x_bounce_check(-1) == 0
|
||||
assert c.x_bounce_check(25) == 19
|
||||
assert c.x_bounce_check(5) == 5
|
||||
|
||||
assert c.y_bounce_check(-1) == 0
|
||||
assert c.y_bounce_check(15) == 9
|
||||
assert c.y_bounce_check(5) == 5
|
||||
|
||||
|
||||
def test_cursor_set_x_y() -> None:
|
||||
"""Test Cursor set_x and set_y."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(5)
|
||||
c.set_y(3)
|
||||
assert c.get_x() == 5
|
||||
assert c.get_y() == 3
|
||||
|
||||
|
||||
def test_cursor_set_x_y_bounds() -> None:
|
||||
"""Test Cursor set_x and set_y with bounds."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(-5)
|
||||
assert c.get_x() == 0
|
||||
c.set_y(100)
|
||||
assert c.get_y() == 9
|
||||
|
||||
|
||||
def test_cursor_move_up() -> None:
|
||||
"""Test Cursor move_up."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_y(5)
|
||||
c.move_up()
|
||||
assert c.get_y() == 4
|
||||
|
||||
|
||||
def test_cursor_move_down() -> None:
|
||||
"""Test Cursor move_down."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_y(5)
|
||||
c.move_down()
|
||||
assert c.get_y() == 6
|
||||
|
||||
|
||||
def test_cursor_move_left() -> None:
|
||||
"""Test Cursor move_left."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(5)
|
||||
c.move_left()
|
||||
assert c.get_x() == 4
|
||||
|
||||
|
||||
def test_cursor_move_right() -> None:
|
||||
"""Test Cursor move_right."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(5)
|
||||
c.move_right()
|
||||
assert c.get_x() == 6
|
||||
|
||||
|
||||
def test_cursor_navigation() -> None:
|
||||
"""Test Cursor navigation with arrow keys."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(5)
|
||||
c.set_y(5)
|
||||
|
||||
c.navigation(curses.KEY_UP)
|
||||
assert c.get_y() == 4
|
||||
c.navigation(curses.KEY_DOWN)
|
||||
assert c.get_y() == 5
|
||||
c.navigation(curses.KEY_LEFT)
|
||||
assert c.get_x() == 4
|
||||
c.navigation(curses.KEY_RIGHT)
|
||||
assert c.get_x() == 5
|
||||
|
||||
|
||||
def test_cursor_navigation_unknown_key() -> None:
|
||||
"""Test Cursor navigation with unknown key (no-op)."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(5)
|
||||
c.set_y(5)
|
||||
c.navigation(999) # Unknown key
|
||||
assert c.get_x() == 5
|
||||
assert c.get_y() == 5
|
||||
|
||||
|
||||
# --- State tests ---
|
||||
|
||||
|
||||
def test_state_init() -> None:
|
||||
"""Test State initialization."""
|
||||
s = State()
|
||||
assert s.key == 0
|
||||
assert s.swap_size == 0
|
||||
assert s.reserve_size == 0
|
||||
assert s.selected_device_ids == set()
|
||||
assert s.show_swap_input is False
|
||||
assert s.show_reserve_input is False
|
||||
|
||||
|
||||
def test_state_get_selected_devices() -> None:
|
||||
"""Test State.get_selected_devices."""
|
||||
s = State()
|
||||
s.selected_device_ids = {"/dev/sda", "/dev/sdb"}
|
||||
result = s.get_selected_devices()
|
||||
assert isinstance(result, tuple)
|
||||
assert set(result) == {"/dev/sda", "/dev/sdb"}
|
||||
|
||||
|
||||
# --- get_device tests ---
|
||||
|
||||
|
||||
def test_get_device() -> None:
|
||||
"""Test get_device parses device string."""
|
||||
raw = 'NAME="/dev/sda" SIZE="100G" TYPE="disk" MOUNTPOINTS=""'
|
||||
device = get_device(raw)
|
||||
assert device["name"] == "/dev/sda"
|
||||
assert device["size"] == "100G"
|
||||
assert device["type"] == "disk"
|
||||
|
||||
|
||||
# --- calculate_device_menu_padding ---
|
||||
|
||||
|
||||
def test_calculate_device_menu_padding() -> None:
|
||||
"""Test calculate_device_menu_padding."""
|
||||
devices = [
|
||||
{"name": "/dev/sda", "size": "100G"},
|
||||
{"name": "/dev/nvme0n1", "size": "500G"},
|
||||
]
|
||||
padding = calculate_device_menu_padding(devices, "name", 2)
|
||||
assert padding == len("/dev/nvme0n1") + 2
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Extended tests for python/installer modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.installer.__main__ import (
|
||||
bash_wrapper,
|
||||
create_zfs_pool,
|
||||
get_cpu_manufacturer,
|
||||
partition_disk,
|
||||
)
|
||||
from python.installer.tui import (
|
||||
Cursor,
|
||||
State,
|
||||
bash_wrapper as tui_bash_wrapper,
|
||||
get_device,
|
||||
calculate_device_menu_padding,
|
||||
)
|
||||
|
||||
|
||||
# --- installer __main__ tests ---
|
||||
|
||||
|
||||
def test_installer_bash_wrapper_success() -> None:
|
||||
"""Test installer bash_wrapper on success."""
|
||||
result = bash_wrapper("echo hello")
|
||||
assert result.strip() == "hello"
|
||||
|
||||
|
||||
def test_installer_bash_wrapper_error() -> None:
|
||||
"""Test installer bash_wrapper raises on error."""
|
||||
with pytest.raises(RuntimeError, match="Failed to run command"):
|
||||
bash_wrapper("ls /nonexistent/path/that/does/not/exist")
|
||||
|
||||
|
||||
def test_partition_disk() -> None:
|
||||
"""Test partition_disk calls commands correctly."""
|
||||
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
|
||||
partition_disk("/dev/sda", swap_size=8, reserve=0)
|
||||
assert mock_bash.call_count == 2
|
||||
|
||||
|
||||
def test_partition_disk_with_reserve() -> None:
|
||||
"""Test partition_disk with reserve space."""
|
||||
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
|
||||
partition_disk("/dev/sda", swap_size=8, reserve=10)
|
||||
assert mock_bash.call_count == 2
|
||||
|
||||
|
||||
def test_partition_disk_minimum_swap() -> None:
|
||||
"""Test partition_disk enforces minimum swap size."""
|
||||
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
|
||||
partition_disk("/dev/sda", swap_size=0, reserve=-1)
|
||||
# swap_size should be clamped to 1, reserve to 0
|
||||
assert mock_bash.call_count == 2
|
||||
|
||||
|
||||
def test_create_zfs_pool_single_disk() -> None:
|
||||
"""Test create_zfs_pool with single disk."""
|
||||
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
|
||||
mock_bash.return_value = "NAME\nroot_pool\n"
|
||||
create_zfs_pool(["/dev/sda-part2"], "/mnt")
|
||||
assert mock_bash.call_count == 2
|
||||
|
||||
|
||||
def test_create_zfs_pool_mirror() -> None:
|
||||
"""Test create_zfs_pool with mirror disks."""
|
||||
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
|
||||
mock_bash.return_value = "NAME\nroot_pool\n"
|
||||
create_zfs_pool(["/dev/sda-part2", "/dev/sdb-part2"], "/mnt")
|
||||
assert mock_bash.call_count == 2
|
||||
|
||||
|
||||
def test_create_zfs_pool_no_disks() -> None:
|
||||
"""Test create_zfs_pool raises with no disks."""
|
||||
with pytest.raises(ValueError, match="disks must be a tuple"):
|
||||
create_zfs_pool([], "/mnt")
|
||||
|
||||
|
||||
def test_get_cpu_manufacturer_amd() -> None:
|
||||
"""Test get_cpu_manufacturer with AMD CPU."""
|
||||
output = "vendor_id\t: AuthenticAMD\nmodel name\t: AMD Ryzen 9\n"
|
||||
with patch("python.installer.__main__.bash_wrapper", return_value=output):
|
||||
assert get_cpu_manufacturer() == "amd"
|
||||
|
||||
|
||||
def test_get_cpu_manufacturer_intel() -> None:
|
||||
"""Test get_cpu_manufacturer with Intel CPU."""
|
||||
output = "vendor_id\t: GenuineIntel\nmodel name\t: Intel Core i9\n"
|
||||
with patch("python.installer.__main__.bash_wrapper", return_value=output):
|
||||
assert get_cpu_manufacturer() == "intel"
|
||||
|
||||
|
||||
def test_get_cpu_manufacturer_unknown() -> None:
|
||||
"""Test get_cpu_manufacturer with unknown CPU raises."""
|
||||
output = "model name\t: Unknown CPU\n"
|
||||
with (
|
||||
patch("python.installer.__main__.bash_wrapper", return_value=output),
|
||||
pytest.raises(RuntimeError, match="Failed to get CPU manufacturer"),
|
||||
):
|
||||
get_cpu_manufacturer()
|
||||
|
||||
|
||||
# --- tui bash_wrapper tests ---
|
||||
|
||||
|
||||
def test_tui_bash_wrapper_success() -> None:
|
||||
"""Test tui bash_wrapper success."""
|
||||
result = tui_bash_wrapper("echo hello")
|
||||
assert result.strip() == "hello"
|
||||
|
||||
|
||||
def test_tui_bash_wrapper_error() -> None:
|
||||
"""Test tui bash_wrapper raises on error."""
|
||||
with pytest.raises(RuntimeError, match="Failed to run command"):
|
||||
tui_bash_wrapper("ls /nonexistent/path/that/does/not/exist")
|
||||
|
||||
|
||||
# --- Cursor boundary tests ---
|
||||
|
||||
|
||||
def test_cursor_move_at_boundaries() -> None:
|
||||
"""Test cursor doesn't go below 0."""
|
||||
c = Cursor()
|
||||
c.set_height(10)
|
||||
c.set_width(20)
|
||||
c.set_x(0)
|
||||
c.set_y(0)
|
||||
c.move_up()
|
||||
assert c.get_y() == 0
|
||||
c.move_left()
|
||||
assert c.get_x() == 0
|
||||
|
||||
|
||||
def test_cursor_move_at_max_boundaries() -> None:
|
||||
"""Test cursor doesn't exceed max."""
|
||||
c = Cursor()
|
||||
c.set_height(5)
|
||||
c.set_width(10)
|
||||
c.set_x(9)
|
||||
c.set_y(4)
|
||||
c.move_down()
|
||||
assert c.get_y() == 4
|
||||
c.move_right()
|
||||
assert c.get_x() == 9
|
||||
|
||||
|
||||
# --- get_device additional ---
|
||||
|
||||
|
||||
def test_get_device_with_mountpoint() -> None:
|
||||
"""Test get_device with mountpoint."""
|
||||
raw = 'NAME="/dev/sda1" SIZE="512M" TYPE="part" MOUNTPOINTS="/boot"'
|
||||
device = get_device(raw)
|
||||
assert device["mountpoints"] == "/boot"
|
||||
|
||||
|
||||
# --- State additional ---
|
||||
|
||||
|
||||
def test_state_selected_devices_empty() -> None:
|
||||
"""Test State get_selected_devices when empty."""
|
||||
s = State()
|
||||
result = s.get_selected_devices()
|
||||
assert result == ()
|
||||
@@ -1,50 +0,0 @@
|
||||
"""Extended tests for python/installer/__main__.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.installer.__main__ import (
|
||||
create_zfs_datasets,
|
||||
create_zfs_pool,
|
||||
get_boot_drive_id,
|
||||
partition_disk,
|
||||
)
|
||||
|
||||
|
||||
def test_create_zfs_datasets() -> None:
|
||||
"""Test create_zfs_datasets creates expected datasets."""
|
||||
with patch("python.installer.__main__.bash_wrapper") as mock_bash:
|
||||
mock_bash.return_value = "NAME\nroot_pool\nroot_pool/root\nroot_pool/home\nroot_pool/var\nroot_pool/nix\n"
|
||||
create_zfs_datasets()
|
||||
assert mock_bash.call_count == 5 # 4 create + 1 list
|
||||
|
||||
|
||||
def test_create_zfs_datasets_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test create_zfs_datasets exits on missing datasets."""
|
||||
with (
|
||||
patch("python.installer.__main__.bash_wrapper") as mock_bash,
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
mock_bash.return_value = "NAME\nroot_pool\n"
|
||||
create_zfs_datasets()
|
||||
|
||||
|
||||
def test_create_zfs_pool_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test create_zfs_pool exits on failure."""
|
||||
with (
|
||||
patch("python.installer.__main__.bash_wrapper") as mock_bash,
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
mock_bash.return_value = "NAME\n"
|
||||
create_zfs_pool(["/dev/sda-part2"], "/mnt")
|
||||
|
||||
|
||||
def test_get_boot_drive_id() -> None:
|
||||
"""Test get_boot_drive_id extracts UUID."""
|
||||
with patch("python.installer.__main__.bash_wrapper", return_value="UUID\nABCD-1234\n"):
|
||||
result = get_boot_drive_id("/dev/sda")
|
||||
assert result == "ABCD-1234"
|
||||
@@ -1,312 +0,0 @@
|
||||
"""Additional tests for python/installer/__main__.py covering missing lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.installer.__main__ import (
|
||||
create_nix_hardware_file,
|
||||
install_nixos,
|
||||
installer,
|
||||
main,
|
||||
)
|
||||
|
||||
|
||||
# --- create_nix_hardware_file (lines 167-218) ---
|
||||
|
||||
|
||||
def test_create_nix_hardware_file_no_encrypt() -> None:
|
||||
"""Test create_nix_hardware_file without encryption."""
|
||||
with (
|
||||
patch("python.installer.__main__.get_cpu_manufacturer", return_value="amd"),
|
||||
patch("python.installer.__main__.get_boot_drive_id", return_value="ABCD-1234"),
|
||||
patch("python.installer.__main__.getrandbits", return_value=0xDEADBEEF),
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
):
|
||||
create_nix_hardware_file("/mnt", ["/dev/sda"], encrypt=None)
|
||||
|
||||
mock_path.assert_called_once_with("/mnt/etc/nixos/hardware-configuration.nix")
|
||||
written_content = mock_path.return_value.write_text.call_args[0][0]
|
||||
assert "kvm-amd" in written_content
|
||||
assert "ABCD-1234" in written_content
|
||||
assert "deadbeef" in written_content
|
||||
assert "luks" not in written_content
|
||||
|
||||
|
||||
def test_create_nix_hardware_file_with_encrypt() -> None:
|
||||
"""Test create_nix_hardware_file with encryption enabled."""
|
||||
with (
|
||||
patch("python.installer.__main__.get_cpu_manufacturer", return_value="intel"),
|
||||
patch("python.installer.__main__.get_boot_drive_id", return_value="EFGH-5678"),
|
||||
patch("python.installer.__main__.getrandbits", return_value=0x12345678),
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
):
|
||||
create_nix_hardware_file("/mnt", ["/dev/sda"], encrypt="mykey")
|
||||
|
||||
written_content = mock_path.return_value.write_text.call_args[0][0]
|
||||
assert "kvm-intel" in written_content
|
||||
assert "EFGH-5678" in written_content
|
||||
assert "12345678" in written_content
|
||||
assert "luks" in written_content
|
||||
assert "luks-root-pool-sda-part2" in written_content
|
||||
assert "bypassWorkqueues" in written_content
|
||||
assert "allowDiscards" in written_content
|
||||
|
||||
|
||||
def test_create_nix_hardware_file_content_structure() -> None:
|
||||
"""Test create_nix_hardware_file generates correct Nix structure."""
|
||||
with (
|
||||
patch("python.installer.__main__.get_cpu_manufacturer", return_value="amd"),
|
||||
patch("python.installer.__main__.get_boot_drive_id", return_value="UUID-1234"),
|
||||
patch("python.installer.__main__.getrandbits", return_value=0xAABBCCDD),
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
):
|
||||
create_nix_hardware_file("/mnt", ["/dev/sda"], encrypt=None)
|
||||
|
||||
written_content = mock_path.return_value.write_text.call_args[0][0]
|
||||
assert "{ config, lib, modulesPath, ... }:" in written_content
|
||||
assert "boot =" in written_content
|
||||
assert "fileSystems" in written_content
|
||||
assert "root_pool/root" in written_content
|
||||
assert "root_pool/home" in written_content
|
||||
assert "root_pool/var" in written_content
|
||||
assert "root_pool/nix" in written_content
|
||||
assert "networking.hostId" in written_content
|
||||
assert "x86_64-linux" in written_content
|
||||
|
||||
|
||||
# --- install_nixos (lines 221-241) ---
|
||||
|
||||
|
||||
def test_install_nixos_single_disk() -> None:
|
||||
"""Test install_nixos mounts filesystems and runs nixos-install."""
|
||||
with (
|
||||
patch("python.installer.__main__.bash_wrapper") as mock_bash,
|
||||
patch("python.installer.__main__.run") as mock_run,
|
||||
patch("python.installer.__main__.create_nix_hardware_file") as mock_hw,
|
||||
):
|
||||
install_nixos("/mnt", ["/dev/sda"], encrypt=None)
|
||||
|
||||
# 4 mount commands + 1 mkfs.vfat + 1 boot mount + 1 nixos-generate-config = 7 bash_wrapper calls
|
||||
assert mock_bash.call_count == 7
|
||||
mock_hw.assert_called_once_with("/mnt", ["/dev/sda"], None)
|
||||
mock_run.assert_called_once_with(("nixos-install", "--root", "/mnt"), check=True)
|
||||
|
||||
|
||||
def test_install_nixos_multiple_disks() -> None:
|
||||
"""Test install_nixos formats all disk EFI partitions."""
|
||||
with (
|
||||
patch("python.installer.__main__.bash_wrapper") as mock_bash,
|
||||
patch("python.installer.__main__.run") as mock_run,
|
||||
patch("python.installer.__main__.create_nix_hardware_file") as mock_hw,
|
||||
):
|
||||
install_nixos("/mnt", ["/dev/sda", "/dev/sdb"], encrypt="key")
|
||||
|
||||
# 4 mount + 2 mkfs.vfat + 1 boot mount + 1 generate-config = 8
|
||||
assert mock_bash.call_count == 8
|
||||
# Check mkfs.vfat called for both disks
|
||||
bash_calls = [str(c) for c in mock_bash.call_args_list]
|
||||
assert any("mkfs.vfat" in c and "sda" in c for c in bash_calls)
|
||||
assert any("mkfs.vfat" in c and "sdb" in c for c in bash_calls)
|
||||
mock_hw.assert_called_once_with("/mnt", ["/dev/sda", "/dev/sdb"], "key")
|
||||
|
||||
|
||||
def test_install_nixos_mounts_zfs_datasets() -> None:
|
||||
"""Test install_nixos mounts all required ZFS datasets."""
|
||||
with (
|
||||
patch("python.installer.__main__.bash_wrapper") as mock_bash,
|
||||
patch("python.installer.__main__.run"),
|
||||
patch("python.installer.__main__.create_nix_hardware_file"),
|
||||
):
|
||||
install_nixos("/mnt", ["/dev/sda"], encrypt=None)
|
||||
|
||||
bash_calls = [str(c) for c in mock_bash.call_args_list]
|
||||
assert any("root_pool/root" in c for c in bash_calls)
|
||||
assert any("root_pool/home" in c for c in bash_calls)
|
||||
assert any("root_pool/var" in c for c in bash_calls)
|
||||
assert any("root_pool/nix" in c for c in bash_calls)
|
||||
|
||||
|
||||
# --- installer (lines 244-280) ---
|
||||
|
||||
|
||||
def test_installer_no_encrypt() -> None:
|
||||
"""Test installer flow without encryption."""
|
||||
with (
|
||||
patch("python.installer.__main__.partition_disk") as mock_partition,
|
||||
patch("python.installer.__main__.Popen") as mock_popen,
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
|
||||
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
|
||||
patch("python.installer.__main__.install_nixos") as mock_install,
|
||||
):
|
||||
installer(
|
||||
disks=("/dev/sda",),
|
||||
swap_size=8,
|
||||
reserve=0,
|
||||
encrypt_key=None,
|
||||
)
|
||||
|
||||
mock_partition.assert_called_once_with("/dev/sda", 8, 0)
|
||||
mock_pool.assert_called_once_with(["/dev/sda-part2"], "/tmp/nix_install")
|
||||
mock_datasets.assert_called_once()
|
||||
mock_install.assert_called_once_with("/tmp/nix_install", ("/dev/sda",), None)
|
||||
|
||||
|
||||
def test_installer_with_encrypt() -> None:
|
||||
"""Test installer flow with encryption enabled."""
|
||||
with (
|
||||
patch("python.installer.__main__.partition_disk") as mock_partition,
|
||||
patch("python.installer.__main__.Popen") as mock_popen,
|
||||
patch("python.installer.__main__.sleep") as mock_sleep,
|
||||
patch("python.installer.__main__.run") as mock_run,
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
|
||||
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
|
||||
patch("python.installer.__main__.install_nixos") as mock_install,
|
||||
):
|
||||
installer(
|
||||
disks=("/dev/sda",),
|
||||
swap_size=8,
|
||||
reserve=10,
|
||||
encrypt_key="secret",
|
||||
)
|
||||
|
||||
mock_partition.assert_called_once_with("/dev/sda", 8, 10)
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
# cryptsetup luksFormat and luksOpen
|
||||
assert mock_run.call_count == 2
|
||||
mock_pool.assert_called_once_with(
|
||||
["/dev/mapper/luks-root-pool-sda-part2"],
|
||||
"/tmp/nix_install",
|
||||
)
|
||||
mock_datasets.assert_called_once()
|
||||
mock_install.assert_called_once_with("/tmp/nix_install", ("/dev/sda",), "secret")
|
||||
|
||||
|
||||
def test_installer_multiple_disks_no_encrypt() -> None:
|
||||
"""Test installer with multiple disks and no encryption."""
|
||||
with (
|
||||
patch("python.installer.__main__.partition_disk") as mock_partition,
|
||||
patch("python.installer.__main__.Popen") as mock_popen,
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
|
||||
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
|
||||
patch("python.installer.__main__.install_nixos") as mock_install,
|
||||
):
|
||||
installer(
|
||||
disks=("/dev/sda", "/dev/sdb"),
|
||||
swap_size=4,
|
||||
reserve=0,
|
||||
encrypt_key=None,
|
||||
)
|
||||
|
||||
assert mock_partition.call_count == 2
|
||||
mock_pool.assert_called_once_with(
|
||||
["/dev/sda-part2", "/dev/sdb-part2"],
|
||||
"/tmp/nix_install",
|
||||
)
|
||||
|
||||
|
||||
def test_installer_multiple_disks_with_encrypt() -> None:
|
||||
"""Test installer with multiple disks and encryption."""
|
||||
with (
|
||||
patch("python.installer.__main__.partition_disk") as mock_partition,
|
||||
patch("python.installer.__main__.Popen") as mock_popen,
|
||||
patch("python.installer.__main__.sleep") as mock_sleep,
|
||||
patch("python.installer.__main__.run") as mock_run,
|
||||
patch("python.installer.__main__.Path") as mock_path,
|
||||
patch("python.installer.__main__.create_zfs_pool") as mock_pool,
|
||||
patch("python.installer.__main__.create_zfs_datasets") as mock_datasets,
|
||||
patch("python.installer.__main__.install_nixos") as mock_install,
|
||||
):
|
||||
installer(
|
||||
disks=("/dev/sda", "/dev/sdb"),
|
||||
swap_size=4,
|
||||
reserve=2,
|
||||
encrypt_key="key123",
|
||||
)
|
||||
|
||||
assert mock_partition.call_count == 2
|
||||
assert mock_sleep.call_count == 2
|
||||
# 2 disks x 2 cryptsetup commands = 4
|
||||
assert mock_run.call_count == 4
|
||||
mock_pool.assert_called_once_with(
|
||||
["/dev/mapper/luks-root-pool-sda-part2", "/dev/mapper/luks-root-pool-sdb-part2"],
|
||||
"/tmp/nix_install",
|
||||
)
|
||||
|
||||
|
||||
# --- main (lines 283-299) ---
|
||||
|
||||
|
||||
def test_main_calls_installer() -> None:
|
||||
"""Test main function orchestrates TUI and installer."""
|
||||
mock_state = MagicMock()
|
||||
mock_state.selected_device_ids = {"/dev/disk/by-id/ata-DISK1"}
|
||||
mock_state.get_selected_devices.return_value = ("/dev/disk/by-id/ata-DISK1",)
|
||||
mock_state.swap_size = 8
|
||||
mock_state.reserve_size = 0
|
||||
|
||||
with (
|
||||
patch("python.installer.__main__.configure_logger"),
|
||||
patch("python.installer.__main__.curses.wrapper", return_value=mock_state),
|
||||
patch("python.installer.__main__.getenv", return_value=None),
|
||||
patch("python.installer.__main__.sleep"),
|
||||
patch("python.installer.__main__.installer") as mock_installer,
|
||||
):
|
||||
main()
|
||||
|
||||
mock_installer.assert_called_once_with(
|
||||
disks=("/dev/disk/by-id/ata-DISK1",),
|
||||
swap_size=8,
|
||||
reserve=0,
|
||||
encrypt_key=None,
|
||||
)
|
||||
|
||||
|
||||
def test_main_with_encrypt_key() -> None:
|
||||
"""Test main function passes encrypt key from environment."""
|
||||
mock_state = MagicMock()
|
||||
mock_state.selected_device_ids = {"/dev/disk/by-id/ata-DISK1"}
|
||||
mock_state.get_selected_devices.return_value = ("/dev/disk/by-id/ata-DISK1",)
|
||||
mock_state.swap_size = 16
|
||||
mock_state.reserve_size = 5
|
||||
|
||||
with (
|
||||
patch("python.installer.__main__.configure_logger"),
|
||||
patch("python.installer.__main__.curses.wrapper", return_value=mock_state),
|
||||
patch("python.installer.__main__.getenv", return_value="my_encrypt_key"),
|
||||
patch("python.installer.__main__.sleep"),
|
||||
patch("python.installer.__main__.installer") as mock_installer,
|
||||
):
|
||||
main()
|
||||
|
||||
mock_installer.assert_called_once_with(
|
||||
disks=("/dev/disk/by-id/ata-DISK1",),
|
||||
swap_size=16,
|
||||
reserve=5,
|
||||
encrypt_key="my_encrypt_key",
|
||||
)
|
||||
|
||||
|
||||
def test_main_calls_sleep() -> None:
|
||||
"""Test main function sleeps for 3 seconds before installing."""
|
||||
mock_state = MagicMock()
|
||||
mock_state.selected_device_ids = set()
|
||||
mock_state.get_selected_devices.return_value = ()
|
||||
mock_state.swap_size = 0
|
||||
mock_state.reserve_size = 0
|
||||
|
||||
with (
|
||||
patch("python.installer.__main__.configure_logger"),
|
||||
patch("python.installer.__main__.curses.wrapper", return_value=mock_state),
|
||||
patch("python.installer.__main__.getenv", return_value=None),
|
||||
patch("python.installer.__main__.sleep") as mock_sleep,
|
||||
patch("python.installer.__main__.installer"),
|
||||
):
|
||||
main()
|
||||
|
||||
mock_sleep.assert_called_once_with(3)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Extended tests for python/installer/tui.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.installer.tui import (
|
||||
Cursor,
|
||||
State,
|
||||
bash_wrapper,
|
||||
calculate_device_menu_padding,
|
||||
get_device,
|
||||
get_devices,
|
||||
status_bar,
|
||||
)
|
||||
|
||||
|
||||
def test_get_devices() -> None:
|
||||
"""Test get_devices parses lsblk output."""
|
||||
mock_output = (
|
||||
'NAME="/dev/sda" SIZE="100G" TYPE="disk" MOUNTPOINTS=""\n'
|
||||
'NAME="/dev/sda1" SIZE="512M" TYPE="part" MOUNTPOINTS="/boot"\n'
|
||||
)
|
||||
with patch("python.installer.tui.bash_wrapper", return_value=mock_output):
|
||||
devices = get_devices()
|
||||
assert len(devices) == 2
|
||||
assert devices[0]["name"] == "/dev/sda"
|
||||
assert devices[1]["name"] == "/dev/sda1"
|
||||
|
||||
|
||||
def test_calculate_device_menu_padding_with_padding() -> None:
|
||||
"""Test calculate_device_menu_padding with custom padding."""
|
||||
devices = [
|
||||
{"name": "abc", "size": "100G"},
|
||||
{"name": "abcdef", "size": "500G"},
|
||||
]
|
||||
result = calculate_device_menu_padding(devices, "name", 5)
|
||||
assert result == len("abcdef") + 5
|
||||
|
||||
|
||||
def test_calculate_device_menu_padding_zero() -> None:
|
||||
"""Test calculate_device_menu_padding with zero padding."""
|
||||
devices = [{"name": "abc"}]
|
||||
result = calculate_device_menu_padding(devices, "name", 0)
|
||||
assert result == 3
|
||||
|
||||
|
||||
def test_status_bar() -> None:
|
||||
"""Test status_bar renders without error."""
|
||||
import curses as _curses
|
||||
|
||||
mock_screen = MagicMock()
|
||||
cursor = Cursor()
|
||||
cursor.set_height(50)
|
||||
cursor.set_width(100)
|
||||
cursor.set_x(5)
|
||||
cursor.set_y(10)
|
||||
with patch.object(_curses, "color_pair", return_value=0), patch.object(_curses, "A_REVERSE", 0):
|
||||
status_bar(mock_screen, cursor, 100, 50)
|
||||
assert mock_screen.addstr.call_count > 0
|
||||
|
||||
|
||||
def test_get_device_various_formats() -> None:
|
||||
"""Test get_device with different formats."""
|
||||
raw = 'NAME="/dev/nvme0n1p1" SIZE="1T" TYPE="nvme" MOUNTPOINTS="/"'
|
||||
device = get_device(raw)
|
||||
assert device["name"] == "/dev/nvme0n1p1"
|
||||
assert device["size"] == "1T"
|
||||
assert device["type"] == "nvme"
|
||||
assert device["mountpoints"] == "/"
|
||||
@@ -1,515 +0,0 @@
|
||||
"""Additional tests for python/installer/tui.py covering missing lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import curses
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
from python.installer.tui import (
|
||||
State,
|
||||
debug_menu,
|
||||
draw_device_ids,
|
||||
draw_device_menu,
|
||||
draw_menu,
|
||||
get_device_id_mapping,
|
||||
get_text_input,
|
||||
reserve_size_input,
|
||||
set_color,
|
||||
swap_size_input,
|
||||
)
|
||||
|
||||
|
||||
# --- set_color (lines 153-156) ---
|
||||
|
||||
|
||||
def test_set_color() -> None:
|
||||
"""Test set_color initializes curses colors."""
|
||||
with (
|
||||
patch("python.installer.tui.curses.start_color") as mock_start,
|
||||
patch("python.installer.tui.curses.use_default_colors") as mock_defaults,
|
||||
patch("python.installer.tui.curses.init_pair") as mock_init_pair,
|
||||
patch.object(curses, "COLORS", 8, create=True),
|
||||
):
|
||||
set_color()
|
||||
|
||||
mock_start.assert_called_once()
|
||||
mock_defaults.assert_called_once()
|
||||
assert mock_init_pair.call_count == 8
|
||||
mock_init_pair.assert_any_call(1, 0, -1)
|
||||
mock_init_pair.assert_any_call(8, 7, -1)
|
||||
|
||||
|
||||
# --- debug_menu (lines 166-175) ---
|
||||
|
||||
|
||||
def test_debug_menu_with_key_pressed() -> None:
|
||||
"""Test debug_menu when a key has been pressed."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getmaxyx.return_value = (40, 80)
|
||||
|
||||
with patch("python.installer.tui.curses.color_pair", return_value=0):
|
||||
debug_menu(mock_screen, ord("a"))
|
||||
|
||||
# Should show width/height, key pressed, and color blocks
|
||||
assert mock_screen.addstr.call_count >= 3
|
||||
|
||||
|
||||
def test_debug_menu_no_key_pressed() -> None:
|
||||
"""Test debug_menu when no key has been pressed (key=0)."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getmaxyx.return_value = (40, 80)
|
||||
|
||||
with patch("python.installer.tui.curses.color_pair", return_value=0):
|
||||
debug_menu(mock_screen, 0)
|
||||
|
||||
# Check that "No key press detected..." is displayed
|
||||
calls = [str(c) for c in mock_screen.addstr.call_args_list]
|
||||
assert any("No key press detected" in c for c in calls)
|
||||
|
||||
|
||||
# --- get_text_input (lines 190-208) ---
|
||||
|
||||
|
||||
def test_get_text_input_enter_key() -> None:
|
||||
"""Test get_text_input returns input when Enter is pressed."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getch.side_effect = [ord("h"), ord("i"), ord("\n")]
|
||||
|
||||
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
|
||||
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
|
||||
|
||||
assert result == "hi"
|
||||
|
||||
|
||||
def test_get_text_input_escape_key() -> None:
|
||||
"""Test get_text_input returns empty string when Escape is pressed."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getch.side_effect = [ord("h"), ord("i"), 27] # 27 = ESC
|
||||
|
||||
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
|
||||
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_text_input_backspace() -> None:
|
||||
"""Test get_text_input handles backspace correctly."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getch.side_effect = [ord("h"), ord("i"), 127, ord("\n")] # 127 = backspace
|
||||
|
||||
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
|
||||
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
|
||||
|
||||
assert result == "h"
|
||||
|
||||
|
||||
def test_get_text_input_curses_backspace() -> None:
|
||||
"""Test get_text_input handles curses KEY_BACKSPACE."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getch.side_effect = [ord("a"), ord("b"), curses.KEY_BACKSPACE, ord("\n")]
|
||||
|
||||
with patch("python.installer.tui.curses.echo"), patch("python.installer.tui.curses.noecho"):
|
||||
result = get_text_input(mock_screen, "Prompt: ", 5, 0)
|
||||
|
||||
assert result == "a"
|
||||
|
||||
|
||||
# --- swap_size_input (lines 226-241) ---
|
||||
|
||||
|
||||
def test_swap_size_input_no_trigger() -> None:
|
||||
"""Test swap_size_input when not triggered (no enter on swap row)."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.key = ord("a")
|
||||
|
||||
result = swap_size_input(mock_screen, state, swap_offset=5)
|
||||
|
||||
assert result.swap_size == 0
|
||||
assert result.show_swap_input is False
|
||||
|
||||
|
||||
def test_swap_size_input_enter_triggers_input() -> None:
|
||||
"""Test swap_size_input when Enter is pressed on the swap row."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(20)
|
||||
state.cursor.set_width(80)
|
||||
state.cursor.set_y(5)
|
||||
state.key = ord("\n")
|
||||
|
||||
with patch("python.installer.tui.get_text_input", return_value="16"):
|
||||
result = swap_size_input(mock_screen, state, swap_offset=5)
|
||||
|
||||
assert result.swap_size == 16
|
||||
assert result.show_swap_input is False
|
||||
|
||||
|
||||
def test_swap_size_input_invalid_value() -> None:
|
||||
"""Test swap_size_input with invalid (non-integer) input."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(20)
|
||||
state.cursor.set_width(80)
|
||||
state.cursor.set_y(5)
|
||||
state.key = ord("\n")
|
||||
|
||||
with patch("python.installer.tui.get_text_input", return_value="abc"):
|
||||
result = swap_size_input(mock_screen, state, swap_offset=5)
|
||||
|
||||
assert result.swap_size == 0
|
||||
assert result.show_swap_input is False
|
||||
# Should have shown "Invalid input" message and waited for a key
|
||||
mock_screen.getch.assert_called_once()
|
||||
|
||||
|
||||
def test_swap_size_input_already_showing() -> None:
|
||||
"""Test swap_size_input when show_swap_input is already True."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.show_swap_input = True
|
||||
state.key = 0
|
||||
|
||||
with patch("python.installer.tui.get_text_input", return_value="8"):
|
||||
result = swap_size_input(mock_screen, state, swap_offset=5)
|
||||
|
||||
assert result.swap_size == 8
|
||||
assert result.show_swap_input is False
|
||||
|
||||
|
||||
# --- reserve_size_input (lines 259-274) ---
|
||||
|
||||
|
||||
def test_reserve_size_input_no_trigger() -> None:
|
||||
"""Test reserve_size_input when not triggered."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.key = ord("a")
|
||||
|
||||
result = reserve_size_input(mock_screen, state, reserve_offset=6)
|
||||
|
||||
assert result.reserve_size == 0
|
||||
assert result.show_reserve_input is False
|
||||
|
||||
|
||||
def test_reserve_size_input_enter_triggers_input() -> None:
|
||||
"""Test reserve_size_input when Enter is pressed on the reserve row."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(20)
|
||||
state.cursor.set_width(80)
|
||||
state.cursor.set_y(6)
|
||||
state.key = ord("\n")
|
||||
|
||||
with patch("python.installer.tui.get_text_input", return_value="32"):
|
||||
result = reserve_size_input(mock_screen, state, reserve_offset=6)
|
||||
|
||||
assert result.reserve_size == 32
|
||||
assert result.show_reserve_input is False
|
||||
|
||||
|
||||
def test_reserve_size_input_invalid_value() -> None:
|
||||
"""Test reserve_size_input with invalid (non-integer) input."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(20)
|
||||
state.cursor.set_width(80)
|
||||
state.cursor.set_y(6)
|
||||
state.key = ord("\n")
|
||||
|
||||
with patch("python.installer.tui.get_text_input", return_value="xyz"):
|
||||
result = reserve_size_input(mock_screen, state, reserve_offset=6)
|
||||
|
||||
assert result.reserve_size == 0
|
||||
assert result.show_reserve_input is False
|
||||
mock_screen.getch.assert_called_once()
|
||||
|
||||
|
||||
def test_reserve_size_input_already_showing() -> None:
|
||||
"""Test reserve_size_input when show_reserve_input is already True."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.show_reserve_input = True
|
||||
state.key = 0
|
||||
|
||||
with patch("python.installer.tui.get_text_input", return_value="10"):
|
||||
result = reserve_size_input(mock_screen, state, reserve_offset=6)
|
||||
|
||||
assert result.reserve_size == 10
|
||||
assert result.show_reserve_input is False
|
||||
|
||||
|
||||
# --- get_device_id_mapping (lines 308-316) ---
|
||||
|
||||
|
||||
def test_get_device_id_mapping() -> None:
|
||||
"""Test get_device_id_mapping returns correct mapping."""
|
||||
find_output = "/dev/disk/by-id/ata-DISK1\n/dev/disk/by-id/ata-DISK2\n"
|
||||
|
||||
def mock_bash(cmd: str) -> str:
|
||||
if cmd.startswith("find"):
|
||||
return find_output
|
||||
if "ata-DISK1" in cmd:
|
||||
return "/dev/sda\n"
|
||||
if "ata-DISK2" in cmd:
|
||||
return "/dev/sda\n"
|
||||
return ""
|
||||
|
||||
with patch("python.installer.tui.bash_wrapper", side_effect=mock_bash):
|
||||
result = get_device_id_mapping()
|
||||
|
||||
assert "/dev/sda" in result
|
||||
assert "/dev/disk/by-id/ata-DISK1" in result["/dev/sda"]
|
||||
assert "/dev/disk/by-id/ata-DISK2" in result["/dev/sda"]
|
||||
|
||||
|
||||
def test_get_device_id_mapping_multiple_devices() -> None:
|
||||
"""Test get_device_id_mapping with multiple different devices."""
|
||||
find_output = "/dev/disk/by-id/ata-DISK1\n/dev/disk/by-id/nvme-DISK2\n"
|
||||
|
||||
def mock_bash(cmd: str) -> str:
|
||||
if cmd.startswith("find"):
|
||||
return find_output
|
||||
if "ata-DISK1" in cmd:
|
||||
return "/dev/sda\n"
|
||||
if "nvme-DISK2" in cmd:
|
||||
return "/dev/nvme0n1\n"
|
||||
return ""
|
||||
|
||||
with patch("python.installer.tui.bash_wrapper", side_effect=mock_bash):
|
||||
result = get_device_id_mapping()
|
||||
|
||||
assert "/dev/sda" in result
|
||||
assert "/dev/nvme0n1" in result
|
||||
|
||||
|
||||
# --- draw_device_ids (lines 354-372) ---
|
||||
|
||||
|
||||
def test_draw_device_ids_no_selection() -> None:
|
||||
"""Test draw_device_ids without selecting any device."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.key = 0
|
||||
device_ids = {"/dev/disk/by-id/ata-DISK1"}
|
||||
menu_width = list(range(0, 60))
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
):
|
||||
result_state, result_row = draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
|
||||
|
||||
assert result_row == 3
|
||||
assert len(result_state.selected_device_ids) == 0
|
||||
|
||||
|
||||
def test_draw_device_ids_select_device() -> None:
|
||||
"""Test draw_device_ids selecting a device with space key."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.cursor.set_y(3)
|
||||
state.cursor.set_x(0)
|
||||
state.key = ord(" ")
|
||||
device_ids = {"/dev/disk/by-id/ata-DISK1"}
|
||||
menu_width = list(range(0, 60))
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
):
|
||||
result_state, result_row = draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
|
||||
|
||||
assert "/dev/disk/by-id/ata-DISK1" in result_state.selected_device_ids
|
||||
|
||||
|
||||
def test_draw_device_ids_deselect_device() -> None:
|
||||
"""Test draw_device_ids deselecting an already selected device."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.cursor.set_y(3)
|
||||
state.cursor.set_x(0)
|
||||
state.key = ord(" ")
|
||||
state.selected_device_ids.add("/dev/disk/by-id/ata-DISK1")
|
||||
device_ids = {"/dev/disk/by-id/ata-DISK1"}
|
||||
menu_width = list(range(0, 60))
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
):
|
||||
result_state, _ = draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
|
||||
|
||||
assert "/dev/disk/by-id/ata-DISK1" not in result_state.selected_device_ids
|
||||
|
||||
|
||||
def test_draw_device_ids_selected_device_color() -> None:
|
||||
"""Test draw_device_ids applies color to already selected devices."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.key = 0
|
||||
state.selected_device_ids.add("/dev/disk/by-id/ata-DISK1")
|
||||
device_ids = {"/dev/disk/by-id/ata-DISK1"}
|
||||
menu_width = list(range(0, 60))
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
patch("python.installer.tui.curses.color_pair", return_value=7) as mock_color,
|
||||
):
|
||||
draw_device_ids(state, 2, 0, mock_screen, menu_width, device_ids)
|
||||
|
||||
mock_screen.attron.assert_any_call(7)
|
||||
|
||||
|
||||
# --- draw_device_menu (lines 396-434) ---
|
||||
|
||||
|
||||
def test_draw_device_menu() -> None:
|
||||
"""Test draw_device_menu renders devices and calls draw_device_ids."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.key = 0
|
||||
|
||||
devices = [
|
||||
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
|
||||
]
|
||||
device_id_mapping = {
|
||||
"/dev/sda": {"/dev/disk/by-id/ata-DISK1"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
):
|
||||
result_state, row_number = draw_device_menu(
|
||||
mock_screen, devices, device_id_mapping, state, menu_start_y=0, menu_start_x=0
|
||||
)
|
||||
|
||||
assert mock_screen.addstr.call_count > 0
|
||||
assert row_number > 0
|
||||
|
||||
|
||||
def test_draw_device_menu_multiple_devices() -> None:
|
||||
"""Test draw_device_menu with multiple devices."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.key = 0
|
||||
|
||||
devices = [
|
||||
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
|
||||
{"name": "/dev/sdb", "size": "200G", "type": "disk", "mountpoints": ""},
|
||||
]
|
||||
device_id_mapping = {
|
||||
"/dev/sda": {"/dev/disk/by-id/ata-DISK1"},
|
||||
"/dev/sdb": {"/dev/disk/by-id/ata-DISK2"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
):
|
||||
result_state, row_number = draw_device_menu(
|
||||
mock_screen, devices, device_id_mapping, state, menu_start_y=0, menu_start_x=0
|
||||
)
|
||||
|
||||
# 2 devices + 2 device ids = at least 4 rows past the header
|
||||
assert row_number >= 4
|
||||
|
||||
|
||||
def test_draw_device_menu_no_device_ids() -> None:
|
||||
"""Test draw_device_menu when a device has no IDs."""
|
||||
mock_screen = MagicMock()
|
||||
state = State()
|
||||
state.cursor.set_height(40)
|
||||
state.cursor.set_width(80)
|
||||
state.key = 0
|
||||
|
||||
devices = [
|
||||
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
|
||||
]
|
||||
device_id_mapping: dict[str, set[str]] = {
|
||||
"/dev/sda": set(),
|
||||
}
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
patch("python.installer.tui.curses.A_BOLD", 1),
|
||||
):
|
||||
result_state, row_number = draw_device_menu(
|
||||
mock_screen, devices, device_id_mapping, state, menu_start_y=0, menu_start_x=0
|
||||
)
|
||||
|
||||
# Should still work; row_number reflects only the device row (no id rows)
|
||||
assert row_number >= 2
|
||||
|
||||
|
||||
# --- draw_menu (lines 447-498) ---
|
||||
|
||||
|
||||
def test_draw_menu_quit_immediately() -> None:
|
||||
"""Test draw_menu exits when 'q' is pressed immediately."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getmaxyx.return_value = (40, 80)
|
||||
mock_screen.getch.return_value = ord("q")
|
||||
|
||||
devices = [
|
||||
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
|
||||
]
|
||||
device_id_mapping = {"/dev/sda": {"/dev/disk/by-id/ata-DISK1"}}
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.set_color"),
|
||||
patch("python.installer.tui.get_devices", return_value=devices),
|
||||
patch("python.installer.tui.get_device_id_mapping", return_value=device_id_mapping),
|
||||
patch("python.installer.tui.draw_device_menu", return_value=(State(), 5)),
|
||||
patch("python.installer.tui.swap_size_input"),
|
||||
patch("python.installer.tui.reserve_size_input"),
|
||||
patch("python.installer.tui.status_bar"),
|
||||
patch("python.installer.tui.debug_menu"),
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
):
|
||||
result = draw_menu(mock_screen)
|
||||
|
||||
assert isinstance(result, State)
|
||||
mock_screen.clear.assert_called()
|
||||
mock_screen.refresh.assert_called()
|
||||
|
||||
|
||||
def test_draw_menu_navigation_then_quit() -> None:
|
||||
"""Test draw_menu handles navigation keys before quitting."""
|
||||
mock_screen = MagicMock()
|
||||
mock_screen.getmaxyx.return_value = (40, 80)
|
||||
# Simulate pressing down arrow then 'q'
|
||||
mock_screen.getch.side_effect = [curses.KEY_DOWN, ord("q")]
|
||||
|
||||
devices = [
|
||||
{"name": "/dev/sda", "size": "100G", "type": "disk", "mountpoints": ""},
|
||||
]
|
||||
device_id_mapping = {"/dev/sda": set()}
|
||||
|
||||
with (
|
||||
patch("python.installer.tui.set_color"),
|
||||
patch("python.installer.tui.get_devices", return_value=devices),
|
||||
patch("python.installer.tui.get_device_id_mapping", return_value=device_id_mapping),
|
||||
patch("python.installer.tui.draw_device_menu", return_value=(State(), 5)),
|
||||
patch("python.installer.tui.swap_size_input"),
|
||||
patch("python.installer.tui.reserve_size_input"),
|
||||
patch("python.installer.tui.status_bar"),
|
||||
patch("python.installer.tui.debug_menu"),
|
||||
patch("python.installer.tui.curses.color_pair", return_value=0),
|
||||
):
|
||||
result = draw_menu(mock_screen)
|
||||
|
||||
assert isinstance(result, State)
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Tests for python/orm modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import environ
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.orm.base import RichieBase, TableBase, get_connection_info, get_postgres_engine
|
||||
from python.orm.contact import ContactNeed, ContactRelationship, RelationshipType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def test_richie_base_schema_name() -> None:
|
||||
"""Test RichieBase has correct schema name."""
|
||||
assert RichieBase.schema_name == "main"
|
||||
|
||||
|
||||
def test_richie_base_metadata_naming() -> None:
|
||||
"""Test RichieBase metadata has naming conventions."""
|
||||
assert RichieBase.metadata.schema == "main"
|
||||
naming = RichieBase.metadata.naming_convention
|
||||
assert naming is not None
|
||||
assert "ix" in naming
|
||||
assert "uq" in naming
|
||||
assert "ck" in naming
|
||||
assert "fk" in naming
|
||||
assert "pk" in naming
|
||||
|
||||
|
||||
def test_table_base_abstract() -> None:
|
||||
"""Test TableBase is abstract."""
|
||||
assert TableBase.__abstract__ is True
|
||||
|
||||
|
||||
def test_get_connection_info_success() -> None:
|
||||
"""Test get_connection_info with all env vars set."""
|
||||
env = {
|
||||
"POSTGRES_DB": "testdb",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": "5432",
|
||||
"POSTGRES_USER": "testuser",
|
||||
"POSTGRES_PASSWORD": "testpass",
|
||||
}
|
||||
with patch.dict(environ, env, clear=False):
|
||||
result = get_connection_info()
|
||||
assert result == ("testdb", "localhost", "5432", "testuser", "testpass")
|
||||
|
||||
|
||||
def test_get_connection_info_no_password() -> None:
|
||||
"""Test get_connection_info with no password."""
|
||||
env = {
|
||||
"POSTGRES_DB": "testdb",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": "5432",
|
||||
"POSTGRES_USER": "testuser",
|
||||
}
|
||||
# Clear password if set
|
||||
cleaned = {k: v for k, v in environ.items() if k != "POSTGRES_PASSWORD"}
|
||||
cleaned.update(env)
|
||||
with patch.dict(environ, cleaned, clear=True):
|
||||
result = get_connection_info()
|
||||
assert result == ("testdb", "localhost", "5432", "testuser", None)
|
||||
|
||||
|
||||
def test_get_connection_info_missing_vars() -> None:
|
||||
"""Test get_connection_info raises with missing env vars."""
|
||||
with patch.dict(environ, {}, clear=True), pytest.raises(ValueError, match="Missing environment variables"):
|
||||
get_connection_info()
|
||||
|
||||
|
||||
def test_get_postgres_engine() -> None:
|
||||
"""Test get_postgres_engine creates an engine."""
|
||||
env = {
|
||||
"POSTGRES_DB": "testdb",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": "5432",
|
||||
"POSTGRES_USER": "testuser",
|
||||
"POSTGRES_PASSWORD": "testpass",
|
||||
}
|
||||
mock_engine = MagicMock()
|
||||
with patch.dict(environ, env, clear=False), patch("python.orm.base.create_engine", return_value=mock_engine):
|
||||
engine = get_postgres_engine()
|
||||
assert engine is mock_engine
|
||||
|
||||
|
||||
# --- Contact ORM tests ---
|
||||
|
||||
|
||||
def test_relationship_type_values() -> None:
|
||||
"""Test RelationshipType enum values."""
|
||||
assert RelationshipType.SPOUSE.value == "spouse"
|
||||
assert RelationshipType.OTHER.value == "other"
|
||||
|
||||
|
||||
def test_relationship_type_default_weight() -> None:
|
||||
"""Test RelationshipType default weights."""
|
||||
assert RelationshipType.SPOUSE.default_weight == 10
|
||||
assert RelationshipType.ACQUAINTANCE.default_weight == 3
|
||||
assert RelationshipType.OTHER.default_weight == 2
|
||||
assert RelationshipType.PARENT.default_weight == 9
|
||||
|
||||
|
||||
def test_relationship_type_display_name() -> None:
|
||||
"""Test RelationshipType display_name."""
|
||||
assert RelationshipType.BEST_FRIEND.display_name == "Best Friend"
|
||||
assert RelationshipType.AUNT_UNCLE.display_name == "Aunt Uncle"
|
||||
assert RelationshipType.SPOUSE.display_name == "Spouse"
|
||||
|
||||
|
||||
def test_all_relationship_types_have_weights() -> None:
|
||||
"""Test all relationship types have valid weights."""
|
||||
for rt in RelationshipType:
|
||||
weight = rt.default_weight
|
||||
assert 1 <= weight <= 10
|
||||
|
||||
|
||||
def test_contact_need_table_name() -> None:
|
||||
"""Test ContactNeed table name."""
|
||||
assert ContactNeed.__tablename__ == "contact_need"
|
||||
|
||||
|
||||
def test_contact_relationship_table_name() -> None:
|
||||
"""Test ContactRelationship table name."""
|
||||
assert ContactRelationship.__tablename__ == "contact_relationship"
|
||||
@@ -1,674 +0,0 @@
|
||||
"""Tests for python/splendor modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
from python.splendor.base import (
|
||||
BASE_COLORS,
|
||||
GEM_COLORS,
|
||||
Action,
|
||||
BuyCard,
|
||||
BuyCardReserved,
|
||||
Card,
|
||||
GameConfig,
|
||||
GameState,
|
||||
Noble,
|
||||
PlayerState,
|
||||
ReserveCard,
|
||||
TakeDifferent,
|
||||
TakeDouble,
|
||||
apply_action,
|
||||
apply_buy_card,
|
||||
apply_buy_card_reserved,
|
||||
apply_reserve_card,
|
||||
apply_take_different,
|
||||
apply_take_double,
|
||||
auto_discard_tokens,
|
||||
check_nobles_for_player,
|
||||
create_random_cards,
|
||||
create_random_cards_tier,
|
||||
create_random_nobles,
|
||||
enforce_token_limit,
|
||||
get_default_starting_tokens,
|
||||
get_legal_actions,
|
||||
load_cards,
|
||||
load_nobles,
|
||||
new_game,
|
||||
run_game,
|
||||
)
|
||||
from python.splendor.bot import (
|
||||
PersonalizedBot,
|
||||
PersonalizedBot2,
|
||||
RandomBot,
|
||||
buy_card,
|
||||
buy_card_reserved,
|
||||
can_bot_afford,
|
||||
check_cards_in_tier,
|
||||
take_tokens,
|
||||
)
|
||||
from python.splendor.public_state import (
|
||||
Observation,
|
||||
ObsCard,
|
||||
ObsNoble,
|
||||
ObsPlayer,
|
||||
to_observation,
|
||||
_encode_card,
|
||||
_encode_noble,
|
||||
_encode_player,
|
||||
)
|
||||
from python.splendor.sim import SimStrategy, simulate_step
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# --- Helper to create a simple game ---
|
||||
|
||||
|
||||
def _make_card(tier: int = 1, points: int = 0, color: str = "white", cost: dict | None = None) -> Card:
|
||||
if cost is None:
|
||||
cost = dict.fromkeys(GEM_COLORS, 0)
|
||||
return Card(tier=tier, points=points, color=color, cost=cost)
|
||||
|
||||
|
||||
def _make_noble(name: str = "Noble", points: int = 3, reqs: dict | None = None) -> Noble:
|
||||
if reqs is None:
|
||||
reqs = {"white": 3, "blue": 3, "green": 3}
|
||||
return Noble(name=name, points=points, requirements=reqs)
|
||||
|
||||
|
||||
def _make_game(num_players: int = 2) -> tuple[GameState, list[RandomBot]]:
|
||||
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
return game, bots
|
||||
|
||||
|
||||
# --- PlayerState tests ---
|
||||
|
||||
|
||||
def test_player_state_defaults() -> None:
|
||||
"""Test PlayerState default values."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
assert p.total_tokens() == 0
|
||||
assert p.score == 0
|
||||
assert p.card_score == 0
|
||||
assert p.noble_score == 0
|
||||
|
||||
|
||||
def test_player_add_card() -> None:
|
||||
"""Test adding a card to player."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
card = _make_card(points=3)
|
||||
p.add_card(card)
|
||||
assert len(p.cards) == 1
|
||||
assert p.card_score == 3
|
||||
assert p.score == 3
|
||||
|
||||
|
||||
def test_player_add_noble() -> None:
|
||||
"""Test adding a noble to player."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
noble = _make_noble(points=3)
|
||||
p.add_noble(noble)
|
||||
assert len(p.nobles) == 1
|
||||
assert p.noble_score == 3
|
||||
assert p.score == 3
|
||||
|
||||
|
||||
def test_player_can_afford_free_card() -> None:
|
||||
"""Test can_afford with a free card."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
assert p.can_afford(card) is True
|
||||
|
||||
|
||||
def test_player_can_afford_with_tokens() -> None:
|
||||
"""Test can_afford with tokens."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["white"] = 3
|
||||
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 3})
|
||||
assert p.can_afford(card) is True
|
||||
|
||||
|
||||
def test_player_cannot_afford() -> None:
|
||||
"""Test can_afford returns False when not enough."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 5})
|
||||
assert p.can_afford(card) is False
|
||||
|
||||
|
||||
def test_player_can_afford_with_gold() -> None:
|
||||
"""Test can_afford uses gold tokens."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["gold"] = 3
|
||||
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 3})
|
||||
assert p.can_afford(card) is True
|
||||
|
||||
|
||||
def test_player_pay_for_card() -> None:
|
||||
"""Test pay_for_card transfers tokens."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["white"] = 3
|
||||
card = _make_card(color="white", cost={**dict.fromkeys(GEM_COLORS, 0), "white": 2})
|
||||
payment = p.pay_for_card(card)
|
||||
assert payment["white"] == 2
|
||||
assert p.tokens["white"] == 1
|
||||
assert len(p.cards) == 1
|
||||
assert p.discounts["white"] == 1
|
||||
|
||||
|
||||
def test_player_pay_for_card_cannot_afford() -> None:
|
||||
"""Test pay_for_card raises when cannot afford."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 5})
|
||||
with pytest.raises(ValueError, match="cannot afford"):
|
||||
p.pay_for_card(card)
|
||||
|
||||
|
||||
# --- GameState tests ---
|
||||
|
||||
|
||||
def test_get_default_starting_tokens() -> None:
|
||||
"""Test starting token counts."""
|
||||
tokens = get_default_starting_tokens(2)
|
||||
assert tokens["gold"] == 5
|
||||
assert tokens["white"] == 4 # (4-6+10)//2 = 4
|
||||
|
||||
tokens = get_default_starting_tokens(3)
|
||||
assert tokens["white"] == 5
|
||||
|
||||
tokens = get_default_starting_tokens(4)
|
||||
assert tokens["white"] == 7
|
||||
|
||||
|
||||
def test_new_game() -> None:
|
||||
"""Test new_game creates valid state."""
|
||||
game, _ = _make_game(2)
|
||||
assert len(game.players) == 2
|
||||
assert game.bank["gold"] == 5
|
||||
assert len(game.available_nobles) == 3 # 2 players + 1
|
||||
|
||||
|
||||
def test_game_next_player() -> None:
|
||||
"""Test next_player cycles."""
|
||||
game, _ = _make_game(2)
|
||||
assert game.current_player_index == 0
|
||||
game.next_player()
|
||||
assert game.current_player_index == 1
|
||||
game.next_player()
|
||||
assert game.current_player_index == 0
|
||||
|
||||
|
||||
def test_game_current_player() -> None:
|
||||
"""Test current_player property."""
|
||||
game, _ = _make_game(2)
|
||||
assert game.current_player is game.players[0]
|
||||
|
||||
|
||||
def test_game_check_winner_simple_no_winner() -> None:
|
||||
"""Test check_winner_simple with no winner."""
|
||||
game, _ = _make_game(2)
|
||||
assert game.check_winner_simple() is None
|
||||
|
||||
|
||||
def test_game_check_winner_simple_winner() -> None:
|
||||
"""Test check_winner_simple with winner."""
|
||||
game, _ = _make_game(2)
|
||||
# Give player enough points
|
||||
for _ in range(15):
|
||||
game.players[0].add_card(_make_card(points=1))
|
||||
winner = game.check_winner_simple()
|
||||
assert winner is game.players[0]
|
||||
assert game.finished is True
|
||||
|
||||
|
||||
def test_game_refill_table() -> None:
|
||||
"""Test refill_table fills from decks."""
|
||||
game, _ = _make_game(2)
|
||||
# Table should be filled initially
|
||||
for tier in (1, 2, 3):
|
||||
assert len(game.table_by_tier[tier]) <= game.config.table_cards_per_tier
|
||||
|
||||
|
||||
# --- Action tests ---
|
||||
|
||||
|
||||
def test_apply_take_different() -> None:
|
||||
"""Test take different colors."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
action = TakeDifferent(colors=["white", "blue", "green"])
|
||||
apply_take_different(game, strategy, action)
|
||||
p = game.players[0]
|
||||
assert p.tokens["white"] == 1
|
||||
assert p.tokens["blue"] == 1
|
||||
assert p.tokens["green"] == 1
|
||||
|
||||
|
||||
def test_apply_take_different_invalid() -> None:
|
||||
"""Test take different with too many colors is truncated."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
# 4 colors should be rejected
|
||||
action = TakeDifferent(colors=["white", "blue", "green", "red"])
|
||||
apply_take_different(game, strategy, action)
|
||||
|
||||
|
||||
def test_apply_take_double() -> None:
|
||||
"""Test take double."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
action = TakeDouble(color="white")
|
||||
apply_take_double(game, strategy, action)
|
||||
p = game.players[0]
|
||||
assert p.tokens["white"] == 2
|
||||
|
||||
|
||||
def test_apply_take_double_insufficient() -> None:
|
||||
"""Test take double fails when bank has insufficient."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
game.bank["white"] = 2 # Below minimum_tokens_to_buy_2
|
||||
action = TakeDouble(color="white")
|
||||
apply_take_double(game, strategy, action)
|
||||
p = game.players[0]
|
||||
assert p.tokens["white"] == 0 # No change
|
||||
|
||||
|
||||
def test_apply_buy_card() -> None:
|
||||
"""Test buy a card."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
# Give the player enough tokens
|
||||
game.players[0].tokens["white"] = 10
|
||||
game.players[0].tokens["blue"] = 10
|
||||
game.players[0].tokens["green"] = 10
|
||||
game.players[0].tokens["red"] = 10
|
||||
game.players[0].tokens["black"] = 10
|
||||
|
||||
if game.table_by_tier[1]:
|
||||
action = BuyCard(tier=1, index=0)
|
||||
apply_buy_card(game, strategy, action)
|
||||
|
||||
|
||||
def test_apply_buy_card_reserved() -> None:
|
||||
"""Test buy a reserved card."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
game.players[0].reserved.append(card)
|
||||
|
||||
action = BuyCardReserved(index=0)
|
||||
apply_buy_card_reserved(game, strategy, action)
|
||||
assert len(game.players[0].reserved) == 0
|
||||
assert len(game.players[0].cards) == 1
|
||||
|
||||
|
||||
def test_apply_reserve_card_from_table() -> None:
|
||||
"""Test reserve a card from table."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
if game.table_by_tier[1]:
|
||||
action = ReserveCard(tier=1, index=0, from_deck=False)
|
||||
apply_reserve_card(game, strategy, action)
|
||||
assert len(game.players[0].reserved) == 1
|
||||
|
||||
|
||||
def test_apply_reserve_card_from_deck() -> None:
|
||||
"""Test reserve a card from deck."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
action = ReserveCard(tier=1, index=None, from_deck=True)
|
||||
apply_reserve_card(game, strategy, action)
|
||||
assert len(game.players[0].reserved) == 1
|
||||
|
||||
|
||||
def test_apply_reserve_card_limit() -> None:
|
||||
"""Test reserve limit."""
|
||||
game, bots = _make_game(2)
|
||||
strategy = bots[0]
|
||||
# Fill reserves
|
||||
for _ in range(3):
|
||||
game.players[0].reserved.append(_make_card())
|
||||
action = ReserveCard(tier=1, index=0, from_deck=False)
|
||||
apply_reserve_card(game, strategy, action)
|
||||
assert len(game.players[0].reserved) == 3 # No change
|
||||
|
||||
|
||||
def test_apply_action_unknown_type() -> None:
|
||||
"""Test apply_action with unknown action type."""
|
||||
|
||||
class FakeAction(Action):
|
||||
pass
|
||||
|
||||
game, bots = _make_game(2)
|
||||
with pytest.raises(ValueError, match="Unknown action type"):
|
||||
apply_action(game, bots[0], FakeAction())
|
||||
|
||||
|
||||
def test_apply_action_dispatches() -> None:
|
||||
"""Test apply_action dispatches to correct handler."""
|
||||
game, bots = _make_game(2)
|
||||
action = TakeDifferent(colors=["white"])
|
||||
apply_action(game, bots[0], action)
|
||||
|
||||
|
||||
# --- auto_discard_tokens ---
|
||||
|
||||
|
||||
def test_auto_discard_tokens() -> None:
|
||||
"""Test auto_discard_tokens."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["white"] = 5
|
||||
p.tokens["blue"] = 3
|
||||
discards = auto_discard_tokens(p, 2)
|
||||
assert sum(discards.values()) == 2
|
||||
|
||||
|
||||
# --- enforce_token_limit ---
|
||||
|
||||
|
||||
def test_enforce_token_limit_under() -> None:
|
||||
"""Test enforce_token_limit when under limit."""
|
||||
game, bots = _make_game(2)
|
||||
p = game.players[0]
|
||||
p.tokens["white"] = 3
|
||||
enforce_token_limit(game, bots[0], p)
|
||||
assert p.tokens["white"] == 3 # No change
|
||||
|
||||
|
||||
def test_enforce_token_limit_over() -> None:
|
||||
"""Test enforce_token_limit when over limit."""
|
||||
game, bots = _make_game(2)
|
||||
p = game.players[0]
|
||||
for color in BASE_COLORS:
|
||||
p.tokens[color] = 5
|
||||
enforce_token_limit(game, bots[0], p)
|
||||
assert p.total_tokens() <= game.config.token_limit
|
||||
|
||||
|
||||
# --- check_nobles_for_player ---
|
||||
|
||||
|
||||
def test_check_nobles_no_qualification() -> None:
|
||||
"""Test check_nobles when player doesn't qualify."""
|
||||
game, bots = _make_game(2)
|
||||
check_nobles_for_player(game, bots[0], game.players[0])
|
||||
assert len(game.players[0].nobles) == 0
|
||||
|
||||
|
||||
def test_check_nobles_qualification() -> None:
|
||||
"""Test check_nobles when player qualifies."""
|
||||
game, bots = _make_game(2)
|
||||
p = game.players[0]
|
||||
# Give enough discounts to qualify for ALL nobles (ensures at least one match)
|
||||
for color in BASE_COLORS:
|
||||
p.discounts[color] = 10
|
||||
check_nobles_for_player(game, bots[0], p)
|
||||
assert len(p.nobles) >= 1
|
||||
|
||||
|
||||
# --- get_legal_actions ---
|
||||
|
||||
|
||||
def test_get_legal_actions() -> None:
|
||||
"""Test get_legal_actions returns valid actions."""
|
||||
game, _ = _make_game(2)
|
||||
actions = get_legal_actions(game)
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_get_legal_actions_explicit_player() -> None:
|
||||
"""Test get_legal_actions with explicit player."""
|
||||
game, _ = _make_game(2)
|
||||
actions = get_legal_actions(game, game.players[1])
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
# --- create_random helpers ---
|
||||
|
||||
|
||||
def test_create_random_cards() -> None:
|
||||
"""Test create_random_cards."""
|
||||
random.seed(42)
|
||||
cards = create_random_cards()
|
||||
assert len(cards) > 0
|
||||
tiers = {c.tier for c in cards}
|
||||
assert tiers == {1, 2, 3}
|
||||
|
||||
|
||||
def test_create_random_cards_tier() -> None:
|
||||
"""Test create_random_cards_tier."""
|
||||
cards = create_random_cards_tier(1, 3, [0, 1], [0, 1])
|
||||
assert len(cards) == 15 # 5 colors * 3 per color
|
||||
|
||||
|
||||
def test_create_random_nobles() -> None:
|
||||
"""Test create_random_nobles."""
|
||||
nobles = create_random_nobles()
|
||||
assert len(nobles) == 8
|
||||
assert all(n.points == 3 for n in nobles)
|
||||
|
||||
|
||||
# --- load_cards / load_nobles ---
|
||||
|
||||
|
||||
def test_load_cards(tmp_path: Path) -> None:
|
||||
"""Test load_cards from file."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
cards_data = [
|
||||
{"tier": 1, "points": 0, "color": "white", "cost": {"white": 0, "blue": 1}},
|
||||
]
|
||||
file = tmp_path / "cards.json"
|
||||
file.write_text(json.dumps(cards_data))
|
||||
cards = load_cards(file)
|
||||
assert len(cards) == 1
|
||||
|
||||
|
||||
def test_load_nobles(tmp_path: Path) -> None:
|
||||
"""Test load_nobles from file."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
nobles_data = [
|
||||
{"name": "Noble 1", "points": 3, "requirements": {"white": 3, "blue": 3}},
|
||||
]
|
||||
file = tmp_path / "nobles.json"
|
||||
file.write_text(json.dumps(nobles_data))
|
||||
nobles = load_nobles(file)
|
||||
assert len(nobles) == 1
|
||||
|
||||
|
||||
# --- run_game ---
|
||||
|
||||
|
||||
def test_run_game() -> None:
|
||||
"""Test run_game completes."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game(2)
|
||||
winner, turns = run_game(game)
|
||||
assert winner is not None
|
||||
assert turns > 0
|
||||
|
||||
|
||||
def test_run_game_concede() -> None:
|
||||
"""Test run_game handles player conceding."""
|
||||
|
||||
class ConcedingBot(RandomBot):
|
||||
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
|
||||
return None
|
||||
|
||||
bots = [ConcedingBot("bot1"), RandomBot("bot2")]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
winner, turns = run_game(game)
|
||||
assert winner is not None
|
||||
|
||||
|
||||
# --- Bot tests ---
|
||||
|
||||
|
||||
def test_random_bot_choose_action() -> None:
|
||||
"""Test RandomBot.choose_action returns valid action."""
|
||||
random.seed(42)
|
||||
game, bots = _make_game(2)
|
||||
action = bots[0].choose_action(game, game.players[0])
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_personalized_bot_choose_action() -> None:
|
||||
"""Test PersonalizedBot.choose_action."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot("pbot")
|
||||
game, _ = _make_game(2)
|
||||
game.players[0].strategy = bot
|
||||
action = bot.choose_action(game, game.players[0])
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_personalized_bot2_choose_action() -> None:
|
||||
"""Test PersonalizedBot2.choose_action."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot2("pbot2")
|
||||
game, _ = _make_game(2)
|
||||
game.players[0].strategy = bot
|
||||
action = bot.choose_action(game, game.players[0])
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_can_bot_afford() -> None:
|
||||
"""Test can_bot_afford function."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
assert can_bot_afford(p, card) is True
|
||||
|
||||
|
||||
def test_check_cards_in_tier() -> None:
|
||||
"""Test check_cards_in_tier."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
free_card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
expensive_card = _make_card(cost={**dict.fromkeys(GEM_COLORS, 0), "white": 10})
|
||||
result = check_cards_in_tier([free_card, expensive_card], p)
|
||||
assert result == [0]
|
||||
|
||||
|
||||
def test_buy_card_function() -> None:
|
||||
"""Test buy_card helper function."""
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
# Give player enough tokens
|
||||
for c in BASE_COLORS:
|
||||
p.tokens[c] = 10
|
||||
result = buy_card(game, p)
|
||||
assert result is not None or True # May or may not find affordable card
|
||||
|
||||
|
||||
def test_buy_card_reserved_function() -> None:
|
||||
"""Test buy_card_reserved helper function."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
# No reserved cards
|
||||
assert buy_card_reserved(p) is None
|
||||
|
||||
# With affordable reserved card
|
||||
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
p.reserved.append(card)
|
||||
result = buy_card_reserved(p)
|
||||
assert isinstance(result, BuyCardReserved)
|
||||
|
||||
|
||||
def test_take_tokens_function() -> None:
|
||||
"""Test take_tokens helper function."""
|
||||
game, _ = _make_game(2)
|
||||
result = take_tokens(game)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_take_tokens_empty_bank() -> None:
|
||||
"""Test take_tokens with empty bank."""
|
||||
game, _ = _make_game(2)
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
result = take_tokens(game)
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- public_state tests ---
|
||||
|
||||
|
||||
def test_encode_card() -> None:
|
||||
"""Test _encode_card."""
|
||||
card = _make_card(tier=1, points=2, color="blue", cost={"white": 1, "blue": 2})
|
||||
obs = _encode_card(card)
|
||||
assert isinstance(obs, ObsCard)
|
||||
assert obs.tier == 1
|
||||
assert obs.points == 2
|
||||
|
||||
|
||||
def test_encode_noble() -> None:
|
||||
"""Test _encode_noble."""
|
||||
noble = _make_noble(points=3, reqs={"white": 3, "blue": 3, "green": 3})
|
||||
obs = _encode_noble(noble)
|
||||
assert isinstance(obs, ObsNoble)
|
||||
assert obs.points == 3
|
||||
|
||||
|
||||
def test_encode_player() -> None:
|
||||
"""Test _encode_player."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
obs = _encode_player(p)
|
||||
assert isinstance(obs, ObsPlayer)
|
||||
assert obs.score == 0
|
||||
|
||||
|
||||
def test_to_observation() -> None:
|
||||
"""Test to_observation creates full observation."""
|
||||
game, _ = _make_game(2)
|
||||
obs = to_observation(game)
|
||||
assert isinstance(obs, Observation)
|
||||
assert len(obs.players) == 2
|
||||
assert obs.current_player == 0
|
||||
|
||||
|
||||
# --- sim tests ---
|
||||
|
||||
|
||||
def test_sim_strategy_choose_action_raises() -> None:
|
||||
"""Test SimStrategy.choose_action raises."""
|
||||
sim = SimStrategy("sim")
|
||||
game, _ = _make_game(2)
|
||||
with pytest.raises(RuntimeError, match="should not be used"):
|
||||
sim.choose_action(game, game.players[0])
|
||||
|
||||
|
||||
def test_simulate_step() -> None:
|
||||
"""Test simulate_step returns deep copy."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game(2)
|
||||
action = TakeDifferent(colors=["white", "blue", "green"])
|
||||
# SimStrategy() in source is missing name arg - patch it
|
||||
with patch("python.splendor.sim.SimStrategy", lambda: SimStrategy("sim")):
|
||||
next_state = simulate_step(game, action)
|
||||
assert next_state is not game
|
||||
assert next_state.current_player_index != game.current_player_index or len(game.players) == 1
|
||||
@@ -1,246 +0,0 @@
|
||||
"""Extra tests for splendor/base.py covering missed lines and branches."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from python.splendor.base import (
|
||||
BASE_COLORS,
|
||||
GEM_COLORS,
|
||||
BuyCard,
|
||||
BuyCardReserved,
|
||||
Card,
|
||||
GameConfig,
|
||||
Noble,
|
||||
ReserveCard,
|
||||
TakeDifferent,
|
||||
TakeDouble,
|
||||
apply_action,
|
||||
apply_buy_card,
|
||||
apply_buy_card_reserved,
|
||||
apply_reserve_card,
|
||||
apply_take_different,
|
||||
apply_take_double,
|
||||
auto_discard_tokens,
|
||||
check_nobles_for_player,
|
||||
create_random_cards,
|
||||
create_random_nobles,
|
||||
enforce_token_limit,
|
||||
get_legal_actions,
|
||||
new_game,
|
||||
run_game,
|
||||
)
|
||||
from python.splendor.bot import RandomBot
|
||||
|
||||
|
||||
def _make_game(num_players: int = 2):
|
||||
random.seed(42)
|
||||
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
return game, bots
|
||||
|
||||
|
||||
def test_auto_discard_tokens_all_zero() -> None:
|
||||
"""Test auto_discard when all tokens are zero."""
|
||||
game, _ = _make_game()
|
||||
p = game.players[0]
|
||||
for c in GEM_COLORS:
|
||||
p.tokens[c] = 0
|
||||
result = auto_discard_tokens(p, 3)
|
||||
assert sum(result.values()) == 0 # Can't discard from empty
|
||||
|
||||
|
||||
def test_enforce_token_limit_with_fallback() -> None:
|
||||
"""Test enforce_token_limit uses auto_discard as fallback."""
|
||||
game, bots = _make_game()
|
||||
p = game.players[0]
|
||||
strategy = bots[0]
|
||||
# Give player many tokens to force discard
|
||||
for c in BASE_COLORS:
|
||||
p.tokens[c] = 5
|
||||
enforce_token_limit(game, strategy, p)
|
||||
assert p.total_tokens() <= game.config.token_limit
|
||||
|
||||
|
||||
def test_apply_take_different_invalid_color() -> None:
|
||||
"""Test take different with gold (non-base) color."""
|
||||
game, bots = _make_game()
|
||||
action = TakeDifferent(colors=["gold"])
|
||||
apply_take_different(game, bots[0], action)
|
||||
# Gold is not in BASE_COLORS, so no tokens should be taken
|
||||
|
||||
|
||||
def test_apply_take_double_invalid_color() -> None:
|
||||
"""Test take double with gold (non-base) color."""
|
||||
game, bots = _make_game()
|
||||
action = TakeDouble(color="gold")
|
||||
apply_take_double(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_take_double_insufficient_bank() -> None:
|
||||
"""Test take double when bank has fewer than minimum."""
|
||||
game, bots = _make_game()
|
||||
game.bank["white"] = 2 # Below minimum_tokens_to_buy_2 (4)
|
||||
action = TakeDouble(color="white")
|
||||
apply_take_double(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_buy_card_invalid_tier() -> None:
|
||||
"""Test buy card with invalid tier."""
|
||||
game, bots = _make_game()
|
||||
action = BuyCard(tier=99, index=0)
|
||||
apply_buy_card(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_buy_card_invalid_index() -> None:
|
||||
"""Test buy card with out-of-range index."""
|
||||
game, bots = _make_game()
|
||||
action = BuyCard(tier=1, index=99)
|
||||
apply_buy_card(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_buy_card_cannot_afford() -> None:
|
||||
"""Test buy card when player can't afford."""
|
||||
game, bots = _make_game()
|
||||
# Zero out all tokens
|
||||
for c in GEM_COLORS:
|
||||
game.players[0].tokens[c] = 0
|
||||
# Find an expensive card
|
||||
for tier, row in game.table_by_tier.items():
|
||||
for idx, card in enumerate(row):
|
||||
if any(v > 0 for v in card.cost.values()):
|
||||
action = BuyCard(tier=tier, index=idx)
|
||||
apply_buy_card(game, bots[0], action)
|
||||
return
|
||||
|
||||
|
||||
def test_apply_buy_card_reserved_invalid_index() -> None:
|
||||
"""Test buy reserved card with out-of-range index."""
|
||||
game, bots = _make_game()
|
||||
action = BuyCardReserved(index=99)
|
||||
apply_buy_card_reserved(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_buy_card_reserved_cannot_afford() -> None:
|
||||
"""Test buy reserved card when can't afford."""
|
||||
game, bots = _make_game()
|
||||
expensive = Card(tier=3, points=5, color="white", cost={
|
||||
"white": 10, "blue": 10, "green": 10, "red": 10, "black": 10, "gold": 0
|
||||
})
|
||||
game.players[0].reserved.append(expensive)
|
||||
for c in GEM_COLORS:
|
||||
game.players[0].tokens[c] = 0
|
||||
action = BuyCardReserved(index=0)
|
||||
apply_buy_card_reserved(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_reserve_card_at_limit() -> None:
|
||||
"""Test reserve card when at reserve limit."""
|
||||
game, bots = _make_game()
|
||||
p = game.players[0]
|
||||
# Fill up reserved slots
|
||||
for _ in range(game.config.reserve_limit):
|
||||
p.reserved.append(Card(tier=1, points=0, color="white", cost=dict.fromkeys(GEM_COLORS, 0)))
|
||||
action = ReserveCard(tier=1, index=0, from_deck=False)
|
||||
apply_reserve_card(game, bots[0], action)
|
||||
assert len(p.reserved) == game.config.reserve_limit
|
||||
|
||||
|
||||
def test_apply_reserve_card_invalid_tier() -> None:
|
||||
"""Test reserve face-up card with invalid tier."""
|
||||
game, bots = _make_game()
|
||||
action = ReserveCard(tier=99, index=0, from_deck=False)
|
||||
apply_reserve_card(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_reserve_card_invalid_index() -> None:
|
||||
"""Test reserve face-up card with None index."""
|
||||
game, bots = _make_game()
|
||||
action = ReserveCard(tier=1, index=None, from_deck=False)
|
||||
apply_reserve_card(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_reserve_card_from_empty_deck() -> None:
|
||||
"""Test reserve from deck when deck is empty."""
|
||||
game, bots = _make_game()
|
||||
game.decks_by_tier[1] = [] # Empty the deck
|
||||
action = ReserveCard(tier=1, index=None, from_deck=True)
|
||||
apply_reserve_card(game, bots[0], action)
|
||||
|
||||
|
||||
def test_apply_reserve_card_no_gold() -> None:
|
||||
"""Test reserve card when bank has no gold."""
|
||||
game, bots = _make_game()
|
||||
game.bank["gold"] = 0
|
||||
action = ReserveCard(tier=1, index=0, from_deck=True)
|
||||
reserved_before = len(game.players[0].reserved)
|
||||
apply_reserve_card(game, bots[0], action)
|
||||
if len(game.players[0].reserved) > reserved_before:
|
||||
assert game.players[0].tokens["gold"] == 0
|
||||
|
||||
|
||||
def test_check_nobles_multiple_candidates() -> None:
|
||||
"""Test check_nobles when player qualifies for multiple nobles."""
|
||||
game, bots = _make_game()
|
||||
p = game.players[0]
|
||||
# Give player huge discounts to qualify for everything
|
||||
for c in BASE_COLORS:
|
||||
p.discounts[c] = 20
|
||||
check_nobles_for_player(game, bots[0], p)
|
||||
|
||||
|
||||
def test_check_nobles_chosen_not_in_available() -> None:
|
||||
"""Test check_nobles when chosen noble is somehow not available."""
|
||||
game, bots = _make_game()
|
||||
p = game.players[0]
|
||||
for c in BASE_COLORS:
|
||||
p.discounts[c] = 20
|
||||
# This tests the normal path - chosen should be in available
|
||||
|
||||
|
||||
def test_run_game_turn_limit() -> None:
|
||||
"""Test run_game respects turn limit."""
|
||||
random.seed(99)
|
||||
bots = [RandomBot(f"bot{i}") for i in range(2)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles, turn_limit=5)
|
||||
game = new_game(bots, config)
|
||||
winner, turns = run_game(game)
|
||||
assert turns <= 5
|
||||
|
||||
|
||||
def test_run_game_action_none() -> None:
|
||||
"""Test run_game stops when strategy returns None."""
|
||||
from unittest.mock import MagicMock
|
||||
bots = [RandomBot(f"bot{i}") for i in range(2)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
# Make the first player's strategy return None
|
||||
game.players[0].strategy.choose_action = MagicMock(return_value=None)
|
||||
winner, turns = run_game(game)
|
||||
assert turns == 1
|
||||
|
||||
|
||||
def test_get_valid_actions_with_reserved() -> None:
|
||||
"""Test get_valid_actions includes BuyCardReserved when player has reserved cards."""
|
||||
game, _ = _make_game()
|
||||
p = game.players[0]
|
||||
# Give player a free reserved card
|
||||
free_card = Card(tier=1, points=0, color="white", cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
p.reserved.append(free_card)
|
||||
actions = get_legal_actions(game)
|
||||
assert any(isinstance(a, BuyCardReserved) for a in actions)
|
||||
|
||||
|
||||
def test_get_legal_actions_reserve_from_deck() -> None:
|
||||
"""Test get_legal_actions includes ReserveCard from deck."""
|
||||
game, _ = _make_game()
|
||||
actions = get_legal_actions(game)
|
||||
assert any(isinstance(a, ReserveCard) and a.from_deck for a in actions)
|
||||
assert any(isinstance(a, ReserveCard) and not a.from_deck for a in actions)
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Tests for PersonalizedBot3 and PersonalizedBot4 edge cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from python.splendor.base import (
|
||||
BASE_COLORS,
|
||||
GEM_COLORS,
|
||||
BuyCard,
|
||||
Card,
|
||||
GameConfig,
|
||||
GameState,
|
||||
PlayerState,
|
||||
ReserveCard,
|
||||
TakeDifferent,
|
||||
create_random_cards,
|
||||
create_random_nobles,
|
||||
new_game,
|
||||
run_game,
|
||||
)
|
||||
from python.splendor.bot import (
|
||||
PersonalizedBot2,
|
||||
PersonalizedBot3,
|
||||
PersonalizedBot4,
|
||||
RandomBot,
|
||||
)
|
||||
|
||||
|
||||
def _make_card(tier: int = 1, points: int = 0, color: str = "white", cost: dict | None = None) -> Card:
|
||||
if cost is None:
|
||||
cost = dict.fromkeys(GEM_COLORS, 0)
|
||||
return Card(tier=tier, points=points, color=color, cost=cost)
|
||||
|
||||
|
||||
def _make_game(bots: list) -> GameState:
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles, turn_limit=100)
|
||||
return new_game(bots, config)
|
||||
|
||||
|
||||
def test_personalized_bot3_reserves_from_deck() -> None:
|
||||
"""Test PersonalizedBot3 reserves from deck when no tokens."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot3("pbot3")
|
||||
game = _make_game([bot, RandomBot("r")])
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
|
||||
# Clear bank to force reserve
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
# Clear table to prevent buys
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
|
||||
action = bot.choose_action(game, p)
|
||||
assert isinstance(action, (ReserveCard, TakeDifferent))
|
||||
|
||||
|
||||
def test_personalized_bot3_fallback_take_different() -> None:
|
||||
"""Test PersonalizedBot3 falls back to TakeDifferent."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot3("pbot3")
|
||||
game = _make_game([bot, RandomBot("r")])
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
|
||||
# Empty everything
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
game.decks_by_tier[tier] = []
|
||||
|
||||
action = bot.choose_action(game, p)
|
||||
assert isinstance(action, TakeDifferent)
|
||||
|
||||
|
||||
def test_personalized_bot4_reserves_from_deck() -> None:
|
||||
"""Test PersonalizedBot4 reserves from deck."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot4("pbot4")
|
||||
game = _make_game([bot, RandomBot("r")])
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
|
||||
action = bot.choose_action(game, p)
|
||||
assert isinstance(action, (ReserveCard, TakeDifferent))
|
||||
|
||||
|
||||
def test_personalized_bot4_fallback() -> None:
|
||||
"""Test PersonalizedBot4 fallback with empty everything."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot4("pbot4")
|
||||
game = _make_game([bot, RandomBot("r")])
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
game.decks_by_tier[tier] = []
|
||||
|
||||
action = bot.choose_action(game, p)
|
||||
assert isinstance(action, TakeDifferent)
|
||||
|
||||
|
||||
def test_personalized_bot2_fallback_empty_colors() -> None:
|
||||
"""Test PersonalizedBot2 with very few available colors."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot2("pbot2")
|
||||
game = _make_game([bot, RandomBot("r")])
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
|
||||
# No table cards, no affordable reserved
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
# Set exactly 2 colors
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
game.bank["white"] = 1
|
||||
game.bank["blue"] = 1
|
||||
|
||||
action = bot.choose_action(game, p)
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_full_game_with_bot3_and_bot4() -> None:
|
||||
"""Test a full game with bot3 and bot4."""
|
||||
random.seed(42)
|
||||
bots = [PersonalizedBot3("b3"), PersonalizedBot4("b4")]
|
||||
game = _make_game(bots)
|
||||
winner, turns = run_game(game)
|
||||
assert winner is not None
|
||||
@@ -1,230 +0,0 @@
|
||||
"""Extended tests for python/splendor/bot.py to improve coverage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from python.splendor.base import (
|
||||
BASE_COLORS,
|
||||
GEM_COLORS,
|
||||
BuyCard,
|
||||
BuyCardReserved,
|
||||
Card,
|
||||
GameConfig,
|
||||
GameState,
|
||||
PlayerState,
|
||||
ReserveCard,
|
||||
TakeDifferent,
|
||||
create_random_cards,
|
||||
create_random_nobles,
|
||||
new_game,
|
||||
run_game,
|
||||
)
|
||||
from python.splendor.bot import (
|
||||
PersonalizedBot,
|
||||
PersonalizedBot2,
|
||||
PersonalizedBot3,
|
||||
PersonalizedBot4,
|
||||
RandomBot,
|
||||
buy_card,
|
||||
buy_card_reserved,
|
||||
estimate_value_of_card,
|
||||
estimate_value_of_token,
|
||||
take_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _make_card(tier: int = 1, points: int = 0, color: str = "white", cost: dict | None = None) -> Card:
|
||||
if cost is None:
|
||||
cost = dict.fromkeys(GEM_COLORS, 0)
|
||||
return Card(tier=tier, points=points, color=color, cost=cost)
|
||||
|
||||
|
||||
def _make_game(num_players: int = 2) -> tuple[GameState, list]:
|
||||
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
return game, bots
|
||||
|
||||
|
||||
def test_random_bot_buys_affordable() -> None:
|
||||
"""Test RandomBot buys affordable cards."""
|
||||
random.seed(1)
|
||||
game, bots = _make_game(2)
|
||||
p = game.players[0]
|
||||
for c in BASE_COLORS:
|
||||
p.tokens[c] = 10
|
||||
# Should sometimes buy
|
||||
actions = [bots[0].choose_action(game, p) for _ in range(20)]
|
||||
buy_actions = [a for a in actions if isinstance(a, BuyCard)]
|
||||
assert len(buy_actions) > 0
|
||||
|
||||
|
||||
def test_random_bot_reserves() -> None:
|
||||
"""Test RandomBot reserves cards sometimes."""
|
||||
random.seed(3)
|
||||
game, bots = _make_game(2)
|
||||
actions = [bots[0].choose_action(game, game.players[0]) for _ in range(50)]
|
||||
reserve_actions = [a for a in actions if isinstance(a, ReserveCard)]
|
||||
assert len(reserve_actions) > 0
|
||||
|
||||
|
||||
def test_random_bot_choose_discard() -> None:
|
||||
"""Test RandomBot.choose_discard."""
|
||||
bot = RandomBot("test")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["white"] = 5
|
||||
p.tokens["blue"] = 3
|
||||
discards = bot.choose_discard(None, p, 2)
|
||||
assert sum(discards.values()) == 2
|
||||
|
||||
|
||||
def test_personalized_bot_takes_different() -> None:
|
||||
"""Test PersonalizedBot takes different when no affordable cards."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot("pbot")
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
action = bot.choose_action(game, p)
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_personalized_bot_choose_discard() -> None:
|
||||
"""Test PersonalizedBot.choose_discard."""
|
||||
bot = PersonalizedBot("pbot")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["white"] = 5
|
||||
discards = bot.choose_discard(None, p, 2)
|
||||
assert sum(discards.values()) == 2
|
||||
|
||||
|
||||
def test_personalized_bot2_buys_reserved() -> None:
|
||||
"""Test PersonalizedBot2 buys reserved cards."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot2("pbot2")
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
# Add affordable reserved card
|
||||
card = _make_card(cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
p.reserved.append(card)
|
||||
# Clear table cards to force reserved buy
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
action = bot.choose_action(game, p)
|
||||
assert isinstance(action, BuyCardReserved)
|
||||
|
||||
|
||||
def test_personalized_bot2_reserves_from_deck() -> None:
|
||||
"""Test PersonalizedBot2 reserves from deck when few colors available."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot2("pbot2")
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
# Clear table and set only 2 bank colors
|
||||
for tier in (1, 2, 3):
|
||||
game.table_by_tier[tier] = []
|
||||
for c in BASE_COLORS:
|
||||
game.bank[c] = 0
|
||||
game.bank["white"] = 1
|
||||
game.bank["blue"] = 1
|
||||
action = bot.choose_action(game, p)
|
||||
assert isinstance(action, (ReserveCard, TakeDifferent))
|
||||
|
||||
|
||||
def test_personalized_bot2_choose_discard() -> None:
|
||||
"""Test PersonalizedBot2.choose_discard."""
|
||||
bot = PersonalizedBot2("pbot2")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["red"] = 5
|
||||
discards = bot.choose_discard(None, p, 2)
|
||||
assert sum(discards.values()) == 2
|
||||
|
||||
|
||||
def test_personalized_bot3_choose_action() -> None:
|
||||
"""Test PersonalizedBot3.choose_action."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot3("pbot3")
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
action = bot.choose_action(game, p)
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_personalized_bot3_choose_discard() -> None:
|
||||
"""Test PersonalizedBot3.choose_discard."""
|
||||
bot = PersonalizedBot3("pbot3")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["green"] = 5
|
||||
discards = bot.choose_discard(None, p, 2)
|
||||
assert sum(discards.values()) == 2
|
||||
|
||||
|
||||
def test_personalized_bot4_choose_action() -> None:
|
||||
"""Test PersonalizedBot4.choose_action."""
|
||||
random.seed(42)
|
||||
bot = PersonalizedBot4("pbot4")
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
p.strategy = bot
|
||||
action = bot.choose_action(game, p)
|
||||
assert action is not None
|
||||
|
||||
|
||||
def test_personalized_bot4_filter_actions() -> None:
|
||||
"""Test PersonalizedBot4.filter_actions."""
|
||||
bot = PersonalizedBot4("pbot4")
|
||||
actions = [
|
||||
TakeDifferent(colors=["white", "blue", "green"]),
|
||||
TakeDifferent(colors=["white", "blue"]),
|
||||
BuyCard(tier=1, index=0),
|
||||
]
|
||||
filtered = bot.filter_actions(actions)
|
||||
# Should keep 3-color TakeDifferent and BuyCard, remove 2-color TakeDifferent
|
||||
assert len(filtered) == 2
|
||||
|
||||
|
||||
def test_personalized_bot4_choose_discard() -> None:
|
||||
"""Test PersonalizedBot4.choose_discard."""
|
||||
bot = PersonalizedBot4("pbot4")
|
||||
p = PlayerState(strategy=bot)
|
||||
p.tokens["black"] = 5
|
||||
discards = bot.choose_discard(None, p, 2)
|
||||
assert sum(discards.values()) == 2
|
||||
|
||||
|
||||
def test_estimate_value_of_card() -> None:
|
||||
"""Test estimate_value_of_card."""
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
result = estimate_value_of_card(game, p, "white")
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
def test_estimate_value_of_token() -> None:
|
||||
"""Test estimate_value_of_token."""
|
||||
game, _ = _make_game(2)
|
||||
p = game.players[0]
|
||||
result = estimate_value_of_token(game, p, "white")
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
def test_full_game_with_personalized_bots() -> None:
|
||||
"""Test a full game with different bot types."""
|
||||
random.seed(42)
|
||||
bots = [
|
||||
RandomBot("random"),
|
||||
PersonalizedBot("p1"),
|
||||
PersonalizedBot2("p2"),
|
||||
]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles, turn_limit=200)
|
||||
game = new_game(bots, config)
|
||||
winner, turns = run_game(game)
|
||||
assert winner is not None
|
||||
assert turns > 0
|
||||
@@ -1,156 +0,0 @@
|
||||
"""Tests for python/splendor/human.py - non-TUI parts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.splendor.human import (
|
||||
COST_ABBR,
|
||||
COLOR_ABBR_TO_FULL,
|
||||
COLOR_STYLE,
|
||||
color_token,
|
||||
fmt_gem,
|
||||
fmt_number,
|
||||
format_card,
|
||||
format_cost,
|
||||
format_discounts,
|
||||
format_noble,
|
||||
format_tokens,
|
||||
parse_color_token,
|
||||
)
|
||||
from python.splendor.base import Card, GEM_COLORS, Noble
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# --- parse_color_token ---
|
||||
|
||||
|
||||
def test_parse_color_token_full_names() -> None:
|
||||
"""Test parsing full color names."""
|
||||
assert parse_color_token("white") == "white"
|
||||
assert parse_color_token("blue") == "blue"
|
||||
assert parse_color_token("green") == "green"
|
||||
assert parse_color_token("red") == "red"
|
||||
assert parse_color_token("black") == "black"
|
||||
|
||||
|
||||
def test_parse_color_token_abbreviations() -> None:
|
||||
"""Test parsing abbreviated color names."""
|
||||
assert parse_color_token("w") == "white"
|
||||
assert parse_color_token("b") == "blue"
|
||||
assert parse_color_token("g") == "green"
|
||||
assert parse_color_token("r") == "red"
|
||||
assert parse_color_token("k") == "black"
|
||||
assert parse_color_token("o") == "gold"
|
||||
|
||||
|
||||
def test_parse_color_token_case_insensitive() -> None:
|
||||
"""Test parsing is case insensitive."""
|
||||
assert parse_color_token("WHITE") == "white"
|
||||
assert parse_color_token("B") == "blue"
|
||||
|
||||
|
||||
def test_parse_color_token_unknown() -> None:
|
||||
"""Test parsing unknown color raises."""
|
||||
with pytest.raises(ValueError, match="Unknown color"):
|
||||
parse_color_token("purple")
|
||||
|
||||
|
||||
# --- format functions ---
|
||||
|
||||
|
||||
def test_format_cost() -> None:
|
||||
"""Test format_cost formats correctly."""
|
||||
cost = {"white": 2, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0}
|
||||
result = format_cost(cost)
|
||||
assert "W:" in result
|
||||
assert "B:" in result
|
||||
|
||||
|
||||
def test_format_cost_empty() -> None:
|
||||
"""Test format_cost with all zeros."""
|
||||
cost = dict.fromkeys(GEM_COLORS, 0)
|
||||
result = format_cost(cost)
|
||||
assert result == "-"
|
||||
|
||||
|
||||
def test_format_card() -> None:
|
||||
"""Test format_card."""
|
||||
card = Card(tier=1, points=2, color="white", cost={"white": 0, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0})
|
||||
result = format_card(card)
|
||||
assert "T1" in result
|
||||
assert "P2" in result
|
||||
|
||||
|
||||
def test_format_noble() -> None:
|
||||
"""Test format_noble."""
|
||||
noble = Noble(name="Noble 1", points=3, requirements={"white": 3, "blue": 3, "green": 3})
|
||||
result = format_noble(noble)
|
||||
assert "Noble 1" in result
|
||||
assert "+3" in result
|
||||
|
||||
|
||||
def test_format_tokens() -> None:
|
||||
"""Test format_tokens."""
|
||||
tokens = {"white": 2, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0}
|
||||
result = format_tokens(tokens)
|
||||
assert "white:" in result
|
||||
|
||||
|
||||
def test_format_discounts() -> None:
|
||||
"""Test format_discounts."""
|
||||
discounts = {"white": 2, "blue": 1, "green": 0, "red": 0, "black": 0, "gold": 0}
|
||||
result = format_discounts(discounts)
|
||||
assert "W:" in result
|
||||
|
||||
|
||||
def test_format_discounts_empty() -> None:
|
||||
"""Test format_discounts with all zeros."""
|
||||
discounts = dict.fromkeys(GEM_COLORS, 0)
|
||||
result = format_discounts(discounts)
|
||||
assert result == "-"
|
||||
|
||||
|
||||
# --- formatting helpers ---
|
||||
|
||||
|
||||
def test_color_token() -> None:
|
||||
"""Test color_token."""
|
||||
result = color_token("white", 3)
|
||||
assert "white" in result
|
||||
assert "3" in result
|
||||
|
||||
|
||||
def test_fmt_gem() -> None:
|
||||
"""Test fmt_gem."""
|
||||
result = fmt_gem("blue")
|
||||
assert "blue" in result
|
||||
|
||||
|
||||
def test_fmt_number() -> None:
|
||||
"""Test fmt_number."""
|
||||
result = fmt_number(42)
|
||||
assert "42" in result
|
||||
|
||||
|
||||
# --- constants ---
|
||||
|
||||
|
||||
def test_cost_abbr_all_colors() -> None:
|
||||
"""Test COST_ABBR has all gem colors."""
|
||||
for color in GEM_COLORS:
|
||||
assert color in COST_ABBR
|
||||
|
||||
|
||||
def test_color_abbr_to_full() -> None:
|
||||
"""Test COLOR_ABBR_TO_FULL mappings."""
|
||||
assert COLOR_ABBR_TO_FULL["w"] == "white"
|
||||
assert COLOR_ABBR_TO_FULL["o"] == "gold"
|
||||
|
||||
|
||||
def test_color_style_all_colors() -> None:
|
||||
"""Test COLOR_STYLE has all gem colors."""
|
||||
for color in GEM_COLORS:
|
||||
assert color in COLOR_STYLE
|
||||
fg, bg = COLOR_STYLE[color]
|
||||
assert isinstance(fg, str)
|
||||
assert isinstance(bg, str)
|
||||
@@ -1,262 +0,0 @@
|
||||
"""Tests for splendor/human.py command handlers and TUI widgets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
from python.splendor.base import (
|
||||
BASE_COLORS,
|
||||
GEM_COLORS,
|
||||
BuyCard,
|
||||
BuyCardReserved,
|
||||
Card,
|
||||
GameConfig,
|
||||
GameState,
|
||||
Noble,
|
||||
PlayerState,
|
||||
ReserveCard,
|
||||
TakeDifferent,
|
||||
TakeDouble,
|
||||
create_random_cards,
|
||||
create_random_nobles,
|
||||
new_game,
|
||||
)
|
||||
from python.splendor.bot import RandomBot
|
||||
from python.splendor.human import (
|
||||
ActionApp,
|
||||
Board,
|
||||
DiscardApp,
|
||||
NobleChoiceApp,
|
||||
)
|
||||
|
||||
|
||||
def _make_game(num_players: int = 2):
|
||||
random.seed(42)
|
||||
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
return game, bots
|
||||
|
||||
|
||||
# --- ActionApp command handlers ---
|
||||
|
||||
|
||||
def test_action_app_cmd_1_basic() -> None:
|
||||
"""Test _cmd_1 take different colors."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
app._update_prompt = MagicMock()
|
||||
app.exit = MagicMock()
|
||||
result = app._cmd_1(["1", "white", "blue", "green"])
|
||||
assert result is None
|
||||
assert isinstance(app.result, TakeDifferent)
|
||||
assert app.result.colors == ["white", "blue", "green"]
|
||||
|
||||
|
||||
def test_action_app_cmd_1_abbreviations() -> None:
|
||||
"""Test _cmd_1 with abbreviated colors."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
app.exit = MagicMock()
|
||||
result = app._cmd_1(["1", "w", "b", "g"])
|
||||
assert result is None
|
||||
assert isinstance(app.result, TakeDifferent)
|
||||
|
||||
|
||||
def test_action_app_cmd_1_no_colors() -> None:
|
||||
"""Test _cmd_1 with no colors."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_1(["1"])
|
||||
assert result is not None # Error message
|
||||
|
||||
|
||||
def test_action_app_cmd_1_empty_bank() -> None:
|
||||
"""Test _cmd_1 with empty bank color."""
|
||||
game, _ = _make_game()
|
||||
game.bank["white"] = 0
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_1(["1", "white"])
|
||||
assert result is not None # Error message
|
||||
|
||||
|
||||
def test_action_app_cmd_2() -> None:
|
||||
"""Test _cmd_2 take double."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
app.exit = MagicMock()
|
||||
result = app._cmd_2(["2", "white"])
|
||||
assert result is None
|
||||
assert isinstance(app.result, TakeDouble)
|
||||
|
||||
|
||||
def test_action_app_cmd_2_no_color() -> None:
|
||||
"""Test _cmd_2 with no color."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_2(["2"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_cmd_2_insufficient_bank() -> None:
|
||||
"""Test _cmd_2 with insufficient bank."""
|
||||
game, _ = _make_game()
|
||||
game.bank["white"] = 2
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_2(["2", "white"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_cmd_3() -> None:
|
||||
"""Test _cmd_3 buy card."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
app.exit = MagicMock()
|
||||
result = app._cmd_3(["3", "1", "0"])
|
||||
assert result is None
|
||||
assert isinstance(app.result, BuyCard)
|
||||
|
||||
|
||||
def test_action_app_cmd_3_no_args() -> None:
|
||||
"""Test _cmd_3 with insufficient args."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_3(["3"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_cmd_4() -> None:
|
||||
"""Test _cmd_4 buy reserved card - source has bug passing tier= to BuyCardReserved."""
|
||||
game, _ = _make_game()
|
||||
card = Card(tier=1, points=0, color="white", cost=dict.fromkeys(GEM_COLORS, 0))
|
||||
game.players[0].reserved.append(card)
|
||||
app = ActionApp(game, game.players[0])
|
||||
app.exit = MagicMock()
|
||||
# BuyCardReserved doesn't accept tier=, so the source code has a bug here
|
||||
import pytest
|
||||
with pytest.raises(TypeError):
|
||||
app._cmd_4(["4", "0"])
|
||||
|
||||
|
||||
def test_action_app_cmd_4_no_args() -> None:
|
||||
"""Test _cmd_4 with no args."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_4(["4"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_cmd_4_out_of_range() -> None:
|
||||
"""Test _cmd_4 with out of range index."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_4(["4", "0"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_cmd_5() -> None:
|
||||
"""Test _cmd_5 reserve face-up card."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
app.exit = MagicMock()
|
||||
result = app._cmd_5(["5", "1", "0"])
|
||||
assert result is None
|
||||
assert isinstance(app.result, ReserveCard)
|
||||
assert app.result.from_deck is False
|
||||
|
||||
|
||||
def test_action_app_cmd_5_no_args() -> None:
|
||||
"""Test _cmd_5 with no args."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_5(["5"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_cmd_6() -> None:
|
||||
"""Test _cmd_6 reserve from deck."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
app.exit = MagicMock()
|
||||
result = app._cmd_6(["6", "1"])
|
||||
assert result is None
|
||||
assert isinstance(app.result, ReserveCard)
|
||||
assert app.result.from_deck is True
|
||||
|
||||
|
||||
def test_action_app_cmd_6_no_args() -> None:
|
||||
"""Test _cmd_6 with no args."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._cmd_6(["6"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_action_app_unknown_cmd() -> None:
|
||||
"""Test unknown command."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
result = app._unknown_cmd(["99"])
|
||||
assert result == "Unknown command."
|
||||
|
||||
|
||||
# --- ActionApp init ---
|
||||
|
||||
|
||||
def test_action_app_init() -> None:
|
||||
"""Test ActionApp initialization."""
|
||||
game, _ = _make_game()
|
||||
app = ActionApp(game, game.players[0])
|
||||
assert app.result is None
|
||||
assert app.message == ""
|
||||
assert app.game is game
|
||||
assert app.player is game.players[0]
|
||||
|
||||
|
||||
# --- DiscardApp ---
|
||||
|
||||
|
||||
def test_discard_app_init() -> None:
|
||||
"""Test DiscardApp initialization."""
|
||||
game, _ = _make_game()
|
||||
app = DiscardApp(game, game.players[0])
|
||||
assert app.discards == dict.fromkeys(GEM_COLORS, 0)
|
||||
assert app.message == ""
|
||||
|
||||
|
||||
def test_discard_app_remaining_to_discard() -> None:
|
||||
"""Test DiscardApp._remaining_to_discard."""
|
||||
game, _ = _make_game()
|
||||
p = game.players[0]
|
||||
for c in BASE_COLORS:
|
||||
p.tokens[c] = 5
|
||||
app = DiscardApp(game, p)
|
||||
remaining = app._remaining_to_discard()
|
||||
assert remaining == p.total_tokens() - game.config.token_limit
|
||||
|
||||
|
||||
# --- NobleChoiceApp ---
|
||||
|
||||
|
||||
def test_noble_choice_app_init() -> None:
|
||||
"""Test NobleChoiceApp initialization."""
|
||||
game, _ = _make_game()
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
assert app.result is None
|
||||
assert app.nobles == nobles
|
||||
assert app.message == ""
|
||||
|
||||
|
||||
# --- Board ---
|
||||
|
||||
|
||||
def test_board_init() -> None:
|
||||
"""Test Board initialization."""
|
||||
game, _ = _make_game()
|
||||
board = Board(game, game.players[0])
|
||||
assert board.game is game
|
||||
assert board.me is game.players[0]
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Tests for python/splendor/human.py TUI classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from python.splendor.base import (
|
||||
GEM_COLORS,
|
||||
GameConfig,
|
||||
PlayerState,
|
||||
create_random_cards,
|
||||
create_random_nobles,
|
||||
new_game,
|
||||
)
|
||||
from python.splendor.bot import RandomBot
|
||||
from python.splendor.human import TuiHuman
|
||||
|
||||
|
||||
def _make_game(num_players: int = 2):
|
||||
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
return game, bots
|
||||
|
||||
|
||||
def test_tui_human_choose_action_no_tty() -> None:
|
||||
"""Test TuiHuman returns None when not a TTY."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game(2)
|
||||
human = TuiHuman("test")
|
||||
# In test environment, stdout is not a TTY
|
||||
result = human.choose_action(game, game.players[0])
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_tui_human_choose_discard_no_tty() -> None:
|
||||
"""Test TuiHuman returns empty discards when not a TTY."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game(2)
|
||||
human = TuiHuman("test")
|
||||
result = human.choose_discard(game, game.players[0], 2)
|
||||
assert result == dict.fromkeys(GEM_COLORS, 0)
|
||||
|
||||
|
||||
def test_tui_human_choose_noble_no_tty() -> None:
|
||||
"""Test TuiHuman returns first noble when not a TTY."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game(2)
|
||||
human = TuiHuman("test")
|
||||
nobles = game.available_nobles[:2]
|
||||
result = human.choose_noble(game, game.players[0], nobles)
|
||||
assert result == nobles[0]
|
||||
@@ -1,647 +0,0 @@
|
||||
"""Tests for splendor/human.py Textual widgets and TUI apps.
|
||||
|
||||
Covers Board (compose, on_mount, refresh_content, render methods),
|
||||
ActionApp/DiscardApp/NobleChoiceApp (compose, on_mount, _update_prompt,
|
||||
on_input_submitted), and TuiHuman tty paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.splendor.base import (
|
||||
BASE_COLORS,
|
||||
GEM_COLORS,
|
||||
Card,
|
||||
GameConfig,
|
||||
GameState,
|
||||
Noble,
|
||||
PlayerState,
|
||||
TakeDifferent,
|
||||
create_random_cards,
|
||||
create_random_nobles,
|
||||
new_game,
|
||||
)
|
||||
from python.splendor.bot import RandomBot
|
||||
from python.splendor.human import (
|
||||
ActionApp,
|
||||
Board,
|
||||
DiscardApp,
|
||||
NobleChoiceApp,
|
||||
TuiHuman,
|
||||
)
|
||||
|
||||
|
||||
def _make_game(num_players: int = 2):
|
||||
random.seed(42)
|
||||
bots = [RandomBot(f"bot{i}") for i in range(num_players)]
|
||||
cards = create_random_cards()
|
||||
nobles = create_random_nobles()
|
||||
config = GameConfig(cards=cards, nobles=nobles)
|
||||
game = new_game(bots, config)
|
||||
return game, bots
|
||||
|
||||
|
||||
def _patch_player_names(game: GameState) -> None:
|
||||
"""Add .name attribute to each PlayerState (delegates to strategy.name)."""
|
||||
for p in game.players:
|
||||
p.name = p.strategy.name # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Board widget tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_compose_and_mount() -> None:
|
||||
"""Board.compose yields expected widget tree; on_mount populates them."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
assert board is not None
|
||||
|
||||
# Verify sub-widgets exist
|
||||
assert app.query_one("#bank_box") is not None
|
||||
assert app.query_one("#tier1_box") is not None
|
||||
assert app.query_one("#tier2_box") is not None
|
||||
assert app.query_one("#tier3_box") is not None
|
||||
assert app.query_one("#nobles_box") is not None
|
||||
assert app.query_one("#players_box") is not None
|
||||
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_bank() -> None:
|
||||
"""Board._render_bank writes bank info to bank_box."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_bank()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_tiers() -> None:
|
||||
"""Board._render_tiers populates tier boxes."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_tiers()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_tiers_empty() -> None:
|
||||
"""Board._render_tiers handles empty tiers."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
for tier in game.table_by_tier:
|
||||
game.table_by_tier[tier] = []
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_tiers()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_nobles() -> None:
|
||||
"""Board._render_nobles shows noble info."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_nobles()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_nobles_empty() -> None:
|
||||
"""Board._render_nobles handles no nobles."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
game.available_nobles = []
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_nobles()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_players() -> None:
|
||||
"""Board._render_players shows all player info."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_players()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_render_players_with_nobles_and_cards() -> None:
|
||||
"""Board._render_players handles players with nobles, cards, and reserved."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
p = game.players[0]
|
||||
card = Card(
|
||||
tier=1, points=1, color="white", cost=dict.fromkeys(GEM_COLORS, 0),
|
||||
)
|
||||
p.cards.append(card)
|
||||
reserved = Card(
|
||||
tier=2, points=2, color="blue", cost=dict.fromkeys(GEM_COLORS, 0),
|
||||
)
|
||||
p.reserved.append(reserved)
|
||||
noble = Noble(
|
||||
name="TestNoble", points=3, requirements=dict.fromkeys(GEM_COLORS, 0),
|
||||
)
|
||||
p.nobles.append(noble)
|
||||
|
||||
app = ActionApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board._render_players()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_board_refresh_content() -> None:
|
||||
"""Board.refresh_content calls all render sub-methods."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
board = app.query_one(Board)
|
||||
board.refresh_content()
|
||||
app.exit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ActionApp tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_compose_and_mount() -> None:
|
||||
"""ActionApp composes command_zone, board, footer and sets up prompt."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
from textual.widgets import Footer, Input, Static
|
||||
|
||||
assert app.query_one("#input_line", Input) is not None
|
||||
assert app.query_one("#prompt", Static) is not None
|
||||
assert app.query_one("#board", Board) is not None
|
||||
assert app.query_one(Footer) is not None
|
||||
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_update_prompt() -> None:
|
||||
"""ActionApp._update_prompt writes action menu to prompt widget."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
app._update_prompt()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_update_prompt_with_message() -> None:
|
||||
"""ActionApp._update_prompt includes error message when set."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
app.message = "Some error occurred"
|
||||
app._update_prompt()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_quit() -> None:
|
||||
"""ActionApp exits on 'q' input via pilot keyboard."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("q", "enter")
|
||||
await pilot.pause()
|
||||
assert app.result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_quit_word() -> None:
|
||||
"""ActionApp exits on 'quit' input."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("q", "u", "i", "t", "enter")
|
||||
await pilot.pause()
|
||||
assert app.result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_zero() -> None:
|
||||
"""ActionApp exits on '0' input."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("0", "enter")
|
||||
await pilot.pause()
|
||||
assert app.result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_empty() -> None:
|
||||
"""ActionApp ignores empty input."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert app.result is None
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_valid_cmd() -> None:
|
||||
"""ActionApp processes valid command '1 w b g' and exits."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
for ch in "1 w b g":
|
||||
await pilot.press(ch)
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert isinstance(app.result, TakeDifferent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_error() -> None:
|
||||
"""ActionApp shows error message for bad command."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
for ch in "xyz":
|
||||
await pilot.press(ch)
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert app.message == "Unknown command."
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_app_on_input_submitted_cmd_error() -> None:
|
||||
"""ActionApp shows error from a valid command number but bad args."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
app = ActionApp(game, game.players[0])
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("1", "enter")
|
||||
await pilot.pause()
|
||||
assert app.message != ""
|
||||
app.exit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DiscardApp tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_discard_game(excess: int = 1):
|
||||
"""Create a game where player 0 has excess tokens over the limit."""
|
||||
game, _bots = _make_game()
|
||||
_patch_player_names(game)
|
||||
p = game.players[0]
|
||||
for c in GEM_COLORS:
|
||||
p.tokens[c] = 0
|
||||
p.tokens["white"] = game.config.token_limit + excess
|
||||
return game, p
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_compose_and_mount() -> None:
|
||||
"""DiscardApp composes header, command_zone, board, footer."""
|
||||
game, p = _make_discard_game(2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
from textual.widgets import Footer, Header, Input, Static
|
||||
|
||||
assert app.query_one(Header) is not None
|
||||
assert app.query_one("#input_line", Input) is not None
|
||||
assert app.query_one("#prompt", Static) is not None
|
||||
assert app.query_one("#board", Board) is not None
|
||||
assert app.query_one(Footer) is not None
|
||||
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_update_prompt() -> None:
|
||||
"""DiscardApp._update_prompt shows remaining discards info."""
|
||||
game, p = _make_discard_game(2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
app._update_prompt()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_update_prompt_with_message() -> None:
|
||||
"""DiscardApp._update_prompt includes error message."""
|
||||
game, p = _make_discard_game(2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
app.message = "No more blue tokens"
|
||||
app._update_prompt()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_on_input_submitted_empty() -> None:
|
||||
"""DiscardApp ignores empty input."""
|
||||
game, p = _make_discard_game(2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert all(v == 0 for v in app.discards.values())
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_on_input_submitted_unknown_color() -> None:
|
||||
"""DiscardApp shows error for unknown color."""
|
||||
game, p = _make_discard_game(2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
for ch in "purple":
|
||||
await pilot.press(ch)
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert "Unknown color" in app.message
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_on_input_submitted_no_tokens() -> None:
|
||||
"""DiscardApp shows error when no tokens of that color available."""
|
||||
game, p = _make_discard_game(2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
for ch in "blue":
|
||||
await pilot.press(ch)
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert "No more" in app.message
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_on_input_submitted_valid_finishes() -> None:
|
||||
"""DiscardApp increments discard and exits when done (excess=1)."""
|
||||
game, p = _make_discard_game(excess=1)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("w", "enter")
|
||||
await pilot.pause()
|
||||
assert app.discards["white"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_app_on_input_submitted_not_done_yet() -> None:
|
||||
"""DiscardApp stays open when more discards still needed (excess=2)."""
|
||||
game, p = _make_discard_game(excess=2)
|
||||
app = DiscardApp(game, p)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("w", "enter")
|
||||
await pilot.pause()
|
||||
assert app.discards["white"] == 1
|
||||
assert app.message == ""
|
||||
|
||||
await pilot.press("w", "enter")
|
||||
await pilot.pause()
|
||||
assert app.discards["white"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NobleChoiceApp tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_compose_and_mount() -> None:
|
||||
"""NobleChoiceApp composes header, command_zone, board, footer."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
from textual.widgets import Footer, Header, Input, Static
|
||||
|
||||
assert app.query_one(Header) is not None
|
||||
assert app.query_one("#input_line", Input) is not None
|
||||
assert app.query_one("#prompt", Static) is not None
|
||||
assert app.query_one("#board", Board) is not None
|
||||
assert app.query_one(Footer) is not None
|
||||
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_update_prompt() -> None:
|
||||
"""NobleChoiceApp._update_prompt lists available nobles."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
app._update_prompt()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_update_prompt_with_message() -> None:
|
||||
"""NobleChoiceApp._update_prompt includes error message."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
app.message = "Index out of range."
|
||||
app._update_prompt()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_on_input_submitted_empty() -> None:
|
||||
"""NobleChoiceApp ignores empty input."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert app.result is None
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_on_input_submitted_not_int() -> None:
|
||||
"""NobleChoiceApp shows error for non-integer input."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
for ch in "abc":
|
||||
await pilot.press(ch)
|
||||
await pilot.press("enter")
|
||||
await pilot.pause()
|
||||
assert "valid integer" in app.message
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_on_input_submitted_out_of_range() -> None:
|
||||
"""NobleChoiceApp shows error for index out of range."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("9", "enter")
|
||||
await pilot.pause()
|
||||
assert "out of range" in app.message.lower()
|
||||
app.exit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_on_input_submitted_valid() -> None:
|
||||
"""NobleChoiceApp selects noble and exits on valid index."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("0", "enter")
|
||||
await pilot.pause()
|
||||
assert app.result is nobles[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noble_choice_app_on_input_submitted_second_noble() -> None:
|
||||
"""NobleChoiceApp selects second noble."""
|
||||
game, _ = _make_game()
|
||||
_patch_player_names(game)
|
||||
nobles = game.available_nobles[:2]
|
||||
app = NobleChoiceApp(game, game.players[0], nobles)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.press("1", "enter")
|
||||
await pilot.pause()
|
||||
assert app.result is nobles[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TuiHuman tty path tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_tui_human_choose_action_tty() -> None:
|
||||
"""TuiHuman.choose_action runs ActionApp when stdout is a tty."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game()
|
||||
human = TuiHuman("test")
|
||||
|
||||
with patch.object(sys.stdout, "isatty", return_value=True):
|
||||
with patch.object(ActionApp, "run") as mock_run:
|
||||
result = human.choose_action(game, game.players[0])
|
||||
mock_run.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_tui_human_choose_discard_tty() -> None:
|
||||
"""TuiHuman.choose_discard runs DiscardApp when stdout is a tty."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game()
|
||||
human = TuiHuman("test")
|
||||
|
||||
with patch.object(sys.stdout, "isatty", return_value=True):
|
||||
with patch.object(DiscardApp, "run") as mock_run:
|
||||
result = human.choose_discard(game, game.players[0], 2)
|
||||
mock_run.assert_called_once()
|
||||
assert result == dict.fromkeys(GEM_COLORS, 0)
|
||||
|
||||
|
||||
def test_tui_human_choose_noble_tty() -> None:
|
||||
"""TuiHuman.choose_noble runs NobleChoiceApp when stdout is a tty."""
|
||||
random.seed(42)
|
||||
game, _ = _make_game()
|
||||
nobles = game.available_nobles[:2]
|
||||
human = TuiHuman("test")
|
||||
|
||||
with patch.object(sys.stdout, "isatty", return_value=True):
|
||||
with patch.object(NobleChoiceApp, "run") as mock_run:
|
||||
result = human.choose_noble(game, game.players[0], nobles)
|
||||
mock_run.assert_called_once()
|
||||
assert result is None
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Tests for python/splendor/main.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_splendor_main_import() -> None:
|
||||
"""Test that splendor main module can be imported."""
|
||||
from python.splendor.main import main
|
||||
assert callable(main)
|
||||
|
||||
|
||||
def test_splendor_main_calls_run_game() -> None:
|
||||
"""Test main creates human + bot and runs game."""
|
||||
# main() uses wrong signature for new_game (passes strings instead of strategies)
|
||||
# so we just verify it can be called with mocked internals
|
||||
with (
|
||||
patch("python.splendor.main.TuiHuman") as mock_tui,
|
||||
patch("python.splendor.main.RandomBot") as mock_bot,
|
||||
patch("python.splendor.main.new_game") as mock_new_game,
|
||||
patch("python.splendor.main.run_game") as mock_run_game,
|
||||
):
|
||||
mock_tui.return_value = MagicMock()
|
||||
mock_bot.return_value = MagicMock()
|
||||
mock_new_game.return_value = MagicMock()
|
||||
mock_run_game.return_value = (MagicMock(), 10)
|
||||
|
||||
from python.splendor.main import main
|
||||
main()
|
||||
|
||||
mock_new_game.assert_called_once()
|
||||
mock_run_game.assert_called_once()
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Tests for python/splendor/simulat.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from python.splendor.base import load_cards, load_nobles
|
||||
from python.splendor.simulat import main
|
||||
|
||||
|
||||
def test_simulat_main(tmp_path: Path) -> None:
|
||||
"""Test simulat main function with mock game data."""
|
||||
random.seed(42)
|
||||
|
||||
# Create temporary game data
|
||||
cards_dir = tmp_path / "game_data" / "cards"
|
||||
nobles_dir = tmp_path / "game_data" / "nobles"
|
||||
cards_dir.mkdir(parents=True)
|
||||
nobles_dir.mkdir(parents=True)
|
||||
|
||||
cards = []
|
||||
for tier in (1, 2, 3):
|
||||
for color in ("white", "blue", "green", "red", "black"):
|
||||
cards.append({
|
||||
"tier": tier,
|
||||
"points": tier,
|
||||
"color": color,
|
||||
"cost": {"white": tier, "blue": 0, "green": 0, "red": 0, "black": 0, "gold": 0},
|
||||
})
|
||||
(cards_dir / "default.json").write_text(json.dumps(cards))
|
||||
|
||||
nobles = [
|
||||
{"name": f"Noble {i}", "points": 3, "requirements": {"white": 3, "blue": 3, "green": 3}}
|
||||
for i in range(5)
|
||||
]
|
||||
(nobles_dir / "default.json").write_text(json.dumps(nobles))
|
||||
|
||||
# Patch Path(__file__).parent to point to tmp_path
|
||||
fake_parent = tmp_path
|
||||
with patch("python.splendor.simulat.Path") as mock_path_cls:
|
||||
mock_path_cls.return_value.__truediv__ = Path.__truediv__
|
||||
mock_file = mock_path_cls().__truediv__("simulat.py")
|
||||
# Make Path(__file__).parent return tmp_path
|
||||
mock_path_cls.reset_mock()
|
||||
mock_path_instance = mock_path_cls.return_value
|
||||
mock_path_instance.parent = fake_parent
|
||||
|
||||
# Actually just patch load_cards and load_nobles
|
||||
cards_data = load_cards(cards_dir / "default.json")
|
||||
nobles_data = load_nobles(nobles_dir / "default.json")
|
||||
|
||||
with (
|
||||
patch("python.splendor.simulat.load_cards", return_value=cards_data),
|
||||
patch("python.splendor.simulat.load_nobles", return_value=nobles_data),
|
||||
):
|
||||
main()
|
||||
|
||||
|
||||
def test_load_cards_and_nobles(tmp_path: Path) -> None:
|
||||
"""Test that load_cards and load_nobles work correctly."""
|
||||
cards_dir = tmp_path / "cards"
|
||||
cards_dir.mkdir()
|
||||
|
||||
cards = [
|
||||
{
|
||||
"tier": 1,
|
||||
"points": 0,
|
||||
"color": "white",
|
||||
"cost": {"white": 1, "blue": 0, "green": 0, "red": 0, "black": 0, "gold": 0},
|
||||
}
|
||||
]
|
||||
cards_file = cards_dir / "default.json"
|
||||
cards_file.write_text(json.dumps(cards))
|
||||
loaded = load_cards(cards_file)
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0].color == "white"
|
||||
|
||||
nobles_dir = tmp_path / "nobles"
|
||||
nobles_dir.mkdir()
|
||||
nobles = [{"name": "Noble A", "points": 3, "requirements": {"white": 3}}]
|
||||
nobles_file = nobles_dir / "default.json"
|
||||
nobles_file.write_text(json.dumps(nobles))
|
||||
loaded_nobles = load_nobles(nobles_file)
|
||||
assert len(loaded_nobles) == 1
|
||||
assert loaded_nobles[0].name == "Noble A"
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Tests for python/stuff modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.stuff.capasitor import (
|
||||
calculate_capacitor_capacity,
|
||||
calculate_pack_capacity,
|
||||
calculate_pack_capacity2,
|
||||
)
|
||||
from python.stuff.voltage_drop import (
|
||||
Length,
|
||||
LengthUnit,
|
||||
MaterialType,
|
||||
Temperature,
|
||||
TemperatureUnit,
|
||||
calculate_awg_diameter_mm,
|
||||
calculate_resistance_per_meter,
|
||||
calculate_wire_area_m2,
|
||||
get_material_resistivity,
|
||||
max_wire_length,
|
||||
voltage_drop,
|
||||
)
|
||||
|
||||
|
||||
# --- capasitor tests ---
|
||||
|
||||
|
||||
def test_calculate_capacitor_capacity() -> None:
|
||||
"""Test capacitor capacity calculation."""
|
||||
result = calculate_capacitor_capacity(voltage=2.7, farads=500)
|
||||
assert isinstance(result, float)
|
||||
|
||||
|
||||
def test_calculate_pack_capacity() -> None:
|
||||
"""Test pack capacity calculation."""
|
||||
result = calculate_pack_capacity(cells=10, cell_voltage=2.7, farads=500)
|
||||
assert isinstance(result, float)
|
||||
|
||||
|
||||
def test_calculate_pack_capacity2() -> None:
|
||||
"""Test pack capacity2 calculation returns capacity and cost."""
|
||||
capacity, cost = calculate_pack_capacity2(cells=10, cell_voltage=2.7, farads=3000, cell_cost=11.60)
|
||||
assert isinstance(capacity, float)
|
||||
assert cost == 11.60 * 10
|
||||
|
||||
|
||||
# --- voltage_drop tests ---
|
||||
|
||||
|
||||
def test_temperature_celsius() -> None:
|
||||
"""Test Temperature with celsius."""
|
||||
t = Temperature(20.0, TemperatureUnit.CELSIUS)
|
||||
assert float(t) == 20.0
|
||||
|
||||
|
||||
def test_temperature_fahrenheit() -> None:
|
||||
"""Test Temperature with fahrenheit."""
|
||||
t = Temperature(100.0, TemperatureUnit.FAHRENHEIT)
|
||||
assert isinstance(float(t), float)
|
||||
|
||||
|
||||
def test_temperature_kelvin() -> None:
|
||||
"""Test Temperature with kelvin."""
|
||||
t = Temperature(300.0, TemperatureUnit.KELVIN)
|
||||
assert isinstance(float(t), float)
|
||||
|
||||
|
||||
def test_temperature_default_unit() -> None:
|
||||
"""Test Temperature defaults to celsius."""
|
||||
t = Temperature(25.0)
|
||||
assert float(t) == 25.0
|
||||
|
||||
|
||||
def test_length_meters() -> None:
|
||||
"""Test Length in meters."""
|
||||
length = Length(10.0, LengthUnit.METERS)
|
||||
assert float(length) == 10.0
|
||||
|
||||
|
||||
def test_length_feet() -> None:
|
||||
"""Test Length in feet."""
|
||||
length = Length(10.0, LengthUnit.FEET)
|
||||
assert abs(float(length) - 3.048) < 0.001
|
||||
|
||||
|
||||
def test_length_inches() -> None:
|
||||
"""Test Length in inches."""
|
||||
length = Length(100.0, LengthUnit.INCHES)
|
||||
assert abs(float(length) - 2.54) < 0.001
|
||||
|
||||
|
||||
def test_length_feet_method() -> None:
|
||||
"""Test Length.feet() conversion."""
|
||||
length = Length(1.0, LengthUnit.METERS)
|
||||
assert abs(length.feet() - 3.2808) < 0.001
|
||||
|
||||
|
||||
def test_get_material_resistivity_default_temp() -> None:
|
||||
"""Test material resistivity with default temperature."""
|
||||
r = get_material_resistivity(MaterialType.COPPER)
|
||||
assert r > 0
|
||||
|
||||
|
||||
def test_get_material_resistivity_with_temp() -> None:
|
||||
"""Test material resistivity with explicit temperature."""
|
||||
r = get_material_resistivity(MaterialType.ALUMINUM, Temperature(50.0))
|
||||
assert r > 0
|
||||
|
||||
|
||||
def test_get_material_resistivity_all_materials() -> None:
|
||||
"""Test resistivity for all materials."""
|
||||
for material in MaterialType:
|
||||
r = get_material_resistivity(material)
|
||||
assert r > 0
|
||||
|
||||
|
||||
def test_calculate_awg_diameter_mm() -> None:
|
||||
"""Test AWG diameter calculation."""
|
||||
d = calculate_awg_diameter_mm(10)
|
||||
assert d > 0
|
||||
|
||||
|
||||
def test_calculate_wire_area_m2() -> None:
|
||||
"""Test wire area calculation."""
|
||||
area = calculate_wire_area_m2(10)
|
||||
assert area > 0
|
||||
|
||||
|
||||
def test_calculate_resistance_per_meter() -> None:
|
||||
"""Test resistance per meter calculation."""
|
||||
r = calculate_resistance_per_meter(10)
|
||||
assert r > 0
|
||||
|
||||
|
||||
def test_voltage_drop_calculation() -> None:
|
||||
"""Test voltage drop calculation."""
|
||||
vd = voltage_drop(
|
||||
gauge=10,
|
||||
material=MaterialType.CCA,
|
||||
length=Length(20, LengthUnit.FEET),
|
||||
current_a=20,
|
||||
)
|
||||
assert vd > 0
|
||||
|
||||
|
||||
def test_max_wire_length_default_temp() -> None:
|
||||
"""Test max wire length with default temperature."""
|
||||
result = max_wire_length(gauge=10, material=MaterialType.CCA, current_amps=20)
|
||||
assert float(result) > 0
|
||||
assert result.feet() > 0
|
||||
|
||||
|
||||
def test_max_wire_length_with_temp() -> None:
|
||||
"""Test max wire length with explicit temperature."""
|
||||
result = max_wire_length(
|
||||
gauge=10,
|
||||
material=MaterialType.COPPER,
|
||||
current_amps=10,
|
||||
voltage_drop=0.5,
|
||||
temperature=Temperature(30.0),
|
||||
)
|
||||
assert float(result) > 0
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Tests for capasitor main function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.stuff.capasitor import main
|
||||
|
||||
|
||||
def test_capasitor_main(capsys: object) -> None:
|
||||
"""Test capasitor main function runs."""
|
||||
main()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Tests for python/stuff/thing.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.stuff.thing import caculat_batry_specs
|
||||
|
||||
|
||||
def test_caculat_batry_specs() -> None:
|
||||
"""Test battery specs calculation."""
|
||||
capacity, voltage = caculat_batry_specs(
|
||||
cell_amp_hour=300,
|
||||
cell_voltage=3.2,
|
||||
cells_per_pack=8,
|
||||
packs=2,
|
||||
)
|
||||
assert voltage == 3.2 * 8
|
||||
assert capacity == voltage * 300 * 2
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Tests for python/testing/logging modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from python.testing.logging.bar import bar
|
||||
from python.testing.logging.configure_logger import configure_logger
|
||||
from python.testing.logging.foo import foo
|
||||
from python.testing.logging.main import main
|
||||
|
||||
|
||||
def test_bar() -> None:
|
||||
"""Test bar function."""
|
||||
bar()
|
||||
|
||||
|
||||
def test_configure_logger_default() -> None:
|
||||
"""Test configure_logger with default level."""
|
||||
configure_logger()
|
||||
|
||||
|
||||
def test_configure_logger_debug() -> None:
|
||||
"""Test configure_logger with debug level."""
|
||||
configure_logger("DEBUG")
|
||||
|
||||
|
||||
def test_configure_logger_with_test() -> None:
|
||||
"""Test configure_logger with test name."""
|
||||
configure_logger("INFO", "TEST")
|
||||
|
||||
|
||||
def test_foo() -> None:
|
||||
"""Test foo function."""
|
||||
foo()
|
||||
|
||||
|
||||
def test_main() -> None:
|
||||
"""Test main function."""
|
||||
main()
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Tests for python/van_weather modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.van_weather.models import Config, DailyForecast, HourlyForecast, Weather
|
||||
from python.van_weather.main import (
|
||||
CONDITION_MAP,
|
||||
fetch_weather,
|
||||
get_ha_state,
|
||||
parse_daily_forecast,
|
||||
parse_hourly_forecast,
|
||||
post_to_ha,
|
||||
update_weather,
|
||||
_post_weather_data,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
# --- models tests ---
|
||||
|
||||
|
||||
def test_config() -> None:
|
||||
"""Test Config creation."""
|
||||
config = Config(ha_url="http://ha.local", ha_token="token123", pirate_weather_api_key="key123")
|
||||
assert config.ha_url == "http://ha.local"
|
||||
assert config.lat_entity == "sensor.gps_latitude"
|
||||
assert config.lon_entity == "sensor.gps_longitude"
|
||||
assert config.mask_decimals == 1
|
||||
|
||||
|
||||
def test_daily_forecast() -> None:
|
||||
"""Test DailyForecast creation and serialization."""
|
||||
dt = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
forecast = DailyForecast(
|
||||
date_time=dt,
|
||||
condition="sunny",
|
||||
temperature=75.0,
|
||||
templow=55.0,
|
||||
precipitation_probability=0.1,
|
||||
)
|
||||
assert forecast.condition == "sunny"
|
||||
serialized = forecast.model_dump()
|
||||
assert serialized["date_time"] == dt.isoformat()
|
||||
|
||||
|
||||
def test_hourly_forecast() -> None:
|
||||
"""Test HourlyForecast creation and serialization."""
|
||||
dt = datetime(2024, 1, 1, 12, 0, tzinfo=UTC)
|
||||
forecast = HourlyForecast(
|
||||
date_time=dt,
|
||||
condition="cloudy",
|
||||
temperature=65.0,
|
||||
precipitation_probability=0.3,
|
||||
)
|
||||
assert forecast.temperature == 65.0
|
||||
serialized = forecast.model_dump()
|
||||
assert serialized["date_time"] == dt.isoformat()
|
||||
|
||||
|
||||
def test_weather_defaults() -> None:
|
||||
"""Test Weather default values."""
|
||||
weather = Weather()
|
||||
assert weather.temperature is None
|
||||
assert weather.daily_forecasts == []
|
||||
assert weather.hourly_forecasts == []
|
||||
|
||||
|
||||
def test_weather_full() -> None:
|
||||
"""Test Weather with all fields."""
|
||||
weather = Weather(
|
||||
temperature=72.0,
|
||||
feels_like=70.0,
|
||||
humidity=0.5,
|
||||
wind_speed=10.0,
|
||||
wind_bearing=180.0,
|
||||
condition="sunny",
|
||||
summary="Clear",
|
||||
pressure=1013.0,
|
||||
visibility=10.0,
|
||||
)
|
||||
assert weather.temperature == 72.0
|
||||
assert weather.condition == "sunny"
|
||||
|
||||
|
||||
# --- main tests ---
|
||||
|
||||
|
||||
def test_condition_map() -> None:
|
||||
"""Test CONDITION_MAP has expected entries."""
|
||||
assert CONDITION_MAP["clear-day"] == "sunny"
|
||||
assert CONDITION_MAP["rain"] == "rainy"
|
||||
assert CONDITION_MAP["snow"] == "snowy"
|
||||
|
||||
|
||||
def test_get_ha_state() -> None:
|
||||
"""Test get_ha_state."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"state": "45.123"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch("python.van_weather.main.requests.get", return_value=mock_response) as mock_get:
|
||||
result = get_ha_state("http://ha.local", "token", "sensor.lat")
|
||||
|
||||
assert result == 45.123
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
def test_parse_daily_forecast() -> None:
|
||||
"""Test parse_daily_forecast."""
|
||||
data = {
|
||||
"daily": {
|
||||
"data": [
|
||||
{
|
||||
"time": 1704067200,
|
||||
"icon": "clear-day",
|
||||
"temperatureHigh": 75.0,
|
||||
"temperatureLow": 55.0,
|
||||
"precipProbability": 0.1,
|
||||
},
|
||||
{
|
||||
"time": 1704153600,
|
||||
"icon": "rain",
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = parse_daily_forecast(data)
|
||||
assert len(result) == 2
|
||||
assert result[0].condition == "sunny"
|
||||
assert result[0].temperature == 75.0
|
||||
|
||||
|
||||
def test_parse_daily_forecast_empty() -> None:
|
||||
"""Test parse_daily_forecast with empty data."""
|
||||
result = parse_daily_forecast({})
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_daily_forecast_no_timestamp() -> None:
|
||||
"""Test parse_daily_forecast skips entries without time."""
|
||||
data = {"daily": {"data": [{"icon": "clear-day"}]}}
|
||||
result = parse_daily_forecast(data)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_hourly_forecast() -> None:
|
||||
"""Test parse_hourly_forecast."""
|
||||
data = {
|
||||
"hourly": {
|
||||
"data": [
|
||||
{
|
||||
"time": 1704067200,
|
||||
"icon": "cloudy",
|
||||
"temperature": 65.0,
|
||||
"precipProbability": 0.3,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = parse_hourly_forecast(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].condition == "cloudy"
|
||||
|
||||
|
||||
def test_parse_hourly_forecast_empty() -> None:
|
||||
"""Test parse_hourly_forecast with empty data."""
|
||||
result = parse_hourly_forecast({})
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_hourly_forecast_no_timestamp() -> None:
|
||||
"""Test parse_hourly_forecast skips entries without time."""
|
||||
data = {"hourly": {"data": [{"icon": "rain"}]}}
|
||||
result = parse_hourly_forecast(data)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_weather() -> None:
|
||||
"""Test fetch_weather."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"currently": {
|
||||
"temperature": 72.0,
|
||||
"apparentTemperature": 70.0,
|
||||
"humidity": 0.5,
|
||||
"windSpeed": 10.0,
|
||||
"windBearing": 180.0,
|
||||
"icon": "clear-day",
|
||||
"summary": "Clear",
|
||||
"pressure": 1013.0,
|
||||
"visibility": 10.0,
|
||||
},
|
||||
"daily": {"data": []},
|
||||
"hourly": {"data": []},
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch("python.van_weather.main.requests.get", return_value=mock_response):
|
||||
weather = fetch_weather("apikey", 45.0, -122.0)
|
||||
|
||||
assert weather.temperature == 72.0
|
||||
assert weather.condition == "sunny"
|
||||
|
||||
|
||||
def test_post_weather_data() -> None:
|
||||
"""Test _post_weather_data."""
|
||||
weather = Weather(
|
||||
temperature=72.0,
|
||||
feels_like=70.0,
|
||||
humidity=0.5,
|
||||
wind_speed=10.0,
|
||||
wind_bearing=180.0,
|
||||
condition="sunny",
|
||||
pressure=1013.0,
|
||||
visibility=10.0,
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch("python.van_weather.main.requests.post", return_value=mock_response) as mock_post:
|
||||
_post_weather_data("http://ha.local", "token", weather)
|
||||
|
||||
assert mock_post.call_count > 0
|
||||
|
||||
|
||||
def test_post_to_ha_retry_on_failure() -> None:
|
||||
"""Test post_to_ha retries on failure."""
|
||||
weather = Weather(temperature=72.0)
|
||||
|
||||
import requests
|
||||
|
||||
with (
|
||||
patch("python.van_weather.main._post_weather_data", side_effect=requests.RequestException("fail")),
|
||||
patch("python.van_weather.main.time.sleep"),
|
||||
):
|
||||
post_to_ha("http://ha.local", "token", weather)
|
||||
|
||||
|
||||
def test_post_to_ha_success() -> None:
|
||||
"""Test post_to_ha calls _post_weather_data on each attempt."""
|
||||
weather = Weather(temperature=72.0)
|
||||
|
||||
with patch("python.van_weather.main._post_weather_data") as mock_post:
|
||||
post_to_ha("http://ha.local", "token", weather)
|
||||
# The function loops through all attempts even on success (no break)
|
||||
assert mock_post.call_count == 6
|
||||
|
||||
|
||||
def test_update_weather() -> None:
|
||||
"""Test update_weather orchestration."""
|
||||
config = Config(ha_url="http://ha.local", ha_token="token", pirate_weather_api_key="key")
|
||||
|
||||
with (
|
||||
patch("python.van_weather.main.get_ha_state", side_effect=[45.123, -122.456]),
|
||||
patch("python.van_weather.main.fetch_weather", return_value=Weather(temperature=72.0, condition="sunny")),
|
||||
patch("python.van_weather.main.post_to_ha"),
|
||||
):
|
||||
update_weather(config)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Tests for van_weather/main.py main() function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python.van_weather.main import main
|
||||
|
||||
|
||||
def test_van_weather_main() -> None:
|
||||
"""Test main sets up scheduler."""
|
||||
with (
|
||||
patch("python.van_weather.main.BlockingScheduler") as mock_sched_cls,
|
||||
patch("python.van_weather.main.configure_logger"),
|
||||
):
|
||||
mock_sched = MagicMock()
|
||||
mock_sched_cls.return_value = mock_sched
|
||||
|
||||
main(
|
||||
ha_url="http://ha.local",
|
||||
ha_token="token",
|
||||
api_key="key",
|
||||
interval=60,
|
||||
log_level="INFO",
|
||||
)
|
||||
|
||||
mock_sched.add_job.assert_called_once()
|
||||
mock_sched.start.assert_called_once()
|
||||
@@ -1,361 +0,0 @@
|
||||
"""Tests for python/zfs/dataset.py covering missing lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python.zfs.dataset import Dataset, Snapshot, _zfs_list
|
||||
|
||||
DATASET = "python.zfs.dataset"
|
||||
|
||||
SAMPLE_SNAPSHOT_DATA = {
|
||||
"createtxg": "123",
|
||||
"properties": {
|
||||
"creation": {"value": "1620000000"},
|
||||
"defer_destroy": {"value": "off"},
|
||||
"guid": {"value": "456"},
|
||||
"objsetid": {"value": "789"},
|
||||
"referenced": {"value": "1024"},
|
||||
"used": {"value": "512"},
|
||||
"userrefs": {"value": "0"},
|
||||
"version": {"value": "1"},
|
||||
"written": {"value": "2048"},
|
||||
},
|
||||
"name": "pool/dataset@snap1",
|
||||
}
|
||||
|
||||
SAMPLE_DATASET_DATA = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {
|
||||
"pool/dataset": {
|
||||
"properties": {
|
||||
"aclinherit": {"value": "restricted"},
|
||||
"aclmode": {"value": "discard"},
|
||||
"acltype": {"value": "off"},
|
||||
"available": {"value": "1000000"},
|
||||
"canmount": {"value": "on"},
|
||||
"checksum": {"value": "on"},
|
||||
"clones": {"value": ""},
|
||||
"compression": {"value": "lz4"},
|
||||
"copies": {"value": "1"},
|
||||
"createtxg": {"value": "1234"},
|
||||
"creation": {"value": "1620000000"},
|
||||
"dedup": {"value": "off"},
|
||||
"devices": {"value": "on"},
|
||||
"encryption": {"value": "off"},
|
||||
"exec": {"value": "on"},
|
||||
"filesystem_limit": {"value": "none"},
|
||||
"guid": {"value": "5678"},
|
||||
"keystatus": {"value": "none"},
|
||||
"logbias": {"value": "latency"},
|
||||
"mlslabel": {"value": "none"},
|
||||
"mounted": {"value": "yes"},
|
||||
"mountpoint": {"value": "/pool/dataset"},
|
||||
"quota": {"value": "0"},
|
||||
"readonly": {"value": "off"},
|
||||
"recordsize": {"value": "131072"},
|
||||
"redundant_metadata": {"value": "all"},
|
||||
"referenced": {"value": "512000"},
|
||||
"refquota": {"value": "0"},
|
||||
"refreservation": {"value": "0"},
|
||||
"reservation": {"value": "0"},
|
||||
"setuid": {"value": "on"},
|
||||
"sharenfs": {"value": "off"},
|
||||
"snapdir": {"value": "hidden"},
|
||||
"snapshot_limit": {"value": "none"},
|
||||
"sync": {"value": "standard"},
|
||||
"used": {"value": "1024000"},
|
||||
"usedbychildren": {"value": "512000"},
|
||||
"usedbydataset": {"value": "256000"},
|
||||
"usedbysnapshots": {"value": "256000"},
|
||||
"version": {"value": "5"},
|
||||
"volmode": {"value": "default"},
|
||||
"volsize": {"value": "none"},
|
||||
"vscan": {"value": "off"},
|
||||
"written": {"value": "4096"},
|
||||
"xattr": {"value": "on"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_dataset() -> Dataset:
|
||||
"""Create a Dataset instance with mocked _zfs_list."""
|
||||
with patch(f"{DATASET}._zfs_list", return_value=SAMPLE_DATASET_DATA):
|
||||
return Dataset("pool/dataset")
|
||||
|
||||
|
||||
# --- _zfs_list version check error (line 29) ---
|
||||
|
||||
|
||||
def test_zfs_list_returns_data_on_valid_version() -> None:
|
||||
"""Test _zfs_list returns parsed data when version is correct."""
|
||||
valid_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {},
|
||||
}
|
||||
with patch(f"{DATASET}.bash_wrapper", return_value=(json.dumps(valid_data), 0)):
|
||||
result = _zfs_list("zfs list pool -pHj -o all")
|
||||
assert result == valid_data
|
||||
|
||||
|
||||
def test_zfs_list_raises_on_wrong_vers_minor() -> None:
|
||||
"""Test _zfs_list raises RuntimeError when vers_minor is wrong."""
|
||||
bad_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 2, "command": "zfs list"},
|
||||
}
|
||||
with (
|
||||
patch(f"{DATASET}.bash_wrapper", return_value=(json.dumps(bad_data), 0)),
|
||||
pytest.raises(RuntimeError, match="Datasets are not in the correct format"),
|
||||
):
|
||||
_zfs_list("zfs list pool -pHj -o all")
|
||||
|
||||
|
||||
def test_zfs_list_raises_on_wrong_command() -> None:
|
||||
"""Test _zfs_list raises RuntimeError when command field is wrong."""
|
||||
bad_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zpool list"},
|
||||
}
|
||||
with (
|
||||
patch(f"{DATASET}.bash_wrapper", return_value=(json.dumps(bad_data), 0)),
|
||||
pytest.raises(RuntimeError, match="Datasets are not in the correct format"),
|
||||
):
|
||||
_zfs_list("zfs list pool -pHj -o all")
|
||||
|
||||
|
||||
# --- Snapshot.__repr__() (line 52) ---
|
||||
|
||||
|
||||
def test_snapshot_repr() -> None:
|
||||
"""Test Snapshot __repr__ returns correct format."""
|
||||
snapshot = Snapshot(SAMPLE_SNAPSHOT_DATA)
|
||||
result = repr(snapshot)
|
||||
assert result == "name=snap1 used=512 refer=1024"
|
||||
|
||||
|
||||
def test_snapshot_repr_different_values() -> None:
|
||||
"""Test Snapshot __repr__ with different values."""
|
||||
data = {
|
||||
**SAMPLE_SNAPSHOT_DATA,
|
||||
"name": "pool/dataset@daily-2024-01-01",
|
||||
"properties": {
|
||||
**SAMPLE_SNAPSHOT_DATA["properties"],
|
||||
"used": {"value": "999"},
|
||||
"referenced": {"value": "5555"},
|
||||
},
|
||||
}
|
||||
snapshot = Snapshot(data)
|
||||
assert "daily-2024-01-01" in repr(snapshot)
|
||||
assert "999" in repr(snapshot)
|
||||
assert "5555" in repr(snapshot)
|
||||
|
||||
|
||||
# --- Dataset.get_snapshots() (lines 113-115) ---
|
||||
|
||||
|
||||
def test_dataset_get_snapshots() -> None:
|
||||
"""Test Dataset.get_snapshots returns list of Snapshot objects."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
snapshot_list_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {
|
||||
"pool/dataset@snap1": SAMPLE_SNAPSHOT_DATA,
|
||||
"pool/dataset@snap2": {
|
||||
**SAMPLE_SNAPSHOT_DATA,
|
||||
"name": "pool/dataset@snap2",
|
||||
},
|
||||
},
|
||||
}
|
||||
with patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data):
|
||||
snapshots = dataset.get_snapshots()
|
||||
|
||||
assert snapshots is not None
|
||||
assert len(snapshots) == 2
|
||||
assert all(isinstance(s, Snapshot) for s in snapshots)
|
||||
|
||||
|
||||
def test_dataset_get_snapshots_empty() -> None:
|
||||
"""Test Dataset.get_snapshots returns empty list when no snapshots."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
snapshot_list_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {},
|
||||
}
|
||||
with patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data):
|
||||
snapshots = dataset.get_snapshots()
|
||||
|
||||
assert snapshots == []
|
||||
|
||||
|
||||
# --- Dataset.create_snapshot() (lines 123-133) ---
|
||||
|
||||
|
||||
def test_dataset_create_snapshot_success() -> None:
|
||||
"""Test create_snapshot returns success message when return code is 0."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch(f"{DATASET}.bash_wrapper", return_value=("", 0)):
|
||||
result = dataset.create_snapshot("my-snap")
|
||||
|
||||
assert result == "snapshot created"
|
||||
|
||||
|
||||
def test_dataset_create_snapshot_already_exists() -> None:
|
||||
"""Test create_snapshot returns message when snapshot already exists."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
snapshot_list_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {
|
||||
"pool/dataset@my-snap": SAMPLE_SNAPSHOT_DATA,
|
||||
},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(f"{DATASET}.bash_wrapper", return_value=("dataset already exists", 1)),
|
||||
patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data),
|
||||
):
|
||||
# The snapshot data has name "pool/dataset@snap1" which extracts to "snap1"
|
||||
# We need the snapshot name to match, so use "snap1"
|
||||
result = dataset.create_snapshot("snap1")
|
||||
|
||||
assert "already exists" in result
|
||||
|
||||
|
||||
def test_dataset_create_snapshot_failure() -> None:
|
||||
"""Test create_snapshot returns failure message on unknown error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
snapshot_list_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(f"{DATASET}.bash_wrapper", return_value=("some error", 1)),
|
||||
patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data),
|
||||
):
|
||||
result = dataset.create_snapshot("new-snap")
|
||||
|
||||
assert "Failed to create snapshot" in result
|
||||
|
||||
|
||||
def test_dataset_create_snapshot_failure_no_snapshots() -> None:
|
||||
"""Test create_snapshot failure when get_snapshots returns empty list."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
# get_snapshots returns empty list (falsy), so the if branch is skipped
|
||||
snapshot_list_data = {
|
||||
"output_version": {"vers_major": 0, "vers_minor": 1, "command": "zfs list"},
|
||||
"datasets": {},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(f"{DATASET}.bash_wrapper", return_value=("error", 1)),
|
||||
patch(f"{DATASET}._zfs_list", return_value=snapshot_list_data),
|
||||
):
|
||||
result = dataset.create_snapshot("nonexistent")
|
||||
|
||||
assert "Failed to create snapshot" in result
|
||||
|
||||
|
||||
# --- Dataset.delete_snapshot() (lines 141-148) ---
|
||||
|
||||
|
||||
def test_dataset_delete_snapshot_success() -> None:
|
||||
"""Test delete_snapshot returns None on success."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch(f"{DATASET}.bash_wrapper", return_value=("", 0)):
|
||||
result = dataset.delete_snapshot("my-snap")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_dataset_delete_snapshot_dependent_clones() -> None:
|
||||
"""Test delete_snapshot returns message when snapshot has dependent clones."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
error_msg = "cannot destroy 'pool/dataset@my-snap': snapshot has dependent clones"
|
||||
with patch(f"{DATASET}.bash_wrapper", return_value=(error_msg, 1)):
|
||||
result = dataset.delete_snapshot("my-snap")
|
||||
|
||||
assert result == "snapshot has dependent clones"
|
||||
|
||||
|
||||
def test_dataset_delete_snapshot_other_failure() -> None:
|
||||
"""Test delete_snapshot raises RuntimeError on other failures."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with (
|
||||
patch(f"{DATASET}.bash_wrapper", return_value=("some other error", 1)),
|
||||
pytest.raises(RuntimeError, match="Failed to delete snapshot"),
|
||||
):
|
||||
dataset.delete_snapshot("my-snap")
|
||||
|
||||
|
||||
# --- Dataset.__repr__() (line 152) ---
|
||||
|
||||
|
||||
def test_dataset_repr() -> None:
|
||||
"""Test Dataset __repr__ includes all attributes."""
|
||||
dataset = _make_dataset()
|
||||
result = repr(dataset)
|
||||
|
||||
expected_attrs = [
|
||||
"aclinherit",
|
||||
"aclmode",
|
||||
"acltype",
|
||||
"available",
|
||||
"canmount",
|
||||
"checksum",
|
||||
"clones",
|
||||
"compression",
|
||||
"copies",
|
||||
"createtxg",
|
||||
"creation",
|
||||
"dedup",
|
||||
"devices",
|
||||
"encryption",
|
||||
"exec",
|
||||
"filesystem_limit",
|
||||
"guid",
|
||||
"keystatus",
|
||||
"logbias",
|
||||
"mlslabel",
|
||||
"mounted",
|
||||
"mountpoint",
|
||||
"name",
|
||||
"quota",
|
||||
"readonly",
|
||||
"recordsize",
|
||||
"redundant_metadata",
|
||||
"referenced",
|
||||
"refquota",
|
||||
"refreservation",
|
||||
"reservation",
|
||||
"setuid",
|
||||
"sharenfs",
|
||||
"snapdir",
|
||||
"snapshot_limit",
|
||||
"sync",
|
||||
"used",
|
||||
"usedbychildren",
|
||||
"usedbydataset",
|
||||
"usedbysnapshots",
|
||||
"version",
|
||||
"volmode",
|
||||
"volsize",
|
||||
"vscan",
|
||||
"written",
|
||||
"xattr",
|
||||
]
|
||||
|
||||
for attr in expected_attrs:
|
||||
assert f"self.{attr}=" in result, f"Missing {attr} in repr"
|
||||
Reference in New Issue
Block a user