Compare commits

..

37 Commits

Author SHA1 Message Date
Richie 7740ebb594 opning ports for testing
pytest / pytest (pull_request) Failing after 26s
build_systems / build-bob (pull_request) Successful in 47s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m0s
treefmt / nix fmt (pull_request) Successful in 5s
build_systems / build-brain (pull_request) Successful in 46s
build_systems / build-leviathan (pull_request) Successful in 53s
build_systems / build-jeeves (pull_request) Successful in 2m37s
2026-06-13 22:17:46 -04:00
Richie 07a9adfdd5 added a index for the VEctor DB 2026-06-13 22:17:46 -04:00
Richie 74e4c2e921 improved BM25 write 2026-06-13 22:17:46 -04:00
Richie 70d24c2a85 added ZstdMiddleware to ebook_search 2026-06-13 22:17:46 -04:00
Richie 773e9f9d4a added vector_engine to fix name postgres name space issue 2026-06-13 22:17:46 -04:00
Richie 8b608f7aa0 reworked ebook_search routers 2026-06-13 22:17:46 -04:00
Richie be6d8c9db9 made fastapi tools 2026-06-13 22:17:46 -04:00
Richie 6bb6f935b1 added proper cache invalidation to load_bm25_corpus 2026-06-13 22:17:46 -04:00
Richie c4e8a395d2 updated tests 2026-06-13 22:17:46 -04:00
Richie d51ed42919 improved reranking weights 2026-06-13 22:17:46 -04:00
Richie 7466c7ed3a fixed duplicat enrichment 2026-06-13 22:17:46 -04:00
Richie e9b574aa58 improved queary for vector search 2026-06-13 22:17:46 -04:00
Richie bcd855cb88 cleaned up installer.py 2026-06-13 22:17:46 -04:00
Richie b976fbf13f add .ebook_search_bm25 to gitignore 2026-06-13 22:17:46 -04:00
Richie 5ba41feb2d updated python 2026-06-13 22:17:46 -04:00
Richie 3ebd9df21f setup tests 2026-06-13 22:17:46 -04:00
Richie 73177ef399 build api and frountend 2026-06-13 22:17:46 -04:00
Richie d81f5a0ec1 added answer.py and config 2026-06-13 22:17:46 -04:00
Richie 4dac3d1c60 added __init__ 2026-06-13 22:17:46 -04:00
Richie a802dbd2b3 made llm_interface.py 2026-06-13 22:17:46 -04:00
Richie 8e6a2809b0 added rerank 2026-06-13 22:17:46 -04:00
Richie c5293b0dcf built ingest 2026-06-13 22:17:46 -04:00
Richie d740b25b2c built rag search setup 2026-06-13 22:17:46 -04:00
Richie febb88dc77 set up embedding system 2026-06-13 22:17:46 -04:00
Richie b9949e8d72 built BM25 search foundation 2026-06-13 22:17:46 -04:00
Richie 345384e76f clean up 2026-06-13 22:17:46 -04:00
Richie ac899d5fca added ebook embedding to orm 2026-06-13 22:17:46 -04:00
Richie 76e206a727 removed hedgedoc 2026-06-13 22:17:46 -04:00
Richie 090e8dddca adding embedding Models to jeeves 2026-06-13 22:17:46 -04:00
Richie c8505d413c updated series_index to float and added UniqueConstraint to audiobook and audiobook_author 2026-06-13 22:17:46 -04:00
Richie ff685112a6 fixed omnibus for audio books 2026-06-13 22:17:46 -04:00
Richie e113dc3ef3 moved installer to python dir 2026-06-13 22:17:46 -04:00
Richie 340f37f114 deleted frontend dir 2026-06-13 22:17:46 -04:00
Richie 5611daab97 added llm_tool_calling.py 2026-06-13 22:17:46 -04:00
Richie f20bee82ec built workflow 2026-06-13 22:17:46 -04:00
Richie ef4e6f75a5 Add catalog.py for manually adding authors and series to the database. 2026-06-13 22:17:46 -04:00
Richie 1ab5d3d650 adding audiobook data to DB 2026-06-13 22:17:46 -04:00
81 changed files with 3420 additions and 1540 deletions
-1
View File
@@ -242,7 +242,6 @@
"referer", "referer",
"REFERERS", "REFERERS",
"relatime", "relatime",
"rerank",
"Rhosts", "Rhosts",
"ripgrep", "ripgrep",
"roboto", "roboto",
+1 -1
View File
@@ -55,7 +55,6 @@
polars polars
psycopg psycopg
pydantic pydantic
pydantic-settings
pyfakefs pyfakefs
pytest pytest
pytest-cov pytest-cov
@@ -65,6 +64,7 @@
ruff ruff
scalene scalene
sqlalchemy sqlalchemy
sqlalchemy
bm25s bm25s
tenacity tenacity
textual textual
+5 -8
View File
@@ -3,7 +3,7 @@ name = "system_tools"
version = "0.1.0" version = "0.1.0"
description = "" description = ""
authors = [{ name = "Richie Cahill", email = "richie@tmmworkshop.com" }] authors = [{ name = "Richie Cahill", email = "richie@tmmworkshop.com" }]
requires-python = "~=3.14.0" requires-python = "~=3.13.0"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
# these dependencies are a best effort and aren't guaranteed to work # these dependencies are a best effort and aren't guaranteed to work
@@ -12,23 +12,20 @@ dependencies = [
"alembic", "alembic",
"apprise", "apprise",
"apscheduler", "apscheduler",
"fastapi",
"fastapi-cli",
"httpx", "httpx",
"python-multipart",
"polars", "polars",
"psycopg[binary]", "psycopg[binary]",
"pydantic", "pydantic",
"pydantic-settings", "pyyaml",
"python-multipart",
"sqlalchemy", "sqlalchemy",
"tenacity",
"tinytuya",
"typer", "typer",
"websockets", "websockets",
] ]
[project.scripts] [project.scripts]
database = "python.database_cli:app" database = "python.database_cli:app"
van-inventory = "python.van_inventory.main:serve"
whisper-transcribe = "python.tools.whisper.transcribe:main" whisper-transcribe = "python.tools.whisper.transcribe:main"
[dependency-groups] [dependency-groups]
@@ -44,7 +41,7 @@ dev = [
[tool.ruff] [tool.ruff]
target-version = "py314" target-version = "py313"
line-length = 120 line-length = 120
+2 -6
View File
@@ -1,10 +1,9 @@
"""FastAPI interface for Contact database.""" """FastAPI interface for Contact database."""
from __future__ import annotations
import logging import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Annotated from typing import Annotated
import typer import typer
import uvicorn import uvicorn
@@ -15,9 +14,6 @@ from python.common import configure_logger
from python.fastapi_tools import ZstdMiddleware from python.fastapi_tools import ZstdMiddleware
from python.orm.common import get_postgres_engine from python.orm.common import get_postgres_engine
if TYPE_CHECKING:
from collections.abc import AsyncIterator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
+1 -1
View File
@@ -9,7 +9,7 @@ from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from python.fastapi_tools.db import DbSession # noqa: TC001 this is a FastAPI needed at runtime from python.fastapi_tools.db import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates" TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+1 -1
View File
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from python.fastapi_tools.db import DbSession # noqa: TC001 this is a FastAPI needed at runtime from python.fastapi_tools.db import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates" TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+15 -3
View File
@@ -4,10 +4,12 @@ Usage:
database <db_name> <command> [args...] database <db_name> <command> [args...]
Examples: Examples:
database van_inventory upgrade head
database van_inventory downgrade head-1
database van_inventory revision --autogenerate -m "add meals table"
database van_inventory check
database richie check database richie check
database richie upgrade head database richie upgrade head
database richie downgrade head-1
database richie revision --autogenerate -m "add meals table"
""" """
from __future__ import annotations from __future__ import annotations
@@ -46,7 +48,10 @@ class DatabaseConfig:
def alembic_config(self) -> Config: def alembic_config(self) -> Config:
"""Build an alembic Config for this database.""" """Build an alembic Config for this database."""
cfg = Config() # Runtime import needed — Config is in TYPE_CHECKING for the return type annotation
from alembic.config import Config as AlembicConfig # noqa: PLC0415
cfg = AlembicConfig()
cfg.set_main_option("script_location", self.script_location) cfg.set_main_option("script_location", self.script_location)
cfg.set_main_option("file_template", self.file_template) cfg.set_main_option("file_template", self.file_template)
cfg.set_main_option("prepend_sys_path", ".") cfg.set_main_option("prepend_sys_path", ".")
@@ -71,6 +76,13 @@ DATABASES: dict[str, DatabaseConfig] = {
base_class_name="RichieBase", base_class_name="RichieBase",
models_module="python.orm.richie", models_module="python.orm.richie",
), ),
"van_inventory": DatabaseConfig(
env_prefix="VAN_INVENTORY",
version_location="python/alembic/van_inventory/versions",
base_module="python.orm.van_inventory.base",
base_class_name="VanInventoryBase",
models_module="python.orm.van_inventory.models",
),
} }
-24
View File
@@ -1,24 +0,0 @@
"""FastAPI dependencies for the EPUB search app."""
from __future__ import annotations
from typing import Annotated
from fastapi import Depends, Request
from sqlalchemy.engine import Engine
from python.ebook_search.config import EbookSearchConfig
def get_config(request: Request) -> EbookSearchConfig:
"""Get the loaded search config from app state."""
return request.app.state.config
def get_engine(request: Request) -> Engine:
"""Get the database engine from app state."""
return request.app.state.engine
AppConfig = Annotated[EbookSearchConfig, Depends(get_config)]
AppEngine = Annotated[Engine, Depends(get_engine)]
+11 -18
View File
@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from python.common import configure_logger from python.common import configure_logger
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
from python.ebook_search.api.routes import admin_router, health_router, page_router, search_router from python.ebook_search.api.routes import admin_router, page_router, search_router
from python.ebook_search.api.web import STATIC_DIR from python.ebook_search.api.web import STATIC_DIR
from python.ebook_search.bm25_corpus import ensure_bm25_corpus from python.ebook_search.bm25_corpus import ensure_bm25_corpus
from python.ebook_search.config import load_config from python.ebook_search.config import load_config
@@ -32,24 +32,9 @@ logger = logging.getLogger(__name__)
async def lifespan(app: FastAPI) -> AsyncIterator[None]: async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage application startup and shutdown resources.""" """Manage application startup and shutdown resources."""
logger.info("ebook_search_startup") logger.info("ebook_search_startup")
config = load_config()
app.state.config = config
logger.info(
"ebook_search_config_loaded top_k=%s embedding_model=%s embedding_base_url=%s vllm_base_url=%s "
"rerank_enabled=%s answer_enabled=%s library_paths=%s",
config.top_k,
config.embedding_model,
config.embedding_base_url,
config.vllm_base_url,
config.rerank.enabled,
config.answer_enabled,
len(config.library_paths),
)
if not config.library_paths:
logger.warning("ebook_search_no_library_paths_configured")
app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True) app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True)
with Session(app.state.engine) as session: with Session(app.state.engine) as session:
ensure_bm25_corpus(session, config) ensure_bm25_corpus(session, app.state.config)
try: try:
yield yield
finally: finally:
@@ -63,9 +48,17 @@ def create_app() -> FastAPI:
app = FastAPI(title="EPUB Search", lifespan=lifespan) app = FastAPI(title="EPUB Search", lifespan=lifespan)
app.add_middleware(ZstdMiddleware) app.add_middleware(ZstdMiddleware)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.state.config = load_config()
logger.info(
"ebook_search_config_loaded top_k=%s embedding_model=%s rerank_enabled=%s answer_enabled=%s library_paths=%s",
app.state.config.top_k,
app.state.config.embedding_model,
app.state.config.rerank.enabled,
app.state.config.answer_enabled,
len(app.state.config.library_paths),
)
app.include_router(admin_router) app.include_router(admin_router)
app.include_router(health_router)
app.include_router(page_router) app.include_router(page_router)
app.include_router(search_router) app.include_router(search_router)
@@ -1,13 +1,11 @@
"""EPUB search web route modules.""" """EPUB search web route modules."""
from python.ebook_search.api.routes.admin import router as admin_router from python.ebook_search.api.routes.admin import router as admin_router
from python.ebook_search.api.routes.health import router as health_router
from python.ebook_search.api.routes.page import router as page_router from python.ebook_search.api.routes.page import router as page_router
from python.ebook_search.api.routes.search import router as search_router from python.ebook_search.api.routes.search import router as search_router
__all__ = [ __all__ = [
"admin_router", "admin_router",
"health_router",
"page_router", "page_router",
"search_router", "search_router",
] ]
+16 -12
View File
@@ -3,37 +3,38 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import replace
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from sqlalchemy.orm import Session
from python.ebook_search.api.bm25_tasks import schedule_bm25_refresh from python.ebook_search.api.bm25_tasks import schedule_bm25_refresh
from python.ebook_search.api.dependencies import (
AppConfig, # noqa: TC001 FastAPI resolves this annotated dependency at runtime
)
from python.ebook_search.api.web import templates from python.ebook_search.api.web import templates
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
from python.ebook_search.ingest import ingest_configured_paths from python.ebook_search.ingest import ingest_configured_paths
from python.fastapi_tools import DbSession # noqa: TC001 FastAPI resolves this annotated dependency at runtime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin") router = APIRouter(prefix="/admin")
EMBED_ALL_BATCH_SIZE = 32
@router.get("", response_class=HTMLResponse) @router.get("", response_class=HTMLResponse)
def admin(request: Request, config: AppConfig, session: DbSession) -> HTMLResponse: def admin(request: Request) -> HTMLResponse:
"""Render the admin page.""" """Render the admin page."""
with Session(request.app.state.engine) as session:
stats = embedding_model_stats(session) stats = embedding_model_stats(session)
logger.info("ebook_admin_page_loaded models=%s", len(stats)) logger.info("ebook_admin_page_loaded models=%s", len(stats))
return templates.TemplateResponse(request, "admin.html", {"config": config, "stats": stats}) return templates.TemplateResponse(request, "admin.html", {"config": request.app.state.config, "stats": stats})
@router.post("/scan", response_class=HTMLResponse) @router.post("/scan", response_class=HTMLResponse)
def scan_library(request: Request, config: AppConfig, session: DbSession) -> HTMLResponse: def scan_library(request: Request) -> HTMLResponse:
"""Scan configured library paths for EPUB changes.""" """Scan configured library paths for EPUB changes."""
try: try:
count = ingest_configured_paths(session, config) with Session(request.app.state.engine) as session:
count = ingest_configured_paths(session, request.app.state.config)
session.commit() session.commit()
except Exception as error: except Exception as error:
logger.exception("ebook_admin_scan_failed") logger.exception("ebook_admin_scan_failed")
@@ -46,10 +47,11 @@ def scan_library(request: Request, config: AppConfig, session: DbSession) -> HTM
@router.post("/embed-missing", response_class=HTMLResponse) @router.post("/embed-missing", response_class=HTMLResponse)
def embed_missing(request: Request, config: AppConfig, session: DbSession) -> HTMLResponse: def embed_missing(request: Request) -> HTMLResponse:
"""Embed chunks missing vectors for the configured model.""" """Embed chunks missing vectors for the configured model."""
try: try:
count = embed_missing_chunks(session, config) with Session(request.app.state.engine) as session:
count = embed_missing_chunks(session, request.app.state.config)
session.commit() session.commit()
except Exception as error: except Exception as error:
logger.exception("ebook_admin_embed_missing_failed") logger.exception("ebook_admin_embed_missing_failed")
@@ -64,11 +66,13 @@ def embed_missing(request: Request, config: AppConfig, session: DbSession) -> HT
@router.post("/embed-all", response_class=HTMLResponse) @router.post("/embed-all", response_class=HTMLResponse)
def embed_all(request: Request, config: AppConfig, session: DbSession) -> HTMLResponse: def embed_all(request: Request) -> HTMLResponse:
"""Embed all chunks missing vectors in fixed-size batches.""" """Embed all chunks missing vectors in fixed-size batches."""
total = 0 total = 0
batches = 0 batches = 0
config = replace(request.app.state.config, embedding_batch_size=EMBED_ALL_BATCH_SIZE)
try: try:
with Session(request.app.state.engine) as session:
while True: while True:
count = embed_missing_chunks(session, config) count = embed_missing_chunks(session, config)
if count == 0: if count == 0:
@@ -99,5 +103,5 @@ def embed_all(request: Request, config: AppConfig, session: DbSession) -> HTMLRe
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"partials/admin_status.html", "partials/admin_status.html",
{"message": f"Embedded {total} chunks in {batches} batches of {config.embedding_batch_size}"}, {"message": f"Embedded {total} chunks in {batches} batches of {EMBED_ALL_BATCH_SIZE}"},
) )
-97
View File
@@ -1,97 +0,0 @@
"""Liveness and readiness routes for the EPUB search service."""
from __future__ import annotations
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from sqlalchemy import literal, select
from sqlalchemy.exc import SQLAlchemyError
from python.ebook_search.api.dependencies import (
AppConfig, # noqa: TC001 FastAPI resolves this annotated dependency at runtime
)
from python.ebook_search.bm25_corpus import bm25_index_exists, bm25_index_path, read_bm25_manifest
from python.ebook_search.llm_interface import check_chat_endpoint, check_embedding_endpoint
from python.fastapi_tools import DbSession # noqa: TC001 FastAPI resolves this annotated dependency at runtime
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from python.ebook_search.config import EbookSearchConfig
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/health")
def health() -> dict[str, str]:
"""Liveness probe that returns ok without touching dependencies."""
return {"status": "ok"}
@router.get("/ready")
def ready(config: AppConfig, session: DbSession) -> JSONResponse:
"""Readiness probe reporting database, embedding endpoint, and BM25 index status."""
database_ok = check_database(session)
embedding_ok = check_embedding_endpoint(config)
chat_status = chat_endpoint_status(config)
bm25_status = check_bm25_status(config)
checks = {
"database": "ok" if database_ok else "fail",
"embedding": "ok" if embedding_ok else "fail",
"chat": chat_status,
"bm25": bm25_status,
}
if not database_ok:
status = "unavailable"
status_code = HTTPStatus.SERVICE_UNAVAILABLE
elif not embedding_ok or chat_status == "fail" or bm25_status == "missing":
status = "degraded"
status_code = HTTPStatus.OK
else:
status = "ready"
status_code = HTTPStatus.OK
logger.info(
"ebook_ready_check status=%s database=%s embedding=%s chat=%s bm25=%s",
status,
database_ok,
embedding_ok,
chat_status,
bm25_status,
)
return JSONResponse(content={"status": status, "checks": checks}, status_code=status_code)
def chat_endpoint_status(config: EbookSearchConfig) -> str:
"""Return the answering chat endpoint status, or disabled when answers are off."""
if not config.answer_enabled:
return "disabled"
return "ok" if check_chat_endpoint(config) else "fail"
def check_database(session: Session) -> bool:
"""Return whether the database answers a trivial query."""
try:
session.execute(select(literal(1)))
except SQLAlchemyError as error:
logger.warning("ebook_ready_database_unavailable error=%s", error)
return False
return True
def check_bm25_status(config: EbookSearchConfig) -> str:
"""Return the persisted BM25 index status without loading it into memory."""
index_path = bm25_index_path(config)
manifest = read_bm25_manifest(index_path)
if manifest is None or not bm25_index_exists(index_path, manifest):
return "missing"
if manifest.chunk_count == 0:
return "empty"
return "ok"
+7 -8
View File
@@ -7,12 +7,9 @@ import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from python.ebook_search.api.dependencies import (
AppConfig, # noqa: TC001 FastAPI resolves this annotated dependency at runtime
)
from python.ebook_search.api.web import templates from python.ebook_search.api.web import templates
from python.fastapi_tools import DbSession # noqa: TC001 FastAPI resolves this annotated dependency at runtime
from python.orm.richie import EbookSource from python.orm.richie import EbookSource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,22 +18,24 @@ router = APIRouter()
@router.get("/", response_class=HTMLResponse) @router.get("/", response_class=HTMLResponse)
def index(request: Request, config: AppConfig) -> HTMLResponse: def index(request: Request) -> HTMLResponse:
"""Render the search page.""" """Render the search page."""
return templates.TemplateResponse(request, "search.html", {"config": config}) return templates.TemplateResponse(request, "search.html", {"config": request.app.state.config})
@router.get("/books", response_class=HTMLResponse) @router.get("/books", response_class=HTMLResponse)
def books(request: Request, session: DbSession) -> HTMLResponse: def books(request: Request) -> HTMLResponse:
"""Render the indexed books page.""" """Render the indexed books page."""
with Session(request.app.state.engine) as session:
sources = list(session.scalars(select(EbookSource).order_by(EbookSource.title)).all()) sources = list(session.scalars(select(EbookSource).order_by(EbookSource.title)).all())
logger.info("ebook_books_page_loaded count=%s", len(sources)) logger.info("ebook_books_page_loaded count=%s", len(sources))
return templates.TemplateResponse(request, "books.html", {"sources": sources}) return templates.TemplateResponse(request, "books.html", {"sources": sources})
@router.get("/books/{source_id}", response_class=HTMLResponse) @router.get("/books/{source_id}", response_class=HTMLResponse)
def book_detail(source_id: int, request: Request, session: DbSession) -> HTMLResponse: def book_detail(source_id: int, request: Request) -> HTMLResponse:
"""Render details for one indexed book.""" """Render details for one indexed book."""
with Session(request.app.state.engine) as session:
source = session.get(EbookSource, source_id) source = session.get(EbookSource, source_id)
if source is not None: if source is not None:
chapter_count = len(source.chapters) chapter_count = len(source.chapters)
+14 -72
View File
@@ -5,112 +5,54 @@ from __future__ import annotations
import logging import logging
from dataclasses import replace from dataclasses import replace
from time import perf_counter from time import perf_counter
from typing import TYPE_CHECKING, Annotated from typing import Annotated
from fastapi import APIRouter, Form, Request from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from python.ebook_search.answer import answer_query from python.ebook_search.answer import answer_query
from python.ebook_search.api.dependencies import ( # noqa: TC001 FastAPI resolves these annotated dependencies at runtime
AppConfig,
AppEngine,
)
from python.ebook_search.api.web import templates from python.ebook_search.api.web import templates
from python.ebook_search.guardrails import ( from python.ebook_search.search import search_ebooks
CitationReport,
is_confident,
retrieval_confidence,
validate_citations,
)
from python.ebook_search.search import SearchResponse, search_ebooks
from python.ebook_search.timing import runtime_step_from_start from python.ebook_search.timing import runtime_step_from_start
if TYPE_CHECKING:
from python.ebook_search.config import EbookSearchConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def build_answer(
query: str,
response: SearchResponse,
config: EbookSearchConfig,
) -> tuple[str, bool, CitationReport | None]:
"""Generate the answer for a search, returning ``(answer, low_confidence, citation_report)``."""
if not config.answer_enabled:
logger.info("ebook_answer_skipped_disabled")
return "Answer generation is disabled. Source chunks are shown below.", False, None
if not is_confident(response.results, config):
logger.info(
"ebook_answer_low_confidence confidence=%.4f threshold=%.4f",
retrieval_confidence(response.results),
config.min_retrieval_confidence,
)
answer = (
"Retrieval confidence is low for this query, so answer generation was skipped. "
"Source chunks are shown below."
)
return answer, True, None
try:
answer = answer_query(query, response.results, config)
except RuntimeError as error:
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
return "Answer generation failed. Source chunks are still shown below.", False, None
citation_report = None
if config.validate_citations_enabled and response.results:
citation_report = validate_citations(answer, len(response.results))
if citation_report.invalid or not citation_report.grounded:
logger.warning(
"ebook_answer_citation_issue invalid=%s grounded=%s",
citation_report.invalid,
citation_report.grounded,
)
return answer, False, citation_report
@router.post("/search", response_class=HTMLResponse) @router.post("/search", response_class=HTMLResponse)
def search( def search(
request: Request, request: Request,
config: AppConfig,
engine: AppEngine,
query: Annotated[str, Form()], query: Annotated[str, Form()],
rerank: Annotated[str | None, Form()] = None, rerank: Annotated[str | None, Form()] = None,
) -> HTMLResponse: ) -> HTMLResponse:
"""Run a search and render HTMX results.""" """Run a search and render HTMX results."""
try: try:
response = search_ebooks(engine, query, config, rerank=rerank == "true") response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true")
except Exception as error: except Exception as error:
logger.exception("ebook_search_request_failed") logger.exception("ebook_search_request_failed")
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500) return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
answer_start = perf_counter() answer_start = perf_counter()
answer, low_confidence, citation_report = build_answer(query, response, config) if request.app.state.config.answer_enabled:
answer_step_name = "Answer generation" if config.answer_enabled else "Answer skipped" try:
answer = answer_query(query, response.results, request.app.state.config)
except RuntimeError as error:
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
answer = "Answer generation failed. Source chunks are still shown below."
else:
logger.info("ebook_answer_skipped_disabled")
answer = "Answer generation is disabled. Source chunks are shown below."
answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped"
response = replace( response = replace(
response, response,
timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)), timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)),
) )
for step in response.timings:
logger.info("ebook_search_timing step=%r runtime_ms=%.1f", step.name, step.duration_ms)
logger.info( logger.info(
"ebook_search_request_complete results=%s rank_label=%s runtime_ms=%.1f", "ebook_search_request_complete results=%s rank_label=%s runtime_ms=%.1f",
len(response.results), len(response.results),
response.rank_label, response.rank_label,
response.total_runtime_ms, response.total_runtime_ms,
) )
return templates.TemplateResponse( return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response})
request,
"partials/results.html",
{
"answer": answer,
"response": response,
"low_confidence": low_confidence,
"citation_report": citation_report,
},
)
-9
View File
@@ -138,12 +138,3 @@ th {
color: #9f1d20; color: #9f1d20;
font-weight: 700; font-weight: 700;
} }
.notice {
margin: 8px 0;
padding: 8px 12px;
border-left: 4px solid #c8881d;
background: #fcf3e2;
color: #6b4a06;
font-weight: 600;
}
@@ -23,16 +23,6 @@
{% endif %} {% endif %}
<section class="answer"> <section class="answer">
<h2>Answer</h2> <h2>Answer</h2>
{% if low_confidence|default(false) %}
<p class="notice">Low retrieval confidence — answer generation was skipped.</p>
{% endif %}
{% set report = citation_report|default(none) %}
{% if report is not none and not report.grounded %}
<p class="notice">Unverified — no source citations were found in this answer.</p>
{% endif %}
{% if report is not none and report.invalid %}
<p class="notice">Invalid citations: {{ report.invalid|join(", ") }} (no matching source).</p>
{% endif %}
<p>{{ answer }}</p> <p>{{ answer }}</p>
</section> </section>
{% if response.results %} {% if response.results %}
+88 -97
View File
@@ -2,15 +2,88 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from os import getenv from os import getenv
from typing import Annotated, Self
from pydantic import AliasChoices, Field, field_validator, model_validator
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
def normalize_embedding_alias(model: str) -> str: def getenv_bool(name: str, *, default: bool) -> bool:
"""Normalize a supported embedding alias to its provider model name.""" """Read a boolean environment variable with a default fallback."""
value = getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def getenv_int(name: str, *, default: int) -> int:
"""Read an integer environment variable with a default fallback."""
value = getenv(name)
if value is None or not value.strip():
return default
return int(value)
@dataclass(frozen=True)
class RerankConfig:
"""vLLM reranker settings."""
enabled: bool = False
base_url: str = "http://192.168.90.25:8001"
model: str = "qwen3-reranker-06b"
candidates: int = 24
timeout_seconds: float = 30.0
@dataclass(frozen=True)
class EbookSearchConfig:
"""Runtime settings for EPUB search."""
rerank: RerankConfig
top_k: int = 12
library_paths: tuple[str, ...] = ()
vllm_base_url: str = "https://ollama.com/v1"
vllm_api_key: str = "not-needed"
chat_model: str = "deepseek-v4-flash"
answer_enabled: bool = True
embedding_base_url: str = "http://192.168.90.25:8000/v1"
embedding_api_key: str = "not-needed"
embedding_model: str = "qwen3-embedding-0.6b"
embedding_batch_size: int = 32
bm25_index_dir: str = ".ebook_search_bm25"
bm25_refresh_delay_seconds: int = 60
def load_rerank_config() -> RerankConfig:
"""Load reranker config from environment variables."""
return RerankConfig(
enabled=getenv_bool("EBOOK_SEARCH_RERANK_ENABLED", default=False),
base_url=getenv("EBOOK_SEARCH_RERANK_BASE_URL", "http://192.168.90.25:8001"),
model=getenv("EBOOK_SEARCH_RERANK_MODEL", "qwen3-reranker-06b"),
candidates=getenv_int("EBOOK_SEARCH_RERANK_CANDIDATES", default=24),
timeout_seconds=float(getenv_int("EBOOK_SEARCH_RERANK_TIMEOUT_SECONDS", default=30)),
)
def load_config() -> EbookSearchConfig:
"""Load EPUB search config from environment variables."""
return EbookSearchConfig(
rerank=load_rerank_config(),
top_k=getenv_int("EBOOK_SEARCH_TOP_K", default=12),
library_paths=library_paths_from_env(),
vllm_base_url=getenv("EBOOK_SEARCH_VLLM_BASE_URL", "https://ollama.com/v1"),
vllm_api_key=getenv("EBOOK_SEARCH_VLLM_API_KEY") or getenv("OLLAMA_API_KEY") or "not-needed",
chat_model=getenv("EBOOK_SEARCH_CHAT_MODEL", "deepseek-v4-flash"),
answer_enabled=getenv_bool("EBOOK_SEARCH_ANSWER_ENABLED", default=True),
embedding_base_url=getenv("EBOOK_SEARCH_EMBEDDING_BASE_URL", "http://192.168.90.25:8000/v1"),
embedding_api_key=getenv("EBOOK_SEARCH_EMBEDDING_API_KEY", "not-needed"),
embedding_model=normalize_embedding_model(),
embedding_batch_size=getenv_int("EBOOK_SEARCH_EMBEDDING_BATCH_SIZE", default=32),
bm25_index_dir=getenv("EBOOK_SEARCH_BM25_INDEX_DIR", ".ebook_search_bm25"),
bm25_refresh_delay_seconds=getenv_int("EBOOK_SEARCH_BM25_REFRESH_DELAY_SECONDS", default=60),
)
def normalize_embedding_model(default: str = "qwen3-embedding-0.6b") -> str:
"""Normalize supported embedding aliases to provider model names."""
aliases = { aliases = {
"Qwen3-Embedding-0.6B": "qwen3-embedding-0.6b", "Qwen3-Embedding-0.6B": "qwen3-embedding-0.6b",
"Qwen3-Embedding-4B": "qwen3-embedding-4b", "Qwen3-Embedding-4B": "qwen3-embedding-4b",
@@ -25,102 +98,20 @@ def normalize_embedding_alias(model: str) -> str:
"qwen3-embedding-4b": "qwen3-embedding-4b", "qwen3-embedding-4b": "qwen3-embedding-4b",
"qwen3-embedding-8b": "qwen3-embedding-8b", "qwen3-embedding-8b": "qwen3-embedding-8b",
} }
model = getenv("EBOOK_SEARCH_EMBEDDING_MODEL", default)
standard_model = aliases.get(model) standard_model = aliases.get(model)
if standard_model is None: if standard_model is None:
error = f"Embedding model {model} is not supported. Supported models are {aliases.keys()}" error = f"Embedding model {model} is not supported. Supported models are {aliases.keys()}"
raise ValueError(error) raise ValueError(error)
return standard_model return standard_model
def normalize_embedding_model(default: str = "qwen3-embedding-0.6b") -> str: def library_paths_from_env() -> tuple[str, ...]:
"""Normalize the configured embedding alias to its provider model name.""" """Read configured EPUB library paths from the environment."""
return normalize_embedding_alias(getenv("EBOOK_SEARCH_EMBEDDING_MODEL", default)) value = getenv("EBOOK_SEARCH_LIBRARY_PATHS")
if value is None:
return ()
class RerankConfig(BaseSettings):
"""vLLM reranker settings."""
model_config = SettingsConfigDict(env_prefix="EBOOK_SEARCH_RERANK_", frozen=True, protected_namespaces=())
enabled: bool = True
base_url: str = "http://192.168.90.25:8001"
model: str = "qwen3-reranker-06b"
candidates: int = 24
timeout_seconds: float = 30.0
score_weight: float = 0.7
hybrid_weight: float = 0.3
class EbookSearchConfig(BaseSettings):
"""Runtime settings for EPUB search."""
model_config = SettingsConfigDict(
env_prefix="EBOOK_SEARCH_",
frozen=True,
populate_by_name=True,
protected_namespaces=(),
)
rerank: RerankConfig = Field(default_factory=RerankConfig)
top_k: int = 12
library_paths: Annotated[tuple[str, ...], NoDecode] = ()
chunk_tokens: int = 700
chunk_overlap: int = 100
vllm_base_url: str = "https://ollama.com/v1"
vllm_api_key: str = Field(
default="not-needed",
validation_alias=AliasChoices("EBOOK_SEARCH_VLLM_API_KEY", "OLLAMA_API_KEY"),
)
chat_model: str = "deepseek-v4-flash"
answer_enabled: bool = True
embedding_base_url: str = "http://192.168.90.25:8000/v1"
embedding_api_key: str = "not-needed"
embedding_model: str = "qwen3-embedding-0.6b"
embedding_batch_size: int = 32
embedding_timeout_seconds: float = 60.0
chat_timeout_seconds: float = 60.0
vector_candidate_multiplier: int = 4
bm25_candidate_limit: int = 120
rrf_rank_constant: int = 60
min_retrieval_confidence: float = 0.0
validate_citations_enabled: bool = True
bm25_index_dir: str = ".ebook_search_bm25"
bm25_refresh_delay_seconds: int = 60
@field_validator("library_paths", mode="before")
@classmethod
def split_library_paths(cls, value: object) -> object:
"""Split a colon-separated library path string into a tuple of paths."""
if isinstance(value, str):
return tuple(path for path in value.split(":") if path) return tuple(path for path in value.split(":") if path)
return value
@field_validator("embedding_model")
@classmethod
def normalize_embedding(cls, value: str) -> str:
"""Normalize the configured embedding alias to its provider model name."""
return normalize_embedding_alias(value)
@model_validator(mode="after")
def validate_runtime_consistency(self) -> Self:
"""Reject configurations that cannot serve the features they enable."""
if not self.embedding_base_url.strip():
msg = "embedding_base_url must be set"
raise ValueError(msg)
if self.answer_enabled and (not self.vllm_base_url.strip() or not self.chat_model.strip()):
msg = "answer_enabled requires vllm_base_url and chat_model to be set"
raise ValueError(msg)
if self.rerank.enabled and not self.rerank.base_url.strip():
msg = "rerank.enabled requires rerank.base_url to be set"
raise ValueError(msg)
return self
def load_rerank_config() -> RerankConfig:
"""Load reranker config from environment variables."""
return RerankConfig()
def load_config() -> EbookSearchConfig:
"""Load EPUB search config from environment variables."""
return EbookSearchConfig()
-1
View File
@@ -1 +0,0 @@
"""Offline evaluation tooling for the ebook search pipeline."""
@@ -1,71 +0,0 @@
{"query": "Who is Damien Montgomery and how does he become a Jump Mage?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What is a Rune Wright and why is Damien so rare?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "How does jump magic let starships travel faster than light?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What is the role of the Mage-King of Mars in the Protectorate?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What happened aboard the Blue Jay in the first Starship's Mage book?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "Who is Captain David Rice?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "How are amplifiers and simulacrums used to power a ship's jump?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What duties does a Hand of the Mage-King carry out?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "Explain the structure of the Royal Martian Navy.", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "How do mages carve runes to enchant a starship?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What threat do the Legatan rebels pose to the Protectorate?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "How does Damien handle his first command?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What is the significance of the simulacrum on a jump ship?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "Describe a mage duel in the Starship's Mage series.", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What moral conflicts does Damien face as a Hand of the Mage-King?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "How does the Protectorate keep peace among its member worlds?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "Who is the Keeper of Oaths and how does Damien work with them?", "answer": null, "answerable": true, "relevant_sources": ["Starship's Mage"]}
{"query": "What event is known as the Onset and how does it change the world?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "Who is the main character at the start of the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How do survivors adapt after the Onset begins?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "What new abilities emerge during the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "Describe the primary antagonist in the Onset series.", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How does society collapse and reorganize after the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "What factions form in the aftermath of the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How does the protagonist gain power throughout the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "What is the cause or origin of the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "Describe an early survival challenge faced after the Onset.", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How do the characters defend their stronghold during the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "What relationships drive the protagonist's choices in the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How does the Onset escalate by the end of the first book?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "What mysteries about the Onset remain unresolved?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How do the rules of the world change once the Onset takes hold?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "What weapons or tactics work best against the threats of the Onset?", "answer": null, "answerable": true, "relevant_sources": ["The Onset"]}
{"query": "How does Bob Johansson become a von Neumann probe?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "What is a replicant and why do Bob's copies have different personalities?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "Who are Riker, Homer, and Bill among the Bob clones?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "What is GUPPI and how does Bob use it?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "Describe the threat posed by the Others.", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "How does Bob protect and uplift the Deltans?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "Why do the replicants drift apart in personality over time?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "What is the role of FAITH and the Brazilian Empire on Earth?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "How does subspace communication work for the Bobs?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "What happens to Bender after he goes missing?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "How do the Bobs build self-replicating probes across the galaxy?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "How does Bob evacuate humanity after Earth becomes uninhabitable?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "Describe the conflict between different factions of Bobs.", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "What ethical dilemmas does Bob face when interfering with primitive species?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "How does the original Bob differ from later generations of clones?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "How do the Bobs defeat the Others' system-harvesting fleets?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
{"query": "What role does Howard play in the human colonies?", "answer": null, "answerable": true, "relevant_sources": ["We Are Legion (We Are Bob)"]}
// querys not it the dataset
{"query": "How does Frodo destroy the One Ring in The Lord of the Rings?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "Who killed Dumbledore in Harry Potter and the Half-Blood Prince?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What house does Tyrion Lannister belong to in A Game of Thrones?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "How does Paul Atreides control the spice on Arrakis in Dune?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What does the green light at the end of the dock mean in The Great Gatsby?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "Why does Hester Prynne wear a scarlet letter?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What does the white whale represent in Moby-Dick?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "How does Elizabeth Bennet's view of Mr. Darcy change in Pride and Prejudice?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What crime does Raskolnikov commit in Crime and Punishment?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "How does Katniss volunteer for the Hunger Games?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What is Winston Smith's job in Nineteen Eighty-Four?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "Who is Atticus Finch defending in To Kill a Mockingbird?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What is the capital of Australia?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "How do I bake a sourdough loaf from scratch?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "Explain how photosynthesis converts sunlight into energy.", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What were the main causes of World War I?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "How does compound interest work?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "How do I change a flat tire on a car?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What is the boiling point of water at sea level?", "answer": null, "answerable": false, "relevant_sources": []}
{"query": "What is the recommended daily intake of vitamin D?", "answer": null, "answerable": false, "relevant_sources": []}
-47
View File
@@ -1,47 +0,0 @@
"""Shared query set loading for evaluation and load testing.
Each JSONL record has a ``query`` and an optional reference ``answer``. ``answerable``
marks whether the query should be answerable from the library (false for out-of-corpus
"garbage" queries used to test the refusal path). Relevance for retrieval metrics is
labeled at source (book) granularity in ``relevant_sources``; source titles must match
``ebook_source.title`` values for the indexed corpus.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
DEFAULT_QUERIES_PATH = Path(__file__).parent / "data" / "queries.jsonl"
@dataclass(frozen=True)
class GoldQuery:
"""One labeled query shared by the eval and load-test tools."""
query: str
answer: str | None
answerable: bool
relevant_sources: tuple[str, ...]
relevant_substrings: tuple[str, ...]
def load_gold_queries(path: Path = DEFAULT_QUERIES_PATH) -> list[GoldQuery]:
"""Load labeled queries from a JSONL file. Blank lines and ``//`` comment lines are skipped."""
queries: list[GoldQuery] = []
for line in path.read_text(encoding="utf-8").splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("//"):
continue
record = json.loads(stripped)
queries.append(
GoldQuery(
query=str(record["query"]),
answer=record.get("answer"),
answerable=bool(record.get("answerable", True)),
relevant_sources=tuple(record.get("relevant_sources", ())),
relevant_substrings=tuple(record.get("relevant_substrings", ())),
)
)
return queries
-57
View File
@@ -1,57 +0,0 @@
"""Serve-time output guardrails for retrieval confidence and answer citations."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from python.ebook_search.config import EbookSearchConfig
from python.ebook_search.search import SearchResult
CITATION_RE = re.compile(r"\[(\d+)\]")
def retrieval_confidence(results: list[SearchResult]) -> float:
"""Return the strongest interpretable relevance signal of the top result.
Reciprocal-rank-fusion scores are rank-based and not comparable across queries,
so the rerank relevance score is preferred, then vector cosine similarity, then
the final score.
"""
if not results:
return 0.0
top = results[0]
if top.rerank_score is not None:
return top.rerank_score
if top.vector_score is not None:
return top.vector_score
return top.score
def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool:
"""Return whether top-result confidence meets the configured threshold."""
return retrieval_confidence(results) >= config.min_retrieval_confidence
@dataclass(frozen=True)
class CitationReport:
"""Validation summary for bracketed citation markers in a generated answer."""
cited: tuple[int, ...]
invalid: tuple[int, ...]
grounded: bool
def validate_citations(answer: str, result_count: int) -> CitationReport:
"""Validate bracketed citation markers against the number of shown sources.
A marker is valid when it points to a returned source (``1..result_count``).
``grounded`` is true when the answer cites at least one valid source.
"""
markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)})
valid = range(1, result_count + 1)
cited = tuple(marker for marker in markers if marker in valid)
invalid = tuple(marker for marker in markers if marker not in valid)
return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited))
+5 -10
View File
@@ -79,17 +79,17 @@ def ingest_configured_paths(session: Session, config: EbookSearchConfig) -> int:
path = Path(library_path).expanduser() path = Path(library_path).expanduser()
logger.info("ebook_ingest_path_start path=%s", path) logger.info("ebook_ingest_path_start path=%s", path)
if path.is_file() and path.suffix.lower() == ".epub": if path.is_file() and path.suffix.lower() == ".epub":
count += int(ingest_file(session, path, config)) count += int(ingest_file(session, path))
elif path.is_dir(): elif path.is_dir():
for epub_path in sorted(path.rglob("*.epub")): for epub_path in sorted(path.rglob("*.epub")):
count += int(ingest_file(session, epub_path, config)) count += int(ingest_file(session, epub_path))
else: else:
logger.warning("ebook_ingest_path_missing path=%s", path) logger.warning("ebook_ingest_path_missing path=%s", path)
logger.info("ebook_ingest_paths_complete changed_files=%s configured_paths=%s", count, len(config.library_paths)) logger.info("ebook_ingest_paths_complete changed_files=%s configured_paths=%s", count, len(config.library_paths))
return count return count
def ingest_file(session: Session, path: Path, config: EbookSearchConfig) -> bool: def ingest_file(session: Session, path: Path) -> bool:
"""Ingest one EPUB file. Return True when the database changed.""" """Ingest one EPUB file. Return True when the database changed."""
resolved_path = path.expanduser().resolve() resolved_path = path.expanduser().resolve()
logger.info("ebook_ingest_file_start path=%s", resolved_path) logger.info("ebook_ingest_file_start path=%s", resolved_path)
@@ -134,7 +134,7 @@ def ingest_file(session: Session, path: Path, config: EbookSearchConfig) -> bool
) )
session.add(chapter) session.add(chapter)
session.flush() session.flush()
chunk_index = add_chapter_chunks(session, source, chapter, parsed_chapter, chunk_index, config) chunk_index = add_chapter_chunks(session, source, chapter, parsed_chapter, chunk_index)
session.flush() session.flush()
logger.info( logger.info(
@@ -160,15 +160,10 @@ def add_chapter_chunks(
chapter: EbookChapter, chapter: EbookChapter,
parsed_chapter: ParsedChapter, parsed_chapter: ParsedChapter,
chunk_index: int, chunk_index: int,
config: EbookSearchConfig,
) -> int: ) -> int:
"""Add chunk rows for one parsed chapter and return the next chunk index.""" """Add chunk rows for one parsed chapter and return the next chunk index."""
page_label = parsed_chapter.page_labels[0] if parsed_chapter.page_labels else None page_label = parsed_chapter.page_labels[0] if parsed_chapter.page_labels else None
for text_chunk in chunk_text( for text_chunk in chunk_text(parsed_chapter.text):
parsed_chapter.text,
chunk_tokens=config.chunk_tokens,
overlap_tokens=config.chunk_overlap,
):
session.add( session.add(
EbookChunk( EbookChunk(
source_id=source.id, source_id=source.id,
+2 -32
View File
@@ -29,7 +29,7 @@ def request_embeddings(texts: Sequence[str], config: EbookSearchConfig) -> list[
f"{config.embedding_base_url.rstrip('/')}/embeddings", f"{config.embedding_base_url.rstrip('/')}/embeddings",
headers=auth_headers(config.embedding_api_key), headers=auth_headers(config.embedding_api_key),
json={"model": config.embedding_model, "input": list(texts)}, json={"model": config.embedding_model, "input": list(texts)},
timeout=config.embedding_timeout_seconds, timeout=60,
) )
response.raise_for_status() response.raise_for_status()
return embedding_vectors_from_response(response.json()) return embedding_vectors_from_response(response.json())
@@ -44,36 +44,6 @@ def request_embeddings(texts: Sequence[str], config: EbookSearchConfig) -> list[
raise RuntimeError(msg) from error raise RuntimeError(msg) from error
def check_embedding_endpoint(config: EbookSearchConfig, *, timeout_seconds: float = 5.0) -> bool:
"""Return whether the configured embedding endpoint answers a model listing."""
try:
response = httpx.get(
f"{config.embedding_base_url.rstrip('/')}/models",
headers=auth_headers(config.embedding_api_key),
timeout=timeout_seconds,
)
response.raise_for_status()
except httpx.HTTPError as error:
logger.warning("ebook_embedding_endpoint_unreachable base_url=%s error=%s", config.embedding_base_url, error)
return False
return True
def check_chat_endpoint(config: EbookSearchConfig, *, timeout_seconds: float = 5.0) -> bool:
"""Return whether the configured chat (answering) endpoint answers a model listing."""
try:
response = httpx.get(
f"{config.vllm_base_url.rstrip('/')}/models",
headers=auth_headers(config.vllm_api_key),
timeout=timeout_seconds,
)
response.raise_for_status()
except httpx.HTTPError as error:
logger.warning("ebook_chat_endpoint_unreachable base_url=%s error=%s", config.vllm_base_url, error)
return False
return True
def embedding_vectors_from_response(body: object) -> list[list[float]]: def embedding_vectors_from_response(body: object) -> list[list[float]]:
"""Extract embedding vectors from an OpenAI-compatible embedding response.""" """Extract embedding vectors from an OpenAI-compatible embedding response."""
if not isinstance(body, dict): if not isinstance(body, dict):
@@ -136,7 +106,7 @@ def request_chat_completion(
"messages": list(messages), "messages": list(messages),
"temperature": 0, "temperature": 0,
}, },
timeout=config.chat_timeout_seconds, timeout=60,
) )
response.raise_for_status() response.raise_for_status()
return chat_content_from_response(response.json()) return chat_content_from_response(response.json())
-218
View File
@@ -1,218 +0,0 @@
"""Load test for the EPUB search service.
Drives ``POST /search`` on a running server at a configurable concurrency and reports
latency percentiles, throughput, and HTTP status distribution. Queries are drawn from
the shared JSONL set (see ``eval/data/queries.jsonl``) that the eval also uses, so load
and evaluation exercise the same questions. Answer generation and reranking happen
server-side, so this exercises the full retrieval pipeline.
"""
from __future__ import annotations
import asyncio
import logging
import math
import random
import statistics
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated
import httpx
import typer
from python.common import configure_logger
from python.ebook_search.eval.dataset import DEFAULT_QUERIES_PATH, load_gold_queries
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RequestResult:
"""Outcome of a single search request."""
status_code: int
latency_ms: float
ok: bool
@dataclass(frozen=True)
class LoadSummary:
"""Aggregate results of a load test run."""
total: int
successes: int
failures: int
wall_seconds: float
throughput_rps: float
latency_p50_ms: float
latency_p90_ms: float
latency_p95_ms: float
latency_p99_ms: float
latency_mean_ms: float
latency_max_ms: float
status_counts: dict[int, int]
def load_queries(queries_file: str | None) -> list[str]:
"""Return the query strings from the shared JSONL set (or a custom JSONL file)."""
path = Path(queries_file) if queries_file else DEFAULT_QUERIES_PATH
queries = [gold.query for gold in load_gold_queries(path)]
if not queries:
msg = f"No queries found in {path}"
raise typer.BadParameter(msg)
return queries
def pick_query(queries: list[str]) -> str:
"""Return a uniformly random query from the pool (not a security context)."""
return random.choice(queries) # noqa: S311 load-test query sampling is not security-sensitive
def percentile(values_sorted: list[float], pct: float) -> float:
"""Return the linearly-interpolated percentile of a sorted list."""
if not values_sorted:
return 0.0
rank = (pct / 100) * (len(values_sorted) - 1)
low = math.floor(rank)
high = math.ceil(rank)
if low == high:
return values_sorted[low]
return values_sorted[low] + (values_sorted[high] - values_sorted[low]) * (rank - low)
def summarize(results: list[RequestResult], wall_seconds: float) -> LoadSummary:
"""Aggregate per-request results into a load summary."""
latencies = sorted(result.latency_ms for result in results)
successes = sum(1 for result in results if result.ok)
status_counts: dict[int, int] = {}
for result in results:
status_counts[result.status_code] = status_counts.get(result.status_code, 0) + 1
return LoadSummary(
total=len(results),
successes=successes,
failures=len(results) - successes,
wall_seconds=wall_seconds,
throughput_rps=len(results) / wall_seconds if wall_seconds > 0 else 0.0,
latency_p50_ms=percentile(latencies, 50),
latency_p90_ms=percentile(latencies, 90),
latency_p95_ms=percentile(latencies, 95),
latency_p99_ms=percentile(latencies, 99),
latency_mean_ms=statistics.fmean(latencies) if latencies else 0.0,
latency_max_ms=latencies[-1] if latencies else 0.0,
status_counts=status_counts,
)
async def send_search(client: httpx.AsyncClient, query: str, *, rerank: bool) -> RequestResult:
"""Send one search request and record its status and latency."""
data = {"query": query, "rerank": "true"} if rerank else {"query": query}
start = time.perf_counter()
try:
response = await client.post("/search", data=data)
except httpx.HTTPError as error:
logger.warning("ebook_loadtest_request_failed error=%s", error)
return RequestResult(status_code=0, latency_ms=(time.perf_counter() - start) * 1000, ok=False)
return RequestResult(
status_code=response.status_code,
latency_ms=(time.perf_counter() - start) * 1000,
ok=response.is_success,
)
async def worker(
client: httpx.AsyncClient,
queue: asyncio.Queue[str],
results: list[RequestResult],
*,
rerank: bool,
) -> None:
"""Pull queries off the queue and send requests until it is empty."""
while True:
try:
query = queue.get_nowait()
except asyncio.QueueEmpty:
return
results.append(await send_search(client, query, rerank=rerank))
async def run_load(
*,
base_url: str,
queries: list[str],
request_count: int,
concurrency: int,
rerank: bool,
warmup: int,
timeout_seconds: float,
) -> LoadSummary:
"""Run the load test and return its aggregate summary."""
limits = httpx.Limits(max_connections=concurrency, max_keepalive_connections=concurrency)
async with httpx.AsyncClient(base_url=base_url, timeout=timeout_seconds, limits=limits) as client:
for _ in range(warmup):
await send_search(client, pick_query(queries), rerank=rerank)
queue: asyncio.Queue[str] = asyncio.Queue()
for _ in range(request_count):
queue.put_nowait(pick_query(queries))
results: list[RequestResult] = []
start = time.perf_counter()
workers = [asyncio.create_task(worker(client, queue, results, rerank=rerank)) for _ in range(concurrency)]
await asyncio.gather(*workers)
wall_seconds = time.perf_counter() - start
return summarize(results, wall_seconds)
def print_summary(summary: LoadSummary) -> None:
"""Print the load summary to stdout."""
typer.echo(f"requests={summary.total} successes={summary.successes} failures={summary.failures}")
typer.echo(f"wall={summary.wall_seconds:.2f}s throughput={summary.throughput_rps:.1f} req/s")
typer.echo(
f"latency_ms p50={summary.latency_p50_ms:.1f} p90={summary.latency_p90_ms:.1f} "
f"p95={summary.latency_p95_ms:.1f} p99={summary.latency_p99_ms:.1f} "
f"mean={summary.latency_mean_ms:.1f} max={summary.latency_max_ms:.1f}"
)
status_summary = " ".join(f"{code}={count}" for code, count in sorted(summary.status_counts.items()))
typer.echo(f"status {status_summary}")
def main(
*,
base_url: Annotated[str, typer.Option(help="Base URL of the running service")] = "http://127.0.0.1:8070",
request_count: Annotated[int, typer.Option("--requests", help="Total requests to send")] = 200,
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 10,
rerank: Annotated[bool, typer.Option(help="Request server-side reranking")] = False,
warmup: Annotated[int, typer.Option(help="Warmup requests, not measured")] = 5,
timeout_seconds: Annotated[float, typer.Option("--timeout", help="Per-request timeout seconds")] = 120.0,
queries_file: Annotated[str | None, typer.Option(help="Query JSONL file (defaults to the shared set)")] = None,
log_level: Annotated[str, typer.Option(help="Log level")] = "WARNING",
) -> None:
"""Load test the search endpoint and report latency and throughput."""
configure_logger(log_level)
queries = load_queries(queries_file)
logger.info(
"ebook_loadtest_start base_url=%s requests=%s concurrency=%s rerank=%s queries=%s",
base_url,
request_count,
concurrency,
rerank,
len(queries),
)
summary = asyncio.run(
run_load(
base_url=base_url,
queries=queries,
request_count=request_count,
concurrency=concurrency,
rerank=rerank,
warmup=warmup,
timeout_seconds=timeout_seconds,
)
)
print_summary(summary)
if __name__ == "__main__":
typer.run(main)
+5 -8
View File
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
from python.ebook_search.search import SearchResult from python.ebook_search.search import SearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RERANK_SCORE_WEIGHT = 0.7
HYBRID_SCORE_WEIGHT = 0.3
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -39,7 +41,7 @@ def rerank_chunks(query: str, candidates: list[SearchResult], config: RerankConf
( (
replace( replace(
result, result,
score=final_rerank_score(result, scores[result.chunk_id].score, candidates, config), score=final_rerank_score(result, scores[result.chunk_id].score, candidates),
rerank_score=scores[result.chunk_id].score, rerank_score=scores[result.chunk_id].score,
) )
for result in candidates for result in candidates
@@ -108,14 +110,9 @@ def clamp_score(score: float) -> float:
return min(max(score, 0.0), 1.0) return min(max(score, 0.0), 1.0)
def final_rerank_score( def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
result: SearchResult,
rerank_score: float,
candidates: list[SearchResult],
config: RerankConfig,
) -> float:
"""Combine rerank relevance with normalized hybrid retrieval evidence.""" """Combine rerank relevance with normalized hybrid retrieval evidence."""
return (config.score_weight * rerank_score) + (config.hybrid_weight * normalized_hybrid_score(result, candidates)) return (RERANK_SCORE_WEIGHT * rerank_score) + (HYBRID_SCORE_WEIGHT * normalized_hybrid_score(result, candidates))
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float: def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
+15 -18
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
import re import re
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -14,7 +13,6 @@ from sqlalchemy import literal, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from python.ebook_search.bm25_corpus import ( from python.ebook_search.bm25_corpus import (
BM25CorpusUnavailableError,
load_bm25_corpus, load_bm25_corpus,
score_bm25_corpus, score_bm25_corpus,
) )
@@ -36,6 +34,7 @@ if TYPE_CHECKING:
from python.ebook_search.config import EbookSearchConfig from python.ebook_search.config import EbookSearchConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BM25_CANDIDATE_LIMIT = 120
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -94,11 +93,14 @@ def search_ebooks(
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank) logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
timings: list[RuntimeStep] = [] timings: list[RuntimeStep] = []
bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query)
timings.append(timing)
retrieval, timing = timed_result( retrieval, timing = timed_result(
"Hybrid retrieval", "Hybrid retrieval",
parallel_retrieval, parallel_retrieval,
engine, engine,
query, query,
bm25_query,
config, config,
) )
timings.extend(retrieval.timings) timings.extend(retrieval.timings)
@@ -108,7 +110,6 @@ def search_ebooks(
reciprocal_rank_fusion, reciprocal_rank_fusion,
retrieval.vector_results, retrieval.vector_results,
retrieval.lexical_results, retrieval.lexical_results,
rank_constant=config.rrf_rank_constant,
) )
timings.append(timing) timings.append(timing)
if config.rerank.enabled and rerank: if config.rerank.enabled and rerank:
@@ -132,7 +133,8 @@ def search_ebooks(
def parallel_retrieval( def parallel_retrieval(
engine: Engine, engine: Engine,
query: str, vector_query: str,
bm25_query: str,
config: EbookSearchConfig, config: EbookSearchConfig,
) -> RetrievalResponse: ) -> RetrievalResponse:
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions.""" """Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
@@ -142,14 +144,14 @@ def parallel_retrieval(
"Embedding + vector search", "Embedding + vector search",
vector_candidates, vector_candidates,
engine, engine,
query, vector_query,
config, config,
) )
bm25_future = executor.submit( bm25_future = executor.submit(
timed_result, timed_result,
"BM25 search", "BM25 search",
bm25_candidates, bm25_candidates,
query, bm25_query,
config, config,
) )
vector_results, vector_timing = vector_future.result() vector_results, vector_timing = vector_future.result()
@@ -213,7 +215,7 @@ def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) ->
raise ValueError(msg) raise ValueError(msg)
embedding = embed_query(query, config) embedding = embed_query(query, config)
limit = max(config.rerank.candidates, config.top_k) * config.vector_candidate_multiplier limit = max(config.rerank.candidates, config.top_k) * 4
embedding_table = get_embedding_table(model.dimension) embedding_table = get_embedding_table(model.dimension)
embedding_param = literal(embedding, type_=Vector(model.dimension)) embedding_param = literal(embedding, type_=Vector(model.dimension))
@@ -250,18 +252,12 @@ def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) ->
def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]: def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]:
"""Return BM25-ranked lexical candidates using the persisted corpus.""" """Return BM25-ranked lexical candidates using the persisted corpus."""
try:
corpus = load_bm25_corpus(config) corpus = load_bm25_corpus(config)
except BM25CorpusUnavailableError as error:
logger.warning("ebook_bm25_index_unavailable_skipping error=%s", error)
return []
if not corpus.records: if not corpus.records:
logger.info("ebook_bm25_search_complete corpus=0 candidates=0") logger.info("ebook_bm25_search_complete corpus=0 candidates=0")
return [] return []
bm25_query = retrieval_query_from_text(query) scored_records = score_bm25_corpus(query, corpus, limit=BM25_CANDIDATE_LIMIT)
scored_records = score_bm25_corpus(bm25_query, corpus, limit=config.bm25_candidate_limit)
results = [ results = [
replace(search_result_from_row(record), score=score, vector_score=None, bm25_score=score) replace(search_result_from_row(record), score=score, vector_score=None, bm25_score=score)
for record, score in scored_records for record, score in scored_records
@@ -280,23 +276,24 @@ def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]
def reciprocal_rank_fusion( def reciprocal_rank_fusion(
vector_results: list[SearchResult], vector_results: list[SearchResult],
lexical_results: list[SearchResult], lexical_results: list[SearchResult],
rank_constant: int, *,
rank_constant: int = 60,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Fuse vector and lexical rankings with Reciprocal Rank Fusion.""" """Fuse vector and lexical rankings with Reciprocal Rank Fusion."""
by_chunk: dict[int, SearchResult] = {} by_chunk: dict[int, SearchResult] = {}
scores: defaultdict[int, float] = defaultdict(float) scores: dict[int, float] = {}
vector_scores: dict[int, float] = {} vector_scores: dict[int, float] = {}
bm25_scores: dict[int, float] = {} bm25_scores: dict[int, float] = {}
for rank, result in enumerate(vector_results, start=1): for rank, result in enumerate(vector_results, start=1):
by_chunk.setdefault(result.chunk_id, result) by_chunk.setdefault(result.chunk_id, result)
vector_scores[result.chunk_id] = result.vector_score if result.vector_score is not None else result.score vector_scores[result.chunk_id] = result.vector_score if result.vector_score is not None else result.score
scores[result.chunk_id] += 1 / (rank_constant + rank) scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank))
for rank, result in enumerate(lexical_results, start=1): for rank, result in enumerate(lexical_results, start=1):
by_chunk.setdefault(result.chunk_id, result) by_chunk.setdefault(result.chunk_id, result)
bm25_scores[result.chunk_id] = result.bm25_score if result.bm25_score is not None else result.score bm25_scores[result.chunk_id] = result.bm25_score if result.bm25_score is not None else result.score
scores[result.chunk_id] += 1 / (rank_constant + rank) scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank))
return sorted( return sorted(
( (
+2 -6
View File
@@ -1,15 +1,11 @@
"""FastAPI dependencies.""" """FastAPI dependencies."""
from __future__ import annotations from collections.abc import Iterator
from typing import Annotated
from typing import TYPE_CHECKING, Annotated
from fastapi import Depends, Request from fastapi import Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
if TYPE_CHECKING:
from collections.abc import Iterator
def get_db(request: Request) -> Iterator[Session]: def get_db(request: Request) -> Iterator[Session]:
"""Get database session from app state.""" """Get database session from app state."""
+1 -5
View File
@@ -1,13 +1,9 @@
"""Zstd response compression middleware.""" """Zstd response compression middleware."""
from compression import zstd from compression import zstd
from typing import TYPE_CHECKING
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response
if TYPE_CHECKING:
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response
MINIMUM_RESPONSE_SIZE = 500 MINIMUM_RESPONSE_SIZE = 500
+2 -6
View File
@@ -1,10 +1,9 @@
"""FastAPI heater control service.""" """FastAPI heater control service."""
from __future__ import annotations
import logging import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Annotated from typing import Annotated
import typer import typer
import uvicorn import uvicorn
@@ -14,9 +13,6 @@ from python.common import configure_logger
from python.heater.controller import HeaterController from python.heater.controller import HeaterController
from python.heater.models import ActionResult, DeviceConfig, HeaterStatus from python.heater.models import ActionResult, DeviceConfig, HeaterStatus
if TYPE_CHECKING:
from collections.abc import AsyncIterator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
-1
View File
@@ -262,7 +262,6 @@ def installer(
): ):
run(command, check=True, stdin=test.stdout) run(command, check=True, stdin=test.stdout)
# Fixed mount point for the new system; the installer runs as root on a fresh disk
mnt_dir = "/tmp/nix_install" # noqa: S108 mnt_dir = "/tmp/nix_install" # noqa: S108
Path(mnt_dir).mkdir(parents=True, exist_ok=True) Path(mnt_dir).mkdir(parents=True, exist_ok=True)
+2
View File
@@ -1,7 +1,9 @@
"""ORM package exports.""" """ORM package exports."""
from python.orm.richie.base import RichieBase from python.orm.richie.base import RichieBase
from python.orm.van_inventory.base import VanInventoryBase
__all__ = [ __all__ = [
"RichieBase", "RichieBase",
"VanInventoryBase",
] ]
+4 -4
View File
@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.richie.base import TableBase from python.orm.richie.base import TableBase
@@ -14,7 +14,7 @@ class AudiobookAuthor(TableBase):
__tablename__ = "audiobook_author" __tablename__ = "audiobook_author"
__table_args__ = (UniqueConstraint("name"),) __table_args__ = (UniqueConstraint("name"),)
name: Mapped[str] name: Mapped[str] = mapped_column(String, unique=True)
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="author") books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="author")
series: Mapped[list[AudiobookSeries]] = relationship("AudiobookSeries", back_populates="author") series: Mapped[list[AudiobookSeries]] = relationship("AudiobookSeries", back_populates="author")
@@ -26,7 +26,7 @@ class AudiobookSeries(TableBase):
__tablename__ = "audiobook_series" __tablename__ = "audiobook_series"
__table_args__ = (UniqueConstraint("author_id", "name"),) __table_args__ = (UniqueConstraint("author_id", "name"),)
name: Mapped[str] name: Mapped[str] = mapped_column(String)
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE")) author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="series") author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="series")
@@ -46,7 +46,7 @@ class Audiobook(TableBase):
), ),
) )
title: Mapped[str] title: Mapped[str] = mapped_column(String)
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE")) author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
series_id: Mapped[int | None] = mapped_column(ForeignKey("main.audiobook_series.id", ondelete="SET NULL")) series_id: Mapped[int | None] = mapped_column(ForeignKey("main.audiobook_series.id", ondelete="SET NULL"))
series_index: Mapped[float] = mapped_column(default=0.0) series_index: Mapped[float] = mapped_column(default=0.0)
+1
View File
@@ -0,0 +1 @@
"""Van inventory database ORM exports."""
+39
View File
@@ -0,0 +1,39 @@
"""Van inventory database ORM base."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, MetaData, func
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from python.orm.common import NAMING_CONVENTION
class VanInventoryBase(DeclarativeBase):
"""Base class for van_inventory database ORM models."""
schema_name = "main"
metadata = MetaData(
schema=schema_name,
naming_convention=NAMING_CONVENTION,
)
class VanTableBase(AbstractConcreteBase, VanInventoryBase):
"""Abstract concrete base for van_inventory tables with IDs and timestamps."""
__abstract__ = True
id: Mapped[int] = mapped_column(primary_key=True)
created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
+46
View File
@@ -0,0 +1,46 @@
"""Van inventory ORM models."""
from __future__ import annotations
from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.van_inventory.base import VanTableBase
class Item(VanTableBase):
"""A food item in the van."""
__tablename__ = "items"
name: Mapped[str] = mapped_column(unique=True)
quantity: Mapped[float] = mapped_column(default=0)
unit: Mapped[str]
category: Mapped[str | None]
meal_ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="item")
class Meal(VanTableBase):
"""A meal that can be made from items in the van."""
__tablename__ = "meals"
name: Mapped[str] = mapped_column(unique=True)
instructions: Mapped[str | None]
ingredients: Mapped[list[MealIngredient]] = relationship(back_populates="meal")
class MealIngredient(VanTableBase):
"""Links a meal to the items it requires, with quantities."""
__tablename__ = "meal_ingredients"
__table_args__ = (UniqueConstraint("meal_id", "item_id"),)
meal_id: Mapped[int] = mapped_column(ForeignKey("meals.id"))
item_id: Mapped[int] = mapped_column(ForeignKey("items.id"))
quantity_needed: Mapped[float]
meal: Mapped[Meal] = relationship(back_populates="ingredients")
item: Mapped[Item] = relationship(back_populates="meal_ingredients")
+1
View File
@@ -0,0 +1 @@
game_data/
+1
View File
@@ -0,0 +1 @@
"""init."""
+675
View File
@@ -0,0 +1,675 @@
"""Base logic for the Splendor game."""
from __future__ import annotations
import itertools
import json
import random
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, Protocol
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
GemColor = Literal["white", "blue", "green", "red", "black", "gold"]
GEM_COLORS: tuple[GemColor, ...] = (
"white",
"blue",
"green",
"red",
"black",
"gold",
)
BASE_COLORS: tuple[GemColor, ...] = (
"white",
"blue",
"green",
"red",
"black",
)
GEM_ORDER: list[GemColor] = list(GEM_COLORS)
GEM_INDEX: dict[GemColor, int] = {c: i for i, c in enumerate(GEM_ORDER)}
BASE_INDEX: dict[GemColor, int] = {c: i for i, c in enumerate(BASE_COLORS)}
@dataclass(frozen=True)
class Card:
"""Development card: gives points + a permanent gem discount."""
tier: int
points: int
color: GemColor
cost: dict[GemColor, int]
@dataclass(frozen=True)
class Noble:
"""Noble tile: gives points if you have enough bonuses."""
name: str
points: int
requirements: dict[GemColor, int]
@dataclass
class PlayerState:
"""State of a player in the game."""
strategy: Strategy
tokens: dict[GemColor, int] = field(default_factory=lambda: dict.fromkeys(GEM_COLORS, 0))
discounts: dict[GemColor, int] = field(default_factory=lambda: dict.fromkeys(GEM_COLORS, 0))
cards: list[Card] = field(default_factory=list)
reserved: list[Card] = field(default_factory=list)
nobles: list[Noble] = field(default_factory=list)
card_score: int = 0
noble_score: int = 0
def total_tokens(self) -> int:
"""Total tokens in player's bank."""
return sum(self.tokens.values())
def add_noble(self, noble: Noble) -> None:
"""Add a noble to the player."""
self.nobles.append(noble)
self.noble_score = sum(noble.points for noble in self.nobles)
def add_card(self, card: Card) -> None:
"""Add a card to the player."""
self.cards.append(card)
self.card_score = sum(card.points for card in self.cards)
@property
def score(self) -> int:
"""Total points in player's cards + nobles."""
return self.card_score + self.noble_score
def can_afford(self, card: Card) -> bool:
"""Check if player can afford card, using discounts + gold."""
missing = 0
gold = self.tokens["gold"]
for color, cost in card.cost.items():
missing += max(0, cost - self.discounts.get(color, 0) - self.tokens.get(color, 0))
if missing > gold:
return False
return True
def pay_for_card(self, card: Card) -> dict[GemColor, int]:
"""Pay tokens for card, move card to tableau, return payment for bank."""
if not self.can_afford(card):
msg = f"cannot afford card {card}"
raise ValueError(msg)
payment: dict[GemColor, int] = dict.fromkeys(GEM_COLORS, 0)
gold_available = self.tokens["gold"]
for color in BASE_COLORS:
cost = card.cost.get(color, 0)
effective_cost = max(0, cost - self.discounts.get(color, 0))
use = min(self.tokens[color], effective_cost)
self.tokens[color] -= use
payment[color] += use
remaining = effective_cost - use
if remaining > 0:
use_gold = min(gold_available, remaining)
gold_available -= use_gold
self.tokens["gold"] -= use_gold
payment["gold"] += use_gold
self.add_card(card)
self.discounts[card.color] += 1
return payment
def get_default_starting_tokens(player_count: int) -> dict[GemColor, int]:
"""get_default_starting_tokens."""
token_count = (player_count * player_count - 3 * player_count + 10) // 2
return {
"white": token_count,
"blue": token_count,
"green": token_count,
"red": token_count,
"black": token_count,
"gold": 5,
}
@dataclass
class GameConfig:
"""Game configuration: gems, bank, cards, nobles, etc."""
win_score: int = 15
table_cards_per_tier: int = 4
reserve_limit: int = 3
token_limit: int = 10
turn_limit: int = 1000
minimum_tokens_to_buy_2: int = 4
max_token_take: int = 3
cards: list[Card] = field(default_factory=list)
nobles: list[Noble] = field(default_factory=list)
class GameState:
"""Game state: players, bank, decks, table, available nobles, etc."""
def __init__(
self,
config: GameConfig,
players: list[PlayerState],
bank: dict[GemColor, int],
decks_by_tier: dict[int, list[Card]],
table_by_tier: dict[int, list[Card]],
available_nobles: list[Noble],
) -> None:
"""Game state."""
self.config = config
self.players = players
self.bank = bank
self.decks_by_tier = decks_by_tier
self.table_by_tier = table_by_tier
self.available_nobles = available_nobles
self.noble_min_requirements = 0
self.get_noble_min_requirements()
self.current_player_index = 0
self.finished = False
def get_noble_min_requirements(self) -> None:
"""Find the minimum requirement for all available nobles."""
test = 0
for noble in self.available_nobles:
test = max(test, min(foo for foo in noble.requirements.values()))
self.noble_min_requirements = test
def next_player(self) -> None:
"""Advance to the next player."""
self.current_player_index = (self.current_player_index + 1) % len(self.players)
@property
def current_player(self) -> PlayerState:
"""Current player."""
return self.players[self.current_player_index]
def refill_table(self) -> None:
"""Refill face-up cards from decks."""
for tier, deck in self.decks_by_tier.items():
table = self.table_by_tier[tier]
while len(table) < self.config.table_cards_per_tier and deck:
table.append(deck.pop())
def check_winner_simple(self) -> PlayerState | None:
"""Simplified: end immediately when someone hits win_score."""
eligible = [player for player in self.players if player.score >= self.config.win_score]
if not eligible:
return None
eligible.sort(
key=lambda p: (p.score, -len(p.cards)),
reverse=True,
)
self.finished = True
return eligible[0]
class Action:
"""Marker protocol for actions."""
@dataclass
class TakeDifferent(Action):
"""Take up to 3 different gem colors."""
colors: list[GemColor]
@dataclass
class TakeDouble(Action):
"""Take two of the same color."""
color: GemColor
@dataclass
class BuyCard(Action):
"""Buy a face-up card."""
tier: int
index: int
@dataclass
class BuyCardReserved(Action):
"""Buy a face-up card."""
index: int
@dataclass
class ReserveCard(Action):
"""Reserve a face-up card."""
tier: int
index: int | None = None
from_deck: bool = False
class Strategy(Protocol):
"""Implement this to make a bot or human controller."""
def __init__(self, name: str) -> None:
"""Initialize a strategy."""
self.name = name
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
"""Return an Action, or None to concede/end."""
raise NotImplementedError
def choose_discard(
self,
game: GameState, # noqa: ARG002
player: PlayerState,
excess: int,
) -> dict[GemColor, int]:
"""Called if player has more than token_limit tokens after an action.
Default: naive auto-discard.
"""
return auto_discard_tokens(player, excess)
def choose_noble(
self,
game: GameState, # noqa: ARG002
player: PlayerState, # noqa: ARG002
nobles: list[Noble],
) -> Noble:
"""Called if player qualifies for multiple nobles. Default: first."""
return nobles[0]
def auto_discard_tokens(player: PlayerState, excess: int) -> dict[GemColor, int]:
"""Very dumb discard logic: discard from colors you have the most of."""
to_discard: dict[GemColor, int] = dict.fromkeys(GEM_COLORS, 0)
remaining = excess
while remaining > 0:
color = max(player.tokens, key=lambda c: player.tokens[c])
if player.tokens[color] == 0:
break
player.tokens[color] -= 1
to_discard[color] += 1
remaining -= 1
return to_discard
def enforce_token_limit(
game: GameState,
strategy: Strategy,
player: PlayerState,
) -> None:
"""If player has more than token_limit tokens, force discards."""
limit = game.config.token_limit
total = player.total_tokens()
if total <= limit:
return
excess = total - limit
discards = strategy.choose_discard(game, player, excess)
for color, amount in discards.items():
available = player.tokens[color]
to_remove = min(amount, available)
if to_remove <= 0:
continue
player.tokens[color] -= to_remove
game.bank[color] += to_remove
remaining = player.total_tokens() - limit
if remaining > 0:
auto = auto_discard_tokens(player, remaining)
for color, amount in auto.items():
game.bank[color] += amount
def _check_nobles_for_player(player: PlayerState, noble: Noble) -> bool:
# this rule is slower
for color, cost in noble.requirements.items(): # noqa: SIM110
if player.discounts[color] < cost:
return False
return True
def check_nobles_for_player(
game: GameState,
strategy: Strategy,
player: PlayerState,
) -> None:
"""Award at most one noble to player if they qualify."""
if game.noble_min_requirements > max(player.discounts.values()):
return
candidates = [noble for noble in game.available_nobles if _check_nobles_for_player(player, noble)]
if not candidates:
return
chosen = candidates[0] if len(candidates) == 1 else strategy.choose_noble(game, player, candidates)
if chosen not in game.available_nobles:
return
game.available_nobles.remove(chosen)
game.get_noble_min_requirements()
player.add_noble(chosen)
def apply_take_different(game: GameState, strategy: Strategy, action: TakeDifferent) -> None:
"""Mutate game state according to action."""
player = game.current_player
colors = [color for color in action.colors if color in BASE_COLORS and game.bank[color] > 0]
if not (1 <= len(colors) <= game.config.max_token_take):
return
for color in colors:
game.bank[color] -= 1
player.tokens[color] += 1
enforce_token_limit(game, strategy, player)
def apply_take_double(game: GameState, strategy: Strategy, action: TakeDouble) -> None:
"""Mutate game state according to action."""
player = game.current_player
color = action.color
if color not in BASE_COLORS:
return
if game.bank[color] < game.config.minimum_tokens_to_buy_2:
return
game.bank[color] -= 2
player.tokens[color] += 2
enforce_token_limit(game, strategy, player)
def apply_buy_card(game: GameState, _strategy: Strategy, action: BuyCard) -> None:
"""Mutate game state according to action."""
player = game.current_player
row = game.table_by_tier.get(action.tier)
if row is None or not (0 <= action.index < len(row)):
return
card = row[action.index]
if not player.can_afford(card):
return
row.pop(action.index)
payment = player.pay_for_card(card)
for color, amount in payment.items():
game.bank[color] += amount
game.refill_table()
def apply_buy_card_reserved(game: GameState, _strategy: Strategy, action: BuyCardReserved) -> None:
"""Mutate game state according to action."""
player = game.current_player
if not (0 <= action.index < len(player.reserved)):
return
card = player.reserved[action.index]
if not player.can_afford(card):
return
player.reserved.pop(action.index)
payment = player.pay_for_card(card)
for color, amount in payment.items():
game.bank[color] += amount
def apply_reserve_card(game: GameState, strategy: Strategy, action: ReserveCard) -> None:
"""Mutate game state according to action."""
player = game.current_player
if len(player.reserved) >= game.config.reserve_limit:
return
card: Card | None = None
if action.from_deck:
deck = game.decks_by_tier.get(action.tier)
if deck:
card = deck.pop()
else:
row = game.table_by_tier.get(action.tier)
if row is None:
return
if action.index is None or not (0 <= action.index < len(row)):
return
card = row.pop(action.index)
game.refill_table()
if card is None:
return
player.reserved.append(card)
if game.bank["gold"] > 0:
game.bank["gold"] -= 1
player.tokens["gold"] += 1
enforce_token_limit(game, strategy, player)
def apply_action(game: GameState, strategy: Strategy, action: Action) -> None:
"""Mutate game state according to action."""
actions = {
TakeDifferent: apply_take_different,
TakeDouble: apply_take_double,
BuyCard: apply_buy_card,
ReserveCard: apply_reserve_card,
BuyCardReserved: apply_buy_card_reserved,
}
action_func = actions.get(type(action))
if action_func is None:
msg = f"Unknown action type: {type(action)}"
raise ValueError(msg)
action_func(game, strategy, action)
# not sure how to simplify this yet
def get_legal_actions( # noqa: C901
game: GameState,
player: PlayerState | None = None,
) -> list[Action]:
"""Enumerate all syntactically legal actions for the given player.
This enforces:
- token-taking rules
- reserve limits
- affordability for buys
"""
if player is None:
player = game.players[game.current_player_index]
actions: list[Action] = []
colors_available = [c for c in BASE_COLORS if game.bank[c] > 0]
for r in (1, 2, 3):
actions.extend(TakeDifferent(colors=list(combo)) for combo in itertools.combinations(colors_available, r))
actions.extend(
TakeDouble(color=color) for color in BASE_COLORS if game.bank[color] >= game.config.minimum_tokens_to_buy_2
)
for tier, row in game.table_by_tier.items():
for idx, card in enumerate(row):
if player.can_afford(card):
actions.append(BuyCard(tier=tier, index=idx))
for idx, card in enumerate(player.reserved):
if player.can_afford(card):
actions.append(BuyCardReserved(index=idx))
if len(player.reserved) < game.config.reserve_limit:
for tier, row in game.table_by_tier.items():
for idx, _ in enumerate(row):
actions.append(
ReserveCard(tier=tier, index=idx, from_deck=False),
)
for tier, deck in game.decks_by_tier.items():
if deck:
actions.append(
ReserveCard(tier=tier, index=None, from_deck=True),
)
return actions
def create_random_cards_tier(
tier: int,
card_count: int,
cost_choices: list[int],
point_choices: list[int],
) -> list[Card]:
"""Create a random set of cards for a given tier."""
cards: list[Card] = []
for color in BASE_COLORS:
for _ in range(card_count):
cost = dict.fromkeys(GEM_COLORS, 0)
for c in BASE_COLORS:
if c == color:
continue
cost[c] = random.choice(cost_choices)
points = random.choice(point_choices)
cards.append(Card(tier=tier, points=points, color=color, cost=cost))
return cards
def create_random_cards() -> list[Card]:
"""Generate a generic but Splendor-ish set of cards.
This is not the official deck, but structured similarly enough for play.
"""
cards: list[Card] = []
cards.extend(
create_random_cards_tier(
tier=1,
card_count=5,
cost_choices=[0, 1, 1, 2],
point_choices=[0, 0, 1],
)
)
cards.extend(
create_random_cards_tier(
tier=2,
card_count=4,
cost_choices=[2, 3, 4],
point_choices=[1, 2, 2, 3],
)
)
cards.extend(
create_random_cards_tier(
tier=3,
card_count=3,
cost_choices=[4, 5, 6],
point_choices=[3, 4, 5],
)
)
random.shuffle(cards)
return cards
def create_random_nobles() -> list[Noble]:
"""A small set of noble tiles, roughly Splendor-ish."""
nobles: list[Noble] = []
base_requirements: list[dict[GemColor, int]] = [
{"white": 3, "blue": 3, "green": 3},
{"blue": 3, "green": 3, "red": 3},
{"green": 3, "red": 3, "black": 3},
{"red": 3, "black": 3, "white": 3},
{"black": 3, "white": 3, "blue": 3},
{"white": 4, "blue": 4},
{"green": 4, "red": 4},
{"blue": 4, "black": 4},
]
for idx, req in enumerate(base_requirements, start=1):
nobles.append(
Noble(
name=f"Noble {idx}",
points=3,
requirements=dict(req.items()),
),
)
return nobles
def load_nobles(file: Path) -> list[Noble]:
"""Load nobles from a file."""
nobles = json.loads(file.read_text())
return [Noble(**noble) for noble in nobles]
def load_cards(file: Path) -> list[Card]:
"""Load cards from a file."""
cards = json.loads(file.read_text())
return [Card(**card) for card in cards]
def new_game(
strategies: Sequence[Strategy],
config: GameConfig,
) -> GameState:
"""Create a new game state from a config + list of players."""
num_players = len(strategies)
bank = get_default_starting_tokens(num_players)
decks_by_tier: dict[int, list[Card]] = {1: [], 2: [], 3: []}
for card in config.cards:
decks_by_tier.setdefault(card.tier, []).append(card)
for deck in decks_by_tier.values():
random.shuffle(deck)
table_by_tier: dict[int, list[Card]] = {1: [], 2: [], 3: []}
players = [PlayerState(strategy=strategy) for strategy in strategies]
nobles = list(config.nobles)
random.shuffle(nobles)
nobles = nobles[: num_players + 1]
game = GameState(
config=config,
players=players,
bank=bank,
decks_by_tier=decks_by_tier,
table_by_tier=table_by_tier,
available_nobles=nobles,
)
game.refill_table()
return game
def run_game(game: GameState) -> tuple[PlayerState, int]:
"""Run a full game loop until someone wins or a player returns None."""
turn_count = 0
while not game.finished:
turn_count += 1
player = game.current_player
strategy = player.strategy
action = strategy.choose_action(game, player)
if action is None:
game.finished = True
break
apply_action(game, strategy, action)
check_nobles_for_player(game, strategy, player)
winner = game.check_winner_simple()
if winner is not None:
return winner, turn_count
game.next_player()
if turn_count >= game.config.turn_limit:
break
fallback = max(game.players, key=lambda player: player.score)
return fallback, turn_count
+288
View File
@@ -0,0 +1,288 @@
"""Bot for Splendor game."""
from __future__ import annotations
import random
from .base import (
BASE_COLORS,
Action,
BuyCard,
BuyCardReserved,
Card,
GameState,
GemColor,
PlayerState,
ReserveCard,
Strategy,
TakeDifferent,
TakeDouble,
auto_discard_tokens,
get_legal_actions,
)
def can_bot_afford(player: PlayerState, card: Card) -> bool:
"""Check if player can afford card, using discounts + gold."""
missing = 0
gold = player.tokens["gold"]
for color, cost in card.cost.items():
missing += max(0, cost - player.discounts.get(color, 0) - player.tokens.get(color, 0))
if missing > gold:
return False
return True
class RandomBot(Strategy):
"""Dumb bot that follows rules but doesn't think."""
def __init__(self, name: str) -> None:
"""Initialize the bot."""
super().__init__(name=name)
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
"""Choose an action for the current player."""
affordable: list[tuple[int, int]] = []
for tier, row in game.table_by_tier.items():
for idx, card in enumerate(row):
if can_bot_afford(player, card):
affordable.append((tier, idx))
if affordable and random.random() < 0.5:
tier, idx = random.choice(affordable)
return BuyCard(tier=tier, index=idx)
if random.random() < 0.2:
tier = random.choice([1, 2, 3])
row = game.table_by_tier.get(tier, [])
if row:
idx = random.randrange(len(row))
return ReserveCard(tier=tier, index=idx, from_deck=False)
if random.random() < 0.5:
colors_for_double = [c for c in BASE_COLORS if game.bank[c] >= 4]
if colors_for_double:
return TakeDouble(color=random.choice(colors_for_double))
colors_for_diff = [c for c in BASE_COLORS if game.bank[c] > 0]
random.shuffle(colors_for_diff)
return TakeDifferent(colors=colors_for_diff[:3])
def choose_discard(
self,
game: GameState, # noqa: ARG002
player: PlayerState,
excess: int,
) -> dict[GemColor, int]:
"""Choose how many tokens to discard."""
return auto_discard_tokens(player, excess)
def check_cards_in_tier(row: list[Card], player: PlayerState) -> list[int]:
"""Check if player can afford card, using discounts + gold."""
return [index for index, card in enumerate(row) if can_bot_afford(player, card)]
class PersonalizedBot(Strategy):
"""PersonalizedBot."""
"""Dumb bot that follows rules but doesn't think."""
def __init__(self, name: str) -> None:
"""Initialize the bot."""
super().__init__(name=name)
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
"""Choose an action for the current player."""
for tier in (1, 2, 3):
row = game.table_by_tier[tier]
if affordable := check_cards_in_tier(row, player):
index = random.choice(affordable)
return BuyCard(tier=tier, index=index)
colors_for_diff = [c for c in BASE_COLORS if game.bank[c] > 0]
random.shuffle(colors_for_diff)
return TakeDifferent(colors=colors_for_diff[:3])
def choose_discard(
self,
game: GameState, # noqa: ARG002
player: PlayerState,
excess: int,
) -> dict[GemColor, int]:
"""Choose how many tokens to discard."""
return auto_discard_tokens(player, excess)
class PersonalizedBot2(Strategy):
"""PersonalizedBot2."""
"""Dumb bot that follows rules but doesn't think."""
def __init__(self, name: str) -> None:
"""Initialize the bot."""
super().__init__(name=name)
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
"""Choose an action for the current player."""
tiers = (1, 2, 3)
for tier in tiers:
row = game.table_by_tier[tier]
if affordable := check_cards_in_tier(row, player):
index = random.choice(affordable)
return BuyCard(tier=tier, index=index)
if affordable := check_cards_in_tier(player.reserved, player):
index = random.choice(affordable)
return BuyCardReserved(index=index)
colors_for_diff = [c for c in BASE_COLORS if game.bank[c] > 0]
if len(colors_for_diff) >= 3:
random.shuffle(colors_for_diff)
return TakeDifferent(colors=colors_for_diff[:3])
for tier in tiers:
len_deck = len(game.decks_by_tier[tier])
if len_deck:
return ReserveCard(tier=tier, index=None, from_deck=True)
return TakeDifferent(colors=colors_for_diff[:3])
def choose_discard(
self,
game: GameState, # noqa: ARG002
player: PlayerState,
excess: int,
) -> dict[GemColor, int]:
"""Choose how many tokens to discard."""
return auto_discard_tokens(player, excess)
def buy_card_reserved(player: PlayerState) -> Action | None:
"""Buy a card reserved."""
if affordable := check_cards_in_tier(player.reserved, player):
index = random.choice(affordable)
return BuyCardReserved(index=index)
return None
def buy_card(game: GameState, player: PlayerState) -> Action | None:
"""Buy a card."""
for tier in (1, 2, 3):
row = game.table_by_tier[tier]
if affordable := check_cards_in_tier(row, player):
index = random.choice(affordable)
return BuyCard(tier=tier, index=index)
return None
def take_tokens(game: GameState) -> Action | None:
"""Take tokens."""
colors_for_diff = [color for color in BASE_COLORS if game.bank[color] > 0]
if len(colors_for_diff) >= 3:
random.shuffle(colors_for_diff)
return TakeDifferent(colors=colors_for_diff[: game.config.max_token_take])
return None
class PersonalizedBot3(Strategy):
"""PersonalizedBot3."""
"""Dumb bot that follows rules but doesn't think."""
def __init__(self, name: str) -> None:
"""Initialize the bot."""
super().__init__(name=name)
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
"""Choose an action for the current player."""
print(len(get_legal_actions(game, player)))
print(get_legal_actions(game, player))
if action := buy_card_reserved(player):
return action
if action := buy_card(game, player):
return action
colors_for_diff = [color for color in BASE_COLORS if game.bank[color] > 0]
if len(colors_for_diff) >= 3:
random.shuffle(colors_for_diff)
return TakeDifferent(colors=colors_for_diff[:3])
for tier in (1, 2, 3):
len_deck = len(game.decks_by_tier[tier])
if len_deck:
return ReserveCard(tier=tier, index=None, from_deck=True)
return TakeDifferent(colors=colors_for_diff[:3])
def choose_discard(
self,
game: GameState, # noqa: ARG002
player: PlayerState,
excess: int,
) -> dict[GemColor, int]:
"""Choose how many tokens to discard."""
return auto_discard_tokens(player, excess)
def estimate_value_of_card(game: GameState, player: PlayerState, color: GemColor) -> int:
"""Estimate value of a color in the player's bank."""
return game.bank[color] - player.discounts.get(color, 0)
def estimate_value_of_token(game: GameState, player: PlayerState, color: GemColor) -> int:
"""Estimate value of a color in the player's bank."""
return game.bank[color] - player.discounts.get(color, 0)
class PersonalizedBot4(Strategy):
"""PersonalizedBot4."""
def __init__(self, name: str) -> None:
"""Initialize the bot."""
super().__init__(name=name)
def filter_actions(self, actions: list[Action]) -> list[Action]:
"""Filter actions to only take different."""
return [
action
for action in actions
if (isinstance(action, TakeDifferent) and len(action.colors) == 3) or not isinstance(action, TakeDifferent)
]
def choose_action(self, game: GameState, player: PlayerState) -> Action | None:
"""Choose an action for the current player."""
legal_actions = get_legal_actions(game, player)
print(len(legal_actions))
good_actions = self.filter_actions(legal_actions)
print(len(good_actions))
print(good_actions)
print(len(get_legal_actions(game, player)))
if action := buy_card_reserved(player):
return action
if action := buy_card(game, player):
return action
colors_for_diff = [color for color in BASE_COLORS if game.bank[color] > 0]
if len(colors_for_diff) >= 3:
random.shuffle(colors_for_diff)
return TakeDifferent(colors=colors_for_diff[:3])
for tier in (1, 2, 3):
len_deck = len(game.decks_by_tier[tier])
if len_deck:
return ReserveCard(tier=tier, index=None, from_deck=True)
return TakeDifferent(colors=colors_for_diff[:3])
def choose_discard(
self,
game: GameState, # noqa: ARG002
player: PlayerState,
excess: int,
) -> dict[GemColor, int]:
"""Choose how many tokens to discard."""
return auto_discard_tokens(player, excess)
+724
View File
@@ -0,0 +1,724 @@
"""Splendor game."""
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any
from textual.app import App, ComposeResult
from textual.containers import Horizontal, Vertical
from textual.widget import Widget
from textual.widgets import Footer, Header, Input, Static
from .base import (
BASE_COLORS,
GEM_COLORS,
Action,
BuyCard,
BuyCardReserved,
Card,
GameState,
GemColor,
Noble,
PlayerState,
ReserveCard,
Strategy,
TakeDifferent,
TakeDouble,
)
if TYPE_CHECKING:
from collections.abc import Mapping
# Abbreviations used when rendering costs
COST_ABBR: dict[GemColor, str] = {
"white": "W",
"blue": "B",
"green": "G",
"red": "R",
"black": "K",
"gold": "O",
}
# Abbreviations players can type on the command line
COLOR_ABBR_TO_FULL: dict[str, GemColor] = {
"w": "white",
"b": "blue",
"g": "green",
"r": "red",
"k": "black",
"o": "gold",
}
def parse_color_token(raw: str) -> GemColor:
"""Convert user input into a GemColor.
Supports:
- full names: white, blue, green, red, black, gold
- abbreviations: w, b, g, r, k, o
"""
key = raw.lower()
# full color names first
if key in BASE_COLORS:
return key # type: ignore[return-value]
# abbreviations
if key in COLOR_ABBR_TO_FULL:
return COLOR_ABBR_TO_FULL[key]
error = f"Unknown color: {raw}"
raise ValueError(error)
def format_cost(cost: Mapping[GemColor, int]) -> str:
"""Format a cost/requirements dict as colored tokens like 'B:2, R:1'.
Uses `color_token` internally so colors are guaranteed to match your bank.
"""
parts: list[str] = []
for color in GEM_COLORS:
n = cost.get(color, 0)
if not n:
continue
# color_token gives us e.g. "[blue]blue: 3[/]"
token = color_token(color, n)
# Turn the leading color name into the abbreviation (blue: 3 → B:3)
# We only replace the first occurrence.
full = f"{color}:"
abbr = f"{COST_ABBR[color]}:"
token = token.replace(full, abbr, 1)
parts.append(token)
return ", ".join(parts) if parts else "-"
def format_card(card: Card) -> str:
"""Readable card line using dataclass fields instead of __str__."""
color_abbr = COST_ABBR[card.color]
header = f"T{card.tier} {color_abbr} P{card.points}"
cost_str = format_cost(card.cost)
return f"{header} ({cost_str})"
def format_noble(noble: Noble) -> str:
"""Readable noble line using dataclass fields instead of __str__."""
cost_str = format_cost(noble.requirements)
return f"{noble.name} +{noble.points} ({cost_str})"
def format_tokens(tokens: Mapping[GemColor, int]) -> str:
"""Colored 'color: n' list for a token dict."""
return " ".join(color_token(c, tokens.get(c, 0)) for c in GEM_COLORS)
def format_discounts(discounts: Mapping[GemColor, int]) -> str:
"""Colored discounts, skipping zeros."""
parts: list[str] = []
for c in GEM_COLORS:
n = discounts.get(c, 0)
if not n:
continue
abbr = COST_ABBR[c]
fg, bg = COLOR_STYLE[c]
parts.append(f"[{fg} on {bg}]{abbr}:{n}[/{fg} on {bg}]")
return ", ".join(parts) if parts else "-"
COLOR_STYLE: dict[GemColor, tuple[str, str]] = {
"white": ("black", "white"), # fg, bg
"blue": ("bright_white", "blue"),
"green": ("bright_white", "sea_green4"),
"red": ("white", "red3"),
"black": ("white", "grey0"),
"gold": ("black", "yellow3"),
}
def fmt_gem(color: GemColor) -> str:
"""Render gem name with fg/bg matching real token color."""
fg, bg = COLOR_STYLE[color]
return f"[{fg} on {bg}] {color} [/{fg} on {bg}]"
def fmt_number(value: int) -> str:
"""Return a Rich-markup colored 'value' string."""
return f"[bold cyan]{value}[/]"
def color_token(name: GemColor, amount: int) -> str:
"""Return a Rich-markup colored 'name: n' string."""
# Map Splendor colors -> terminal colors
color_map: Mapping[GemColor, str] = {
"white": "white",
"blue": "blue",
"green": "green",
"red": "red",
"black": "grey70", # 'black' is unreadable on dark backgrounds
"gold": "yellow",
}
style = color_map.get(name, "white")
return f"[{style}]{name}: {amount}[/]"
class Board(Widget):
"""Big board widget with the layout you sketched."""
def __init__(self, game: GameState, me: PlayerState, **kwargs: Any) -> None: # noqa: ANN401
"""Initialize the board widget."""
super().__init__(**kwargs)
self.game = game
self.me = me
def compose(self) -> ComposeResult:
"""Compose the board widget."""
# Structure:
# ┌ bank row
# ├ middle row (tiers | nobles)
# └ players row
with Vertical(id="board_root"):
yield Static(id="bank_box")
with Horizontal(id="middle_row"):
with Vertical(id="tiers_box"):
yield Static(id="tier1_box")
yield Static(id="tier2_box")
yield Static(id="tier3_box")
yield Static(id="nobles_box")
yield Static(id="players_box")
def on_mount(self) -> None:
"""Refresh the board content."""
self.refresh_content()
def refresh_content(self) -> None:
"""Refresh the board content."""
self._render_bank()
self._render_tiers()
self._render_nobles()
self._render_players()
# --- sections ----------------------------------------------------
def _render_bank(self) -> None:
bank = self.game.bank
parts: list[str] = ["[b]Bank:[/b]"]
# One line, all tokens colored
parts.append(format_tokens(bank))
self.query_one("#bank_box", Static).update("\n".join(parts))
def _render_tiers(self) -> None:
for tier in (1, 2, 3):
box = self.query_one(f"#tier{tier}_box", Static)
cards: list[Card] = self.game.table_by_tier.get(tier, [])
lines: list[str] = [f"[b]Tier {tier} cards:[/b]"]
if not cards:
lines.append(" (none)")
else:
for idx, card in enumerate(cards):
lines.append(f" [{idx}] {format_card(card)}")
box.update("\n".join(lines))
def _render_nobles(self) -> None:
nobles_box = self.query_one("#nobles_box", Static)
lines: list[str] = ["[b]Nobles[/b]"]
if not self.game.available_nobles:
lines.append(" (none)")
else:
lines.extend(" - " + format_noble(noble) for noble in self.game.available_nobles)
nobles_box.update("\n".join(lines))
def _render_players(self) -> None:
players_box = self.query_one("#players_box", Static)
lines: list[str] = ["[b]Players:[/b]", ""]
for player in self.game.players:
mark = "*" if player is self.me else " "
token_str = format_tokens(player.tokens)
discount_str = format_discounts(player.discounts)
lines.append(
f"{mark} {player.name:10} Score={player.score:2d} Discounts={discount_str}",
)
lines.append(f" Tokens: {token_str}")
if player.nobles:
noble_names = ", ".join(n.name for n in player.nobles)
lines.append(f" Nobles: {noble_names}")
# Optional: show counts of cards / reserved
if player.cards:
lines.append(f" Cards: {len(player.cards)}")
if player.reserved:
lines.append(f" Reserved: {len(player.reserved)}")
lines.append("")
players_box.update("\n".join(lines))
class ActionApp(App[None]):
"""Textual app that asks for a single action command and returns an Action."""
CSS = """
Screen {
/* 3 rows: command zone, board, footer */
layout: grid;
grid-size: 1 3;
grid-rows: auto 1fr auto;
}
/* Top area with input + instructions */
#command_zone {
grid-columns: 1;
grid-rows: 1;
padding: 1 1;
}
/* Board sits in the middle row and can grow */
#board {
grid-columns: 1;
grid-rows: 2;
padding: 0 1 1 1;
}
Footer {
grid-columns: 1;
grid-rows: 3;
}
Input {
border: round $accent;
}
/* === Board layout === */
#board_root {
/* outer frame around the whole board area */
border: heavy white;
padding: 0 1;
}
/* Bank row: full width */
#bank_box {
border: heavy white;
padding: 0 1;
}
/* Middle row: tiers (left) + nobles (right) */
#middle_row {
layout: horizontal;
}
#tiers_box {
border: heavy white;
padding: 0 1;
width: 70%;
}
#tier1_box,
#tier2_box,
#tier3_box {
border-bottom: heavy white;
padding: 0 0 1 0;
margin-bottom: 1;
}
#nobles_box {
border: heavy white;
padding: 0 1;
width: 30%;
}
/* Players row: full width at bottom */
#players_box {
border: heavy white;
padding: 0 1;
}
"""
def __init__(self, game: GameState, player: PlayerState) -> None:
"""Initialize the action app."""
super().__init__()
self.game = game
self.player = player
self.result: Action | None = None
self.message: str = ""
def compose(self) -> ComposeResult:
"""Compose the action app."""
# Row 1: input + Actions text
with Vertical(id="command_zone"):
yield Input(
placeholder="Enter command, e.g. '1 white blue red' or '1 w b r' or 'q'",
id="input_line",
)
yield Static("", id="prompt")
# Row 2: board
yield Board(self.game, self.player, id="board")
# Row 3: footer
yield Footer()
def on_mount(self) -> None:
"""Mount the action app."""
self._update_prompt()
self.query_one(Input).focus()
def _update_prompt(self) -> None:
lines: list[str] = []
lines.append("[bold underline]Actions:[/]")
lines.append(
" [bold green]1[/] <colors...> - Take up to 3 different gem colors "
"(e.g. [cyan]1 white blue red[/] or [cyan]1 w b r[/])",
)
lines.append(
f" [bold green]2[/] <color> - Take 2 of the same color (needs {fmt_number(4)} in bank, "
"e.g. [cyan]2 blue[/] or [cyan]2 b[/])",
)
lines.append(
" [bold green]3[/] <tier> <idx> - Buy a face-up card (e.g. [cyan]3 1 0[/] for tier 1, index 0)",
)
lines.append(" [bold green]4[/] <idx> - Buy a reserved card")
lines.append(" [bold green]5[/] <tier> <idx> - Reserve a face-up card")
lines.append(" [bold green]6[/] <tier> - Reserve top card of a deck")
lines.append(" [bold red]q[/] - Quit game")
if self.message:
lines.append("")
lines.append(f"[bold red]Message:[/] {self.message}")
self.query_one("#prompt", Static).update("\n".join(lines))
def _cmd_1(self, parts: list[str]) -> str | None:
"""Take up to 3 different gem colors: 1 white blue red OR 1 w b r."""
color_names = parts[1:]
if not color_names:
return "Need at least one color (full name or abbreviation)."
colors: list[GemColor] = []
for name in color_names:
color = parse_color_token(name)
if self.game.bank[color] <= 0:
return f"No tokens left for color: {color}"
colors.append(color)
self.result = TakeDifferent(colors=colors[:3])
self.exit()
return None
def _cmd_2(self, parts: list[str]) -> str | None:
"""Take two of the same color."""
if len(parts) < 2:
return "Usage: 2 <color>"
color = parse_color_token(parts[1])
if self.game.bank[color] < self.game.config.minimum_tokens_to_buy_2:
return "Bank must have at least 4 of that color."
self.result = TakeDouble(color=color)
self.exit()
return None
def _cmd_3(self, parts: list[str]) -> str | None:
"""Buy face-up card."""
if len(parts) < 3:
return "Usage: 3 <tier> <index>"
tier = int(parts[1])
idx = int(parts[2])
self.result = BuyCard(tier=tier, index=idx)
self.exit()
return None
def _cmd_4(self, parts: list[str]) -> str | None:
"""Buy reserved card."""
if len(parts) < 2:
return "Usage: 4 <reserved_index>"
idx = int(parts[1])
if not (0 <= idx < len(self.player.reserved)):
return "Reserved index out of range."
self.result = BuyCardReserved(tier=0, index=idx)
self.exit()
return None
def _cmd_5(self, parts: list[str]) -> str | None:
"""Reserve face-up card."""
if len(parts) < 3:
return "Usage: 5 <tier> <index>"
tier = int(parts[1])
idx = int(parts[2])
self.result = ReserveCard(tier=tier, index=idx, from_deck=False)
self.exit()
return None
def _cmd_6(self, parts: list[str]) -> str | None:
"""Reserve top of deck."""
if len(parts) < 2:
return "Usage: 6 <tier>"
tier = int(parts[1])
self.result = ReserveCard(tier=tier, index=None, from_deck=True)
self.exit()
return None
def _unknown_cmd(self, _parts: list[str]) -> str:
return "Unknown command."
def on_input_submitted(self, event: Input.Submitted) -> None:
"""Handle user input."""
text = (event.value or "").strip()
event.input.value = ""
if not text:
return
if text.lower() in {"q", "quit", "0"}:
self.result = None
self.exit()
return
parts = text.split()
cmds = {
"1": self._cmd_1,
"2": self._cmd_2,
"3": self._cmd_3,
"4": self._cmd_4,
"5": self._cmd_5,
"6": self._cmd_6,
}
cmd = parts[0]
error = cmds.get(cmd, self._unknown_cmd)(parts)
if error:
self.message = error
self._update_prompt()
return
class DiscardApp(App[None]):
"""Textual app to choose discards when over token limit."""
CSS = """
Screen {
layout: vertical;
}
#command_zone {
padding: 1 1;
}
#board {
padding: 0 1 1 1;
}
Input {
border: round $accent;
}
"""
def __init__(self, game: GameState, player: PlayerState) -> None:
"""Initialize the discard app."""
super().__init__()
self.game = game
self.player = player
self.discards: dict[GemColor, int] = dict.fromkeys(GEM_COLORS, 0)
self.message: str = ""
def compose(self) -> ComposeResult: # type: ignore[override]
"""Compose the discard app."""
yield Header(show_clock=False)
with Vertical(id="command_zone"):
yield Input(
placeholder="Enter color to discard, e.g. 'blue' or 'b'",
id="input_line",
)
yield Static("", id="prompt")
# Board directly under the command zone
yield Board(self.game, self.player, id="board")
yield Footer()
def on_mount(self) -> None: # type: ignore[override]
"""Mount the discard app."""
self._update_prompt()
self.query_one(Input).focus()
def _remaining_to_discard(self) -> int:
return self.player.total_tokens() - sum(self.discards.values()) - self.game.config.token_limit
def _update_prompt(self) -> None:
remaining = max(self._remaining_to_discard(), 0)
lines: list[str] = []
lines.append(
"You must discard "
f"{fmt_number(remaining)} token(s) "
f"to get down to {fmt_number(self.game.config.token_limit)}.",
)
disc_str = ", ".join(f"{fmt_gem(c)}={fmt_number(self.discards[c])}" for c in GEM_COLORS)
lines.append(f"Current planned discards: {{ {disc_str} }}")
lines.append(
"Type a color name or abbreviation (e.g. 'blue' or 'b') to discard one token.",
)
if self.message:
lines.append("")
lines.append(f"[bold red]Message:[/] {self.message}")
self.query_one("#prompt", Static).update("\n".join(lines))
def on_input_submitted(self, event: Input.Submitted) -> None: # type: ignore[override]
"""Handle user input."""
raw = (event.value or "").strip()
event.input.value = ""
if not raw:
return
try:
color = parse_color_token(raw)
except ValueError:
self.message = f"Unknown color: {raw}"
self._update_prompt()
return
available = self.player.tokens[color] - self.discards[color]
if available <= 0:
self.message = f"No more {color} tokens available to discard."
self._update_prompt()
return
self.discards[color] += 1
if self._remaining_to_discard() <= 0:
self.exit()
return
self.message = ""
self._update_prompt()
# ---------------------------------------------------------------------------
# Noble choice app
# ---------------------------------------------------------------------------
class NobleChoiceApp(App[None]):
"""Textual app to choose one noble."""
CSS = """
Screen {
layout: vertical;
}
#command_zone {
padding: 1 1;
}
#board {
padding: 0 1 1 1;
}
Input {
border: round $accent;
}
"""
def __init__(
self,
game: GameState,
player: PlayerState,
nobles: list[Noble],
) -> None:
"""Initialize the noble choice app."""
super().__init__()
self.game = game
self.player = player
self.nobles = nobles
self.result: Noble | None = None
self.message: str = ""
def compose(self) -> ComposeResult: # type: ignore[override]
"""Compose the noble choice app."""
yield Header(show_clock=False)
with Vertical(id="command_zone"):
yield Input(
placeholder="Enter noble index, e.g. '0'",
id="input_line",
)
yield Static("", id="prompt")
# Board directly under the command zone
yield Board(self.game, self.player, id="board")
yield Footer()
def on_mount(self) -> None: # type: ignore[override]
"""Mount the noble choice app."""
self._update_prompt()
self.query_one(Input).focus()
def _update_prompt(self) -> None:
lines: list[str] = []
lines.append("[bold underline]You qualify for nobles:[/]")
for i, noble in enumerate(self.nobles):
lines.append(f" [bright_cyan]{i})[/] {format_noble(noble)}")
lines.append("Enter the index of the noble you want.")
if self.message:
lines.append("")
lines.append(f"[bold red]Message:[/] {self.message}")
self.query_one("#prompt", Static).update("\n".join(lines))
def on_input_submitted(self, event: Input.Submitted) -> None: # type: ignore[override]
"""Handle user input."""
raw = (event.value or "").strip()
event.input.value = ""
if not raw:
return
try:
idx = int(raw)
except ValueError:
self.message = "Please enter a valid integer index."
self._update_prompt()
return
if not (0 <= idx < len(self.nobles)):
self.message = "Index out of range."
self._update_prompt()
return
self.result = self.nobles[idx]
self.exit()
class TuiHuman(Strategy):
"""Textual-based human player Strategy with colorful board."""
def choose_action(
self,
game: GameState,
player: PlayerState,
) -> Action | None:
"""Choose an action for the player."""
if not sys.stdout.isatty():
return None
app = ActionApp(game, player)
app.run()
return app.result
def choose_discard(
self,
game: GameState,
player: PlayerState,
excess: int, # noqa: ARG002
) -> dict[GemColor, int]:
"""Choose tokens to discard."""
if not sys.stdout.isatty():
return dict.fromkeys(GEM_COLORS, 0)
app = DiscardApp(game, player)
app.run()
return app.discards
def choose_noble(
self,
game: GameState,
player: PlayerState,
nobles: list[Noble],
) -> Noble:
"""Choose a noble for the player."""
if not sys.stdout.isatty():
return nobles[0]
app = NobleChoiceApp(game, player, nobles)
app.run()
return app.result
+19
View File
@@ -0,0 +1,19 @@
"""Main entry point for Splendor game."""
from __future__ import annotations
from .base import new_game, run_game
from .bot import RandomBot
from .human import TuiHuman
def main() -> None:
"""Main entry point."""
human = TuiHuman()
bot = RandomBot()
game_state = new_game(["You", "Bot A"])
run_game(game_state, [human, bot])
if __name__ == "__main__":
main()
+111
View File
@@ -0,0 +1,111 @@
"""Public state for RL/search."""
from __future__ import annotations
from dataclasses import dataclass
from .base import (
BASE_COLORS,
BASE_INDEX,
GEM_ORDER,
Card,
GameState,
Noble,
PlayerState,
)
@dataclass(frozen=True)
class ObsCard:
"""Numeric-ish card view for RL/search."""
tier: int
points: int
color_index: int
cost: list[int]
@dataclass(frozen=True)
class ObsNoble:
"""Numeric-ish noble view for RL/search."""
points: int
requirements: list[int]
@dataclass(frozen=True)
class ObsPlayer:
"""Numeric-ish player view for RL/search."""
tokens: list[int]
discounts: list[int]
score: int
cards: list[ObsCard]
reserved: list[ObsCard]
nobles: list[ObsNoble]
@dataclass(frozen=True)
class Observation:
"""Full public state for RL/search."""
current_player: int
bank: list[int]
players: list[ObsPlayer]
table_by_tier: dict[int, list[ObsCard]]
decks_remaining: dict[int, int]
available_nobles: list[ObsNoble]
def _encode_card(card: Card) -> ObsCard:
color_index = BASE_INDEX.get(card.color, -1)
cost_vec = [card.cost.get(c, 0) for c in BASE_COLORS]
return ObsCard(
tier=card.tier,
points=card.points,
color_index=color_index,
cost=cost_vec,
)
def _encode_noble(noble: Noble) -> ObsNoble:
req_vec = [noble.requirements.get(c, 0) for c in BASE_COLORS]
return ObsNoble(
points=noble.points,
requirements=req_vec,
)
def _encode_player(player: PlayerState) -> ObsPlayer:
tokens_vec = [player.tokens[c] for c in GEM_ORDER]
discounts_vec = [player.discounts[c] for c in GEM_ORDER]
cards_enc = [_encode_card(c) for c in player.cards]
reserved_enc = [_encode_card(c) for c in player.reserved]
nobles_enc = [_encode_noble(n) for n in player.nobles]
return ObsPlayer(
tokens=tokens_vec,
discounts=discounts_vec,
score=player.score,
cards=cards_enc,
reserved=reserved_enc,
nobles=nobles_enc,
)
def to_observation(game: GameState) -> Observation:
"""Create a structured observation of the full public state."""
bank_vec = [game.bank[c] for c in GEM_ORDER]
players_enc = [_encode_player(p) for p in game.players]
table_enc: dict[int, list[ObsCard]] = {
tier: [_encode_card(c) for c in row] for tier, row in game.table_by_tier.items()
}
decks_remaining = {tier: len(deck) for tier, deck in game.decks_by_tier.items()}
nobles_enc = [_encode_noble(n) for n in game.available_nobles]
return Observation(
current_player=game.current_player_index,
bank=bank_vec,
players=players_enc,
table_by_tier=table_enc,
decks_remaining=decks_remaining,
available_nobles=nobles_enc,
)
+36
View File
@@ -0,0 +1,36 @@
"""Simulate a step in the game."""
from __future__ import annotations
import copy
from .base import Action, GameState, PlayerState, apply_action, check_nobles_for_player
from .bot import RandomBot
class SimStrategy(RandomBot):
"""Strategy used in simulate_step.
We never call choose_action here (caller chooses actions),
but we reuse discard/noble-selection logic.
"""
def choose_action(self, game: GameState, player: PlayerState) -> Action | None: # noqa: ARG002
"""Choose an action for the current player."""
msg = "SimStrategy.choose_action should not be used in simulate_step"
raise RuntimeError(msg)
def simulate_step(game: GameState, action: Action) -> GameState:
"""Return a deep-copied next state after applying action for the current player.
Useful for tree search / MCTS:
next_state = simulate_step(state, action)
"""
next_state = copy.deepcopy(game)
sim_strategy = SimStrategy()
apply_action(next_state, sim_strategy, action)
check_nobles_for_player(next_state, sim_strategy, next_state.current_player)
next_state.next_player()
return next_state
+50
View File
@@ -0,0 +1,50 @@
"""Simulator for Splendor game."""
from __future__ import annotations
from collections import defaultdict
from pathlib import Path
from statistics import mean
from .base import GameConfig, load_cards, load_nobles, new_game, run_game
from .bot import PersonalizedBot4, RandomBot
def main() -> None:
"""Main entry point."""
turn_limit = 1000
good_games = 0
games = 1
winners: dict[str, list] = defaultdict(list)
game_data = Path(__file__).parent / "game_data"
cards = load_cards(game_data / "cards/default.json")
nobles = load_nobles(game_data / "nobles/default.json")
for _ in range(games):
bot_a = RandomBot("bot_a")
bot_b = RandomBot("bot_b")
bot_c = RandomBot("bot_c")
bot_d = PersonalizedBot4("my_bot")
config = GameConfig(
cards=cards,
nobles=nobles,
turn_limit=turn_limit,
)
players = (bot_a, bot_b, bot_c, bot_d)
game_state = new_game(players, config)
winner, turns = run_game(game_state)
if turns < turn_limit:
good_games += 1
winners[winner.strategy.name].append(turns)
print(
f"out of {games} {turn_limit} turn games with {len(players)}"
f"random bots there where {good_games} games where a bot won"
)
for name, turns in winners.items():
print(f"{name} won {len(turns)} games in {mean(turns):.2f} turns")
if __name__ == "__main__":
main()
+1 -1
View File
@@ -4,7 +4,7 @@ import logging
import sys import sys
import tomllib import tomllib
from os import environ from os import environ
from pathlib import Path # noqa: TC003 This is required for the typer CLI from pathlib import Path
from socket import gethostname from socket import gethostname
import typer import typer
+1 -1
View File
@@ -451,7 +451,7 @@ def convert_aax_file_with_agent(aax_file: Path, config: ConversionConfig) -> Non
destination.parent.mkdir(parents=True, exist_ok=True) destination.parent.mkdir(parents=True, exist_ok=True)
try: try:
temp_file.replace(destination) temp_file.replace(destination)
except OSError as error: except Exception as error: # noqa: BLE001
write_review_file( write_review_file(
destination=destination, destination=destination,
ffprobe_metadata=ffprobe_metadata, ffprobe_metadata=ffprobe_metadata,
+1
View File
@@ -169,6 +169,7 @@ def csv_id(row: dict[str, str | None], csv_path: Path, row_number: int) -> int |
except ValueError as error: except ValueError as error:
msg = f"{csv_path}:{row_number}: id must be an integer: {value}" msg = f"{csv_path}:{row_number}: id must be an integer: {value}"
raise CatalogImportError(msg) from error raise CatalogImportError(msg) from error
return None
if __name__ == "__main__": if __name__ == "__main__":
@@ -1,3 +1,4 @@
# ruff: noqa: LOG015, E501, D102, D103, D107 These need the be fixed
"""Install NixOS on a ZFS pool.""" """Install NixOS on a ZFS pool."""
from __future__ import annotations from __future__ import annotations
@@ -16,9 +17,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
logger = logging.getLogger(__name__)
ESCAPE_KEY = 27
def configure_logger(level: str = "INFO") -> None: def configure_logger(level: str = "INFO") -> None:
"""Configure the logger. """Configure the logger.
@@ -44,7 +42,7 @@ def bash_wrapper(command: str) -> str:
Tuple[str, int]: A tuple containing the output of the command (stdout) as a string, Tuple[str, int]: A tuple containing the output of the command (stdout) as a string,
the error output (stderr) as a string (optional), and the return code as an integer. the error output (stderr) as a string (optional), and the return code as an integer.
""" """
logger.debug(f"running {command=}") logging.debug(f"running {command=}")
# This is a acceptable risk # This is a acceptable risk
process = Popen(command.split(), stdout=PIPE, stderr=PIPE) process = Popen(command.split(), stdout=PIPE, stderr=PIPE)
output, _ = process.communicate() output, _ = process.communicate()
@@ -65,7 +63,7 @@ def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
reserve (int, optional): The size of the reserve partition in GB. Defaults to 0. reserve (int, optional): The size of the reserve partition in GB. Defaults to 0.
minimum value is 0. minimum value is 0.
""" """
logger.info(f"partitioning {disk=}") logging.info(f"partitioning {disk=}")
swap_size = max(swap_size, 1) swap_size = max(swap_size, 1)
reserve = max(reserve, 0) reserve = max(reserve, 0)
@@ -73,16 +71,16 @@ def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
if reserve > 0: if reserve > 0:
msg = f"Creating swap partition on {disk=} with size {swap_size=}GiB and reserve {reserve=}GiB" msg = f"Creating swap partition on {disk=} with size {swap_size=}GiB and reserve {reserve=}GiB"
logger.info(msg) logging.info(msg)
swap_start = swap_size + reserve swap_start = swap_size + reserve
swap_partition = f"mkpart swap -{swap_start}GiB -{reserve}GiB " swap_partition = f"mkpart swap -{swap_start}GiB -{reserve}GiB "
else: else:
logger.info(f"Creating swap partition on {disk=} with size {swap_size=}GiB") logging.info(f"Creating swap partition on {disk=} with size {swap_size=}GiB")
swap_start = swap_size swap_start = swap_size
swap_partition = f"mkpart swap -{swap_start}GiB 100% " swap_partition = f"mkpart swap -{swap_start}GiB 100% "
logger.debug(f"{swap_partition=}") logging.debug(f"{swap_partition=}")
create_partitions = ( create_partitions = (
f"parted --script --align=optimal {disk} -- " f"parted --script --align=optimal {disk} -- "
@@ -94,7 +92,7 @@ def partition_disk(disk: str, swap_size: int, reserve: int = 0) -> None:
) )
bash_wrapper(create_partitions) bash_wrapper(create_partitions)
logger.info(f"{disk=} successfully partitioned") logging.info(f"{disk=} successfully partitioned")
def create_zfs_pool(pool_disks: Sequence[str], mnt_dir: str) -> None: def create_zfs_pool(pool_disks: Sequence[str], mnt_dir: str) -> None:
@@ -133,7 +131,7 @@ def create_zfs_pool(pool_disks: Sequence[str], mnt_dir: str) -> None:
bash_wrapper(zpool_create) bash_wrapper(zpool_create)
zpools = bash_wrapper("zpool list -o name") zpools = bash_wrapper("zpool list -o name")
if "root_pool" not in zpools.splitlines(): if "root_pool" not in zpools.splitlines():
logger.critical("Failed to create root_pool") logging.critical("Failed to create root_pool")
sys.exit(1) sys.exit(1)
@@ -153,7 +151,7 @@ def create_zfs_datasets() -> None:
} }
missing_datasets = expected_datasets.difference(datasets.splitlines()) missing_datasets = expected_datasets.difference(datasets.splitlines())
if missing_datasets: if missing_datasets:
logger.critical(f"Failed to create pools {missing_datasets}") logging.critical(f"Failed to create pools {missing_datasets}")
sys.exit(1) sys.exit(1)
@@ -166,7 +164,8 @@ def get_cpu_manufacturer() -> str:
for line in output.splitlines(): for line in output.splitlines():
if "vendor_id" in line: if "vendor_id" in line:
return id_vendor[line.split(": ")[1].strip()] return id_vendor[line.split(": ")[1].strip()]
error = "Failed to get CPU manufacturer"
error = "CPU manufacturer not found"
raise RuntimeError(error) raise RuntimeError(error)
@@ -201,15 +200,7 @@ def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], *, encrypt: boo
' imports = [ (modulesPath + "/installer/scan/not-detected.nix") ];\n\n' ' imports = [ (modulesPath + "/installer/scan/not-detected.nix") ];\n\n'
" boot = {\n" " boot = {\n"
" initrd = {\n" " initrd = {\n"
" availableKernelModules = [ \n" ' availableKernelModules = [ \n "ahci"\n "ehci_pci"\n "nvme"\n "sd_mod"\n "usb_storage"\n "usbhid"\n "xhci_pci"\n ];\n'
' "ahci"\n'
' "ehci_pci"\n'
' "nvme"\n'
' "sd_mod"\n'
' "usb_storage"\n'
' "usbhid"\n'
' "xhci_pci"\n'
" ];\n"
" kernelModules = [ ];\n" " kernelModules = [ ];\n"
f" {devices}" f" {devices}"
" };\n" " };\n"
@@ -223,18 +214,11 @@ def create_nix_hardware_file(mnt_dir: str, disks: Sequence[str], *, encrypt: boo
' "/nix" = {\n device = "root_pool/nix";\n fsType = "zfs";\n };\n\n' ' "/nix" = {\n device = "root_pool/nix";\n fsType = "zfs";\n };\n\n'
' "/boot" = {\n' ' "/boot" = {\n'
f' device = "/dev/disk/by-uuid/{get_boot_drive_id(disks[0])}";\n' f' device = "/dev/disk/by-uuid/{get_boot_drive_id(disks[0])}";\n'
' fsType = "vfat";\n' ' fsType = "vfat";\n options = [\n "fmask=0077"\n "dmask=0077"\n ];\n };\n };\n\n'
" options = [\n"
' "fmask=0077"\n'
' "dmask=0077"\n'
" ];\n"
" };\n"
" };\n\n"
" swapDevices = [ ];\n\n" " swapDevices = [ ];\n\n"
" networking.useDHCP = lib.mkDefault true;\n\n" " networking.useDHCP = lib.mkDefault true;\n\n"
' nixpkgs.hostPlatform = lib.mkDefault "x86_64-linux";\n' ' nixpkgs.hostPlatform = lib.mkDefault "x86_64-linux";\n'
f" hardware.cpu.{cpu_manufacturer}.updateMicrocode = lib.mkDefault " f" hardware.cpu.{cpu_manufacturer}.updateMicrocode = lib.mkDefault config.hardware.enableRedistributableFirmware;\n"
"config.hardware.enableRedistributableFirmware;\n"
f' networking.hostId = "{host_id}";\n' f' networking.hostId = "{host_id}";\n'
"}\n" "}\n"
) )
@@ -272,32 +256,19 @@ def installer(
encrypt_key: str | None, encrypt_key: str | None,
) -> None: ) -> None:
"""Main.""" """Main."""
logger.info("Starting installation") logging.info("Starting installation")
for disk in disks: for disk in disks:
partition_disk(disk, swap_size, reserve) partition_disk(disk, swap_size, reserve)
if encrypt_key: if encrypt_key:
sleep(1) sleep(1)
key_input = encrypt_key.encode() for command in (
run( f'printf "{encrypt_key}" | cryptsetup luksFormat --type luks2 {disk}-part2 -',
("cryptsetup", "luksFormat", "--type", "luks2", f"{disk}-part2", "-"), f'printf "{encrypt_key}" | cryptsetup luksOpen {disk}-part2 luks-root-pool-{disk.split("/")[-1]}-part2 -',
input=key_input, ):
check=True, run(command, shell=True, check=True) # noqa: S602
)
run(
(
"cryptsetup",
"luksOpen",
f"{disk}-part2",
f"luks-root-pool-{disk.split('/')[-1]}-part2",
"-",
),
input=key_input,
check=True,
)
# Fixed mount point for the new system; the installer runs as root on a fresh disk
mnt_dir = "/tmp/nix_install" # noqa: S108 mnt_dir = "/tmp/nix_install" # noqa: S108
Path(mnt_dir).mkdir(parents=True, exist_ok=True) Path(mnt_dir).mkdir(parents=True, exist_ok=True)
@@ -311,73 +282,59 @@ def installer(
create_zfs_datasets() create_zfs_datasets()
install_nixos(mnt_dir, disks, encrypt=bool(encrypt_key)) install_nixos(mnt_dir, disks, encrypt=encrypt_key)
logger.info("Installation complete") logging.info("Installation complete")
class Cursor: class Cursor:
"""Track cursor position and constrain movement to screen bounds.""" """Cursor class to store the cursor position."""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize cursor position and screen dimensions."""
self.x_position = 0 self.x_position = 0
self.y_position = 0 self.y_position = 0
self.height = 0 self.height = 0
self.width = 0 self.width = 0
def set_height(self, height: int) -> None: def set_height(self, height: int) -> None:
"""Set the maximum screen height."""
self.height = height self.height = height
def set_width(self, width: int) -> None: def set_width(self, width: int) -> None:
"""Set the maximum screen width."""
self.width = width self.width = width
def x_bounce_check(self, cursor: int) -> int: def x_bounce_check(self, cursor: int) -> int:
"""Clamp an x position to the screen width."""
cursor = max(0, cursor) cursor = max(0, cursor)
return min(self.width - 1, cursor) return min(self.width - 1, cursor)
def y_bounce_check(self, cursor: int) -> int: def y_bounce_check(self, cursor: int) -> int:
"""Clamp a y position to the screen height."""
cursor = max(0, cursor) cursor = max(0, cursor)
return min(self.height - 1, cursor) return min(self.height - 1, cursor)
def set_x(self, x: int) -> None: def set_x(self, x: int) -> None:
"""Set the cursor x position."""
self.x_position = self.x_bounce_check(x) self.x_position = self.x_bounce_check(x)
def set_y(self, y: int) -> None: def set_y(self, y: int) -> None:
"""Set the cursor y position."""
self.y_position = self.y_bounce_check(y) self.y_position = self.y_bounce_check(y)
def get_x(self) -> int: def get_x(self) -> int:
"""Get the cursor x position."""
return self.x_position return self.x_position
def get_y(self) -> int: def get_y(self) -> int:
"""Get the cursor y position."""
return self.y_position return self.y_position
def move_up(self) -> None: def move_up(self) -> None:
"""Move the cursor up one row."""
self.set_y(self.y_position - 1) self.set_y(self.y_position - 1)
def move_down(self) -> None: def move_down(self) -> None:
"""Move the cursor down one row."""
self.set_y(self.y_position + 1) self.set_y(self.y_position + 1)
def move_left(self) -> None: def move_left(self) -> None:
"""Move the cursor left one column."""
self.set_x(self.x_position - 1) self.set_x(self.x_position - 1)
def move_right(self) -> None: def move_right(self) -> None:
"""Move the cursor right one column."""
self.set_x(self.x_position + 1) self.set_x(self.x_position + 1)
def navigation(self, key: int) -> None: def navigation(self, key: int) -> None:
"""Move the cursor for a curses navigation key."""
action = { action = {
curses.KEY_DOWN: self.move_down, curses.KEY_DOWN: self.move_down,
curses.KEY_UP: self.move_up, curses.KEY_UP: self.move_up,
@@ -392,7 +349,6 @@ class State:
"""State class to store the state of the program.""" """State class to store the state of the program."""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize installer menu state."""
self.key = 0 self.key = 0
self.cursor = Cursor() self.cursor = Cursor()
@@ -410,7 +366,6 @@ class State:
def get_device(raw_device: str) -> dict[str, str]: def get_device(raw_device: str) -> dict[str, str]:
"""Parse an lsblk key-value device row."""
raw_device_components = raw_device.split(" ") raw_device_components = raw_device.split(" ")
return {thing.split("=")[0].lower(): thing.split("=")[1].strip('"') for thing in raw_device_components} return {thing.split("=")[0].lower(): thing.split("=")[1].strip('"') for thing in raw_device_components}
@@ -440,7 +395,6 @@ def get_device_id_mapping() -> dict[str, set[str]]:
def calculate_device_menu_padding(devices: list[dict[str, str]], column: str, padding: int = 0) -> int: def calculate_device_menu_padding(devices: list[dict[str, str]], column: str, padding: int = 0) -> int:
"""Calculate the width needed for a device menu column."""
return max(len(device[column]) for device in devices) + padding return max(len(device[column]) for device in devices) + padding
@@ -452,7 +406,6 @@ def draw_device_ids(
menu_width: list[int], menu_width: list[int],
device_ids: set[str], device_ids: set[str],
) -> tuple[State, int]: ) -> tuple[State, int]:
"""Draw selectable device IDs for a device row."""
for device_id in sorted(device_ids): for device_id in sorted(device_ids):
row_number = row_number + 1 row_number = row_number + 1
if row_number == state.cursor.get_y() and state.cursor.get_x() in menu_width: if row_number == state.cursor.get_y() and state.cursor.get_x() in menu_width:
@@ -481,7 +434,7 @@ def draw_device_menu(
state: State, state: State,
menu_start_y: int = 0, menu_start_y: int = 0,
menu_start_x: int = 0, menu_start_x: int = 0,
) -> tuple[State, int]: ) -> State:
"""Draw the device menu and handle user input. """Draw the device menu and handle user input.
Args: Args:
@@ -537,7 +490,6 @@ def draw_device_menu(
def debug_menu(std_screen: curses.window, key: int) -> None: def debug_menu(std_screen: curses.window, key: int) -> None:
"""Draw debug information for the current curses screen."""
height, width = std_screen.getmaxyx() height, width = std_screen.getmaxyx()
width_height = f"Width: {width}, Height: {height}" width_height = f"Width: {width}, Height: {height}"
std_screen.addstr(height - 4, 0, width_height, curses.color_pair(5)) std_screen.addstr(height - 4, 0, width_height, curses.color_pair(5))
@@ -557,7 +509,6 @@ def status_bar(
width: int, width: int,
height: int, height: int,
) -> None: ) -> None:
"""Draw the footer status bar."""
std_screen.attron(curses.A_REVERSE) std_screen.attron(curses.A_REVERSE)
std_screen.attron(curses.color_pair(3)) std_screen.attron(curses.color_pair(3))
@@ -570,7 +521,6 @@ def status_bar(
def set_color() -> None: def set_color() -> None:
"""Initialize curses color pairs."""
curses.start_color() curses.start_color()
curses.use_default_colors() curses.use_default_colors()
for i in range(curses.COLORS): for i in range(curses.COLORS):
@@ -578,7 +528,6 @@ def set_color() -> None:
def get_text_input(std_screen: curses.window, prompt: str, y: int, x: int) -> str: def get_text_input(std_screen: curses.window, prompt: str, y: int, x: int) -> str:
"""Read text input from a curses screen."""
curses.echo() curses.echo()
std_screen.addstr(y, x, prompt) std_screen.addstr(y, x, prompt)
input_str = "" input_str = ""
@@ -586,7 +535,7 @@ def get_text_input(std_screen: curses.window, prompt: str, y: int, x: int) -> st
key = std_screen.getch() key = std_screen.getch()
if key == ord("\n"): if key == ord("\n"):
break break
if key == ESCAPE_KEY: if key == 27: # ESC key # noqa: PLR2004
input_str = "" input_str = ""
break break
if key in (curses.KEY_BACKSPACE, ord("\b"), 127): if key in (curses.KEY_BACKSPACE, ord("\b"), 127):
@@ -604,7 +553,6 @@ def swap_size_input(
state: State, state: State,
swap_offset: int, swap_offset: int,
) -> State: ) -> State:
"""Handle swap size input."""
swap_size_text = "Swap size (GB): " swap_size_text = "Swap size (GB): "
std_screen.addstr(swap_offset, 0, f"{swap_size_text}{state.swap_size}") std_screen.addstr(swap_offset, 0, f"{swap_size_text}{state.swap_size}")
if state.key == ord("\n") and state.cursor.get_y() == swap_offset: if state.key == ord("\n") and state.cursor.get_y() == swap_offset:
@@ -628,7 +576,6 @@ def reserve_size_input(
state: State, state: State,
reserve_offset: int, reserve_offset: int,
) -> State: ) -> State:
"""Handle reserve size input."""
reserve_size_text = "reserve size (GB): " reserve_size_text = "reserve size (GB): "
std_screen.addstr(reserve_offset, 0, f"{reserve_size_text}{state.reserve_size}") std_screen.addstr(reserve_offset, 0, f"{reserve_size_text}{state.reserve_size}")
if state.key == ord("\n") and state.cursor.get_y() == reserve_offset: if state.key == ord("\n") and state.cursor.get_y() == reserve_offset:
@@ -652,7 +599,6 @@ def draw_menu(std_screen: curses.window) -> State:
Args: Args:
std_screen (curses.window): the curses window to draw on std_screen (curses.window): the curses window to draw on
Returns: Returns:
State: the state object State: the state object
""" """
@@ -712,18 +658,17 @@ def draw_menu(std_screen: curses.window) -> State:
def main() -> None: def main() -> None:
"""Run the installer menu and start installation."""
configure_logger("DEBUG") configure_logger("DEBUG")
state = curses.wrapper(draw_menu) state = curses.wrapper(draw_menu)
encrypt_key = getenv("ENCRYPT_KEY") encrypt_key = getenv("ENCRYPT_KEY")
logger.info("installing_nixos") logging.info("installing_nixos")
logger.info(f"disks: {state.selected_device_ids}") logging.info(f"disks: {state.selected_device_ids}")
logger.info(f"swap_size: {state.swap_size}") logging.info(f"swap_size: {state.swap_size}")
logger.info(f"reserve: {state.reserve_size}") logging.info(f"reserve: {state.reserve_size}")
logger.info(f"encrypted: {bool(encrypt_key)}") logging.info(f"encrypted: {bool(encrypt_key)}")
sleep(3) sleep(3)
+1
View File
@@ -0,0 +1 @@
"""Van inventory FastAPI application."""
+16
View File
@@ -0,0 +1,16 @@
"""FastAPI dependencies for van inventory."""
from collections.abc import Iterator
from typing import Annotated
from fastapi import Depends, Request
from sqlalchemy.orm import Session
def get_db(request: Request) -> Iterator[Session]:
"""Get database session from app state."""
with Session(request.app.state.engine) as session:
yield session
DbSession = Annotated[Session, Depends(get_db)]
+56
View File
@@ -0,0 +1,56 @@
"""FastAPI app for van inventory."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Annotated
import typer
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.van_inventory.routers import api_router, frontend_router
STATIC_DIR = Path(__file__).resolve().parent / "static"
if TYPE_CHECKING:
from collections.abc import AsyncIterator
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
app.state.engine = get_postgres_engine(name="VAN_INVENTORY")
yield
app.state.engine.dispose()
app = FastAPI(title="Van Inventory", lifespan=lifespan)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.include_router(api_router)
app.include_router(frontend_router)
return app
def serve(
# Intentionally binds all interfaces — this is a LAN-only van server
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "0.0.0.0", # noqa: S104
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8001,
log_level: Annotated[str, typer.Option("--log-level", "-l", help="Log level")] = "INFO",
) -> None:
"""Start the Van Inventory server."""
configure_logger(log_level)
app = create_app()
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
typer.run(serve)
+6
View File
@@ -0,0 +1,6 @@
"""Van inventory API routers."""
from python.van_inventory.routers.api import router as api_router
from python.van_inventory.routers.frontend import router as frontend_router
__all__ = ["api_router", "frontend_router"]
+314
View File
@@ -0,0 +1,314 @@
"""Van inventory API router."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from python.orm.van_inventory.models import Item, Meal, MealIngredient
if TYPE_CHECKING:
from python.van_inventory.dependencies import DbSession
# --- Schemas ---
class ItemCreate(BaseModel):
"""Schema for creating an item."""
name: str
quantity: float = Field(default=0, ge=0)
unit: str
category: str | None = None
class ItemUpdate(BaseModel):
"""Schema for updating an item."""
name: str | None = None
quantity: float | None = Field(default=None, ge=0)
unit: str | None = None
category: str | None = None
class ItemResponse(BaseModel):
"""Schema for item response."""
id: int
name: str
quantity: float
unit: str
category: str | None
model_config = {"from_attributes": True}
class IngredientCreate(BaseModel):
"""Schema for adding an ingredient to a meal."""
item_id: int
quantity_needed: float = Field(gt=0)
class MealCreate(BaseModel):
"""Schema for creating a meal."""
name: str
instructions: str | None = None
ingredients: list[IngredientCreate] = []
class MealUpdate(BaseModel):
"""Schema for updating a meal."""
name: str | None = None
instructions: str | None = None
class IngredientResponse(BaseModel):
"""Schema for ingredient response."""
item_id: int
item_name: str
quantity_needed: float
unit: str
model_config = {"from_attributes": True}
class MealResponse(BaseModel):
"""Schema for meal response."""
id: int
name: str
instructions: str | None
ingredients: list[IngredientResponse] = []
model_config = {"from_attributes": True}
@classmethod
def from_meal(cls, meal: Meal) -> MealResponse:
"""Build a MealResponse from an ORM Meal with loaded ingredients."""
return cls(
id=meal.id,
name=meal.name,
instructions=meal.instructions,
ingredients=[
IngredientResponse(
item_id=mi.item_id,
item_name=mi.item.name,
quantity_needed=mi.quantity_needed,
unit=mi.item.unit,
)
for mi in meal.ingredients
],
)
class ShoppingItem(BaseModel):
"""An item needed for a meal that is short on stock."""
item_name: str
unit: str
needed: float
have: float
short: float
class MealAvailability(BaseModel):
"""Availability status for a meal."""
meal_id: int
meal_name: str
can_make: bool
missing: list[ShoppingItem] = []
# --- Routes ---
router = APIRouter(prefix="/api", tags=["van_inventory"])
# Items
@router.post("/items", response_model=ItemResponse)
def create_item(item: ItemCreate, db: DbSession) -> Item:
"""Create a new inventory item."""
db_item = Item(**item.model_dump())
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
@router.get("/items", response_model=list[ItemResponse])
def list_items(db: DbSession) -> list[Item]:
"""List all inventory items."""
return list(db.scalars(select(Item).order_by(Item.name)).all())
@router.get("/items/{item_id}", response_model=ItemResponse)
def get_item(item_id: int, db: DbSession) -> Item:
"""Get an item by ID."""
item = db.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item
@router.patch("/items/{item_id}", response_model=ItemResponse)
def update_item(item_id: int, item: ItemUpdate, db: DbSession) -> Item:
"""Update an item by ID."""
db_item = db.get(Item, item_id)
if not db_item:
raise HTTPException(status_code=404, detail="Item not found")
for key, value in item.model_dump(exclude_unset=True).items():
setattr(db_item, key, value)
db.commit()
db.refresh(db_item)
return db_item
@router.delete("/items/{item_id}")
def delete_item(item_id: int, db: DbSession) -> dict[str, bool]:
"""Delete an item by ID."""
item = db.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
db.delete(item)
db.commit()
return {"deleted": True}
# Meals
@router.post("/meals", response_model=MealResponse)
def create_meal(meal: MealCreate, db: DbSession) -> MealResponse:
"""Create a new meal with optional ingredients."""
for ing in meal.ingredients:
if not db.get(Item, ing.item_id):
raise HTTPException(status_code=422, detail=f"Item {ing.item_id} not found")
db_meal = Meal(name=meal.name, instructions=meal.instructions)
db.add(db_meal)
db.flush()
for ing in meal.ingredients:
db.add(MealIngredient(meal_id=db_meal.id, item_id=ing.item_id, quantity_needed=ing.quantity_needed))
db.commit()
db_meal = db.scalar(
select(Meal)
.where(Meal.id == db_meal.id)
.options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
return MealResponse.from_meal(db_meal)
@router.get("/meals", response_model=list[MealResponse])
def list_meals(db: DbSession) -> list[MealResponse]:
"""List all meals with ingredients."""
meals = list(
db.scalars(
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
).all()
)
return [MealResponse.from_meal(m) for m in meals]
@router.get("/meals/availability", response_model=list[MealAvailability])
def check_all_meals(db: DbSession) -> list[MealAvailability]:
"""Check which meals can be made with current inventory."""
meals = list(
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
)
return [_check_meal(m) for m in meals]
@router.get("/meals/{meal_id}", response_model=MealResponse)
def get_meal(meal_id: int, db: DbSession) -> MealResponse:
"""Get a meal by ID with ingredients."""
meal = db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
return MealResponse.from_meal(meal)
@router.delete("/meals/{meal_id}")
def delete_meal(meal_id: int, db: DbSession) -> dict[str, bool]:
"""Delete a meal by ID."""
meal = db.get(Meal, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
db.delete(meal)
db.commit()
return {"deleted": True}
@router.post("/meals/{meal_id}/ingredients", response_model=MealResponse)
def add_ingredient(meal_id: int, ingredient: IngredientCreate, db: DbSession) -> MealResponse:
"""Add an ingredient to a meal."""
meal = db.get(Meal, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
if not db.get(Item, ingredient.item_id):
raise HTTPException(status_code=422, detail="Item not found")
existing = db.scalar(
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == ingredient.item_id)
)
if existing:
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
db.add(MealIngredient(meal_id=meal_id, item_id=ingredient.item_id, quantity_needed=ingredient.quantity_needed))
db.commit()
meal = db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
return MealResponse.from_meal(meal)
@router.delete("/meals/{meal_id}/ingredients/{item_id}")
def remove_ingredient(meal_id: int, item_id: int, db: DbSession) -> dict[str, bool]:
"""Remove an ingredient from a meal."""
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
if not mi:
raise HTTPException(status_code=404, detail="Ingredient not found")
db.delete(mi)
db.commit()
return {"deleted": True}
@router.get("/meals/{meal_id}/availability", response_model=MealAvailability)
def check_meal(meal_id: int, db: DbSession) -> MealAvailability:
"""Check if a specific meal can be made and what's missing."""
meal = db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
return _check_meal(meal)
def _check_meal(meal: Meal) -> MealAvailability:
missing = [
ShoppingItem(
item_name=mi.item.name,
unit=mi.item.unit,
needed=mi.quantity_needed,
have=mi.item.quantity,
short=mi.quantity_needed - mi.item.quantity,
)
for mi in meal.ingredients
if mi.item.quantity < mi.quantity_needed
]
return MealAvailability(
meal_id=meal.id,
meal_name=meal.name,
can_make=len(missing) == 0,
missing=missing,
)
+198
View File
@@ -0,0 +1,198 @@
"""HTMX frontend routes for van inventory."""
from __future__ import annotations
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from python.orm.van_inventory.models import Item, Meal, MealIngredient
# FastAPI needs DbSession at runtime to resolve the Depends() annotation
from python.van_inventory.dependencies import DbSession # noqa: TC001
from python.van_inventory.routers.api import _check_meal
TEMPLATE_DIR = Path(__file__).resolve().parent.parent / "templates"
templates = Jinja2Templates(directory=TEMPLATE_DIR)
router = APIRouter(tags=["frontend"])
# --- Items ---
@router.get("/", response_class=HTMLResponse)
def items_page(request: Request, db: DbSession) -> HTMLResponse:
"""Render the inventory page."""
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "items.html", {"items": items})
@router.post("/items", response_class=HTMLResponse)
def htmx_create_item(
request: Request,
db: DbSession,
name: Annotated[str, Form()],
quantity: Annotated[float, Form()] = 0,
unit: Annotated[str, Form()] = "",
category: Annotated[str | None, Form()] = None,
) -> HTMLResponse:
"""Create an item and return updated item rows."""
if quantity < 0:
raise HTTPException(status_code=422, detail="Quantity must not be negative")
db.add(Item(name=name, quantity=quantity, unit=unit, category=category or None))
db.commit()
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
@router.patch("/items/{item_id}", response_class=HTMLResponse)
def htmx_update_item(
request: Request,
item_id: int,
db: DbSession,
quantity: Annotated[float, Form()],
) -> HTMLResponse:
"""Update an item's quantity and return updated item rows."""
if quantity < 0:
raise HTTPException(status_code=422, detail="Quantity must not be negative")
item = db.get(Item, item_id)
if item:
item.quantity = quantity
db.commit()
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
@router.delete("/items/{item_id}", response_class=HTMLResponse)
def htmx_delete_item(request: Request, item_id: int, db: DbSession) -> HTMLResponse:
"""Delete an item and return updated item rows."""
item = db.get(Item, item_id)
if item:
db.delete(item)
db.commit()
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "partials/item_rows.html", {"items": items})
# --- Meals ---
def _load_meals(db: DbSession) -> list[Meal]:
return list(
db.scalars(
select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item)).order_by(Meal.name)
).all()
)
@router.get("/meals", response_class=HTMLResponse)
def meals_page(request: Request, db: DbSession) -> HTMLResponse:
"""Render the meals page."""
meals = _load_meals(db)
return templates.TemplateResponse(request, "meals.html", {"meals": meals})
@router.post("/meals", response_class=HTMLResponse)
def htmx_create_meal(
request: Request,
db: DbSession,
name: Annotated[str, Form()],
instructions: Annotated[str | None, Form()] = None,
) -> HTMLResponse:
"""Create a meal and return updated meal rows."""
db.add(Meal(name=name, instructions=instructions or None))
db.commit()
meals = _load_meals(db)
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
@router.delete("/meals/{meal_id}", response_class=HTMLResponse)
def htmx_delete_meal(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
"""Delete a meal and return updated meal rows."""
meal = db.get(Meal, meal_id)
if meal:
db.delete(meal)
db.commit()
meals = _load_meals(db)
return templates.TemplateResponse(request, "partials/meal_rows.html", {"meals": meals})
# --- Meal detail ---
def _load_meal(db: DbSession, meal_id: int) -> Meal | None:
return db.scalar(
select(Meal).where(Meal.id == meal_id).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))
)
@router.get("/meals/{meal_id}", response_class=HTMLResponse)
def meal_detail_page(request: Request, meal_id: int, db: DbSession) -> HTMLResponse:
"""Render the meal detail page."""
meal = _load_meal(db, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
items = list(db.scalars(select(Item).order_by(Item.name)).all())
return templates.TemplateResponse(request, "meal_detail.html", {"meal": meal, "items": items})
@router.post("/meals/{meal_id}/ingredients", response_class=HTMLResponse)
def htmx_add_ingredient(
request: Request,
meal_id: int,
db: DbSession,
item_id: Annotated[int, Form()],
quantity_needed: Annotated[float, Form()],
) -> HTMLResponse:
"""Add an ingredient to a meal and return updated ingredient rows."""
if quantity_needed <= 0:
raise HTTPException(status_code=422, detail="Quantity must be positive")
meal = db.get(Meal, meal_id)
if not meal:
raise HTTPException(status_code=404, detail="Meal not found")
if not db.get(Item, item_id):
raise HTTPException(status_code=422, detail="Item not found")
existing = db.scalar(
select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id)
)
if existing:
raise HTTPException(status_code=409, detail="Ingredient already exists for this meal")
db.add(MealIngredient(meal_id=meal_id, item_id=item_id, quantity_needed=quantity_needed))
db.commit()
meal = _load_meal(db, meal_id)
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
@router.delete("/meals/{meal_id}/ingredients/{item_id}", response_class=HTMLResponse)
def htmx_remove_ingredient(
request: Request,
meal_id: int,
item_id: int,
db: DbSession,
) -> HTMLResponse:
"""Remove an ingredient from a meal and return updated ingredient rows."""
mi = db.scalar(select(MealIngredient).where(MealIngredient.meal_id == meal_id, MealIngredient.item_id == item_id))
if mi:
db.delete(mi)
db.commit()
meal = _load_meal(db, meal_id)
return templates.TemplateResponse(request, "partials/ingredient_rows.html", {"meal": meal})
# --- Availability ---
@router.get("/availability", response_class=HTMLResponse)
def availability_page(request: Request, db: DbSession) -> HTMLResponse:
"""Render the meal availability page."""
meals = list(
db.scalars(select(Meal).options(selectinload(Meal.ingredients).selectinload(MealIngredient.item))).all()
)
availability = [_check_meal(m) for m in meals]
return templates.TemplateResponse(request, "availability.html", {"availability": availability})
+212
View File
@@ -0,0 +1,212 @@
:root {
--neon-pink: #ff2a6d;
--neon-cyan: #05d9e8;
--neon-yellow: #f9f002;
--neon-purple: #d300c5;
--bg-dark: #0a0a0f;
--bg-panel: #0d0d1a;
--bg-input: #111128;
--border: #1a1a3e;
--text: #c0c0d0;
--text-dim: #8e8ea0;
}
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: 'Share Tech Mono', monospace;
max-width: 900px;
margin: 0 auto;
padding: 1rem;
background: var(--bg-dark);
color: var(--text);
position: relative;
}
/* Scanline overlay */
body::before {
content: '';
position: fixed;
top: 0; left: 0; right: 0; bottom: 0;
background: repeating-linear-gradient(
0deg,
transparent,
transparent 2px,
rgba(0, 0, 0, 0.08) 2px,
rgba(0, 0, 0, 0.08) 4px
);
pointer-events: none;
z-index: 9999;
}
h1, h2, h3 {
font-family: 'Orbitron', sans-serif;
margin-bottom: 0.5rem;
color: var(--neon-cyan);
text-shadow: 0 0 10px rgba(5, 217, 232, 0.5), 0 0 40px rgba(5, 217, 232, 0.2);
text-transform: uppercase;
letter-spacing: 2px;
}
a { color: var(--neon-pink); text-decoration: none; transition: all 0.2s; }
a:hover {
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8), 0 0 20px rgba(255, 42, 109, 0.4);
}
nav {
display: flex;
gap: 1.5rem;
padding: 1rem 0;
border-bottom: 1px solid var(--border);
margin-bottom: 1.5rem;
position: relative;
}
nav::after {
content: '';
position: absolute;
bottom: -1px;
left: 0;
right: 0;
height: 1px;
background: linear-gradient(90deg, var(--neon-pink), var(--neon-cyan), var(--neon-purple));
opacity: 0.6;
}
nav a {
font-family: 'Orbitron', sans-serif;
font-weight: 700;
font-size: 0.85rem;
letter-spacing: 1px;
text-transform: uppercase;
padding: 0.3rem 0;
border-bottom: 2px solid transparent;
transition: all 0.2s;
}
nav a:hover {
border-bottom-color: var(--neon-pink);
text-shadow: 0 0 8px rgba(255, 42, 109, 0.8);
}
table {
width: 100%;
border-collapse: collapse;
margin: 1rem 0;
border: 1px solid var(--border);
}
th, td {
text-align: left;
padding: 0.6rem 0.75rem;
border-bottom: 1px solid var(--border);
}
th {
font-family: 'Orbitron', sans-serif;
color: var(--neon-cyan);
font-size: 0.7rem;
text-transform: uppercase;
letter-spacing: 2px;
background: var(--bg-panel);
border-bottom: 1px solid var(--neon-cyan);
text-shadow: 0 0 6px rgba(5, 217, 232, 0.3);
}
tr:hover td {
background: rgba(5, 217, 232, 0.03);
}
form {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
align-items: end;
margin: 1rem 0;
padding: 1rem;
border: 1px solid var(--border);
background: var(--bg-panel);
}
input, select {
padding: 0.5rem 0.6rem;
border: 1px solid var(--border);
border-radius: 2px;
background: var(--bg-input);
color: var(--neon-cyan);
font-family: 'Share Tech Mono', monospace;
transition: all 0.2s;
}
input:focus, select:focus {
outline: none;
border-color: var(--neon-cyan);
box-shadow: 0 0 8px rgba(5, 217, 232, 0.3), inset 0 0 8px rgba(5, 217, 232, 0.05);
}
button {
padding: 0.5rem 1.2rem;
border: 1px solid var(--neon-pink);
border-radius: 2px;
background: transparent;
color: var(--neon-pink);
cursor: pointer;
font-family: 'Orbitron', sans-serif;
font-weight: 700;
font-size: 0.7rem;
letter-spacing: 1px;
text-transform: uppercase;
transition: all 0.2s;
}
button:hover {
background: var(--neon-pink);
color: var(--bg-dark);
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5), 0 0 30px rgba(255, 42, 109, 0.2);
}
button.danger {
border-color: var(--text-dim);
color: var(--text-dim);
}
button.danger:hover {
border-color: var(--neon-pink);
background: var(--neon-pink);
color: var(--bg-dark);
box-shadow: 0 0 15px rgba(255, 42, 109, 0.5);
}
.badge {
display: inline-block;
padding: 0.2rem 0.6rem;
border-radius: 2px;
font-family: 'Orbitron', sans-serif;
font-size: 0.65rem;
font-weight: 700;
letter-spacing: 1px;
text-transform: uppercase;
}
.badge.yes {
background: rgba(5, 217, 232, 0.1);
color: var(--neon-cyan);
border: 1px solid var(--neon-cyan);
text-shadow: 0 0 6px rgba(5, 217, 232, 0.5);
}
.badge.no {
background: rgba(255, 42, 109, 0.1);
color: var(--neon-pink);
border: 1px solid var(--neon-pink);
text-shadow: 0 0 6px rgba(255, 42, 109, 0.5);
}
.missing-list { font-size: 0.85rem; color: var(--text-dim); }
label {
font-size: 0.75rem;
color: var(--text-dim);
display: flex;
flex-direction: column;
gap: 0.2rem;
text-transform: uppercase;
letter-spacing: 1px;
}
.flash {
padding: 0.5rem 1rem;
margin: 0.5rem 0;
border-radius: 2px;
background: rgba(5, 217, 232, 0.1);
color: var(--neon-cyan);
border: 1px solid var(--neon-cyan);
}
@@ -0,0 +1,30 @@
{% extends "base.html" %}
{% block title %}What Can I Make? - Van{% endblock %}
{% block content %}
<h1>What Can I Make?</h1>
<table>
<thead>
<tr><th>Meal</th><th>Status</th><th>Missing</th></tr>
</thead>
<tbody>
{% for meal in availability %}
<tr>
<td><a href="/meals/{{ meal.meal_id }}">{{ meal.meal_name }}</a></td>
<td>
{% if meal.can_make %}
<span class="badge yes">Ready</span>
{% else %}
<span class="badge no">Missing items</span>
{% endif %}
</td>
<td class="missing-list">
{% for m in meal.missing %}
{{ m.item_name }}: need {{ m.short }} more {{ m.unit }}{% if not loop.last %}, {% endif %}
{% endfor %}
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endblock %}
+20
View File
@@ -0,0 +1,20 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}Van Inventory{% endblock %}</title>
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link href="https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&display=swap" rel="stylesheet">
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<nav>
<a href="/">Inventory</a>
<a href="/meals">Meals</a>
<a href="/availability">What Can I Make?</a>
</nav>
{% block content %}{% endblock %}
</body>
</html>
+17
View File
@@ -0,0 +1,17 @@
{% extends "base.html" %}
{% block title %}Inventory - Van{% endblock %}
{% block content %}
<h1>Van Inventory</h1>
<form hx-post="/items" hx-target="#item-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
<label>Name <input type="text" name="name" required></label>
<label>Qty <input type="number" name="quantity" step="any" value="0" min="0" required></label>
<label>Unit <input type="text" name="unit" required placeholder="lbs, cans, etc"></label>
<label>Category <input type="text" name="category" placeholder="optional"></label>
<button type="submit">Add Item</button>
</form>
<div id="item-list">
{% include "partials/item_rows.html" %}
</div>
{% endblock %}
@@ -0,0 +1,24 @@
{% extends "base.html" %}
{% block title %}{{ meal.name }} - Van{% endblock %}
{% block content %}
<h1>{{ meal.name }}</h1>
{% if meal.instructions %}<p>{{ meal.instructions }}</p>{% endif %}
<h2>Ingredients</h2>
<form hx-post="/meals/{{ meal.id }}/ingredients" hx-target="#ingredient-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
<label>Item
<select name="item_id" required>
<option value="">--</option>
{% for item in items %}
<option value="{{ item.id }}">{{ item.name }} ({{ item.unit }})</option>
{% endfor %}
</select>
</label>
<label>Qty needed <input type="number" name="quantity_needed" step="any" min="0.01" required></label>
<button type="submit">Add</button>
</form>
<div id="ingredient-list">
{% include "partials/ingredient_rows.html" %}
</div>
{% endblock %}
+15
View File
@@ -0,0 +1,15 @@
{% extends "base.html" %}
{% block title %}Meals - Van{% endblock %}
{% block content %}
<h1>Meals</h1>
<form hx-post="/meals" hx-target="#meal-list" hx-swap="innerHTML" hx-on::after-request="if(event.detail.successful) this.reset()">
<label>Name <input type="text" name="name" required></label>
<label>Instructions <input type="text" name="instructions" placeholder="optional"></label>
<button type="submit">Add Meal</button>
</form>
<div id="meal-list">
{% include "partials/meal_rows.html" %}
</div>
{% endblock %}
@@ -0,0 +1,16 @@
<table>
<thead>
<tr><th>Item</th><th>Needed</th><th>Have</th><th>Unit</th><th></th></tr>
</thead>
<tbody>
{% for mi in meal.ingredients %}
<tr>
<td>{{ mi.item.name }}</td>
<td>{{ mi.quantity_needed }}</td>
<td>{{ mi.item.quantity }}</td>
<td>{{ mi.item.unit }}</td>
<td><button class="danger" hx-delete="/meals/{{ meal.id }}/ingredients/{{ mi.item_id }}" hx-target="#ingredient-list" hx-swap="innerHTML" hx-confirm="Remove {{ mi.item.name }}?">X</button></td>
</tr>
{% endfor %}
</tbody>
</table>
@@ -0,0 +1,21 @@
<table>
<thead>
<tr><th>Name</th><th>Qty</th><th>Unit</th><th>Category</th><th></th></tr>
</thead>
<tbody>
{% for item in items %}
<tr>
<td>{{ item.name }}</td>
<td>
<form hx-patch="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" style="display:inline; margin:0;">
<input type="number" name="quantity" value="{{ item.quantity }}" step="any" min="0" style="width:5rem">
<button type="submit" style="padding:0.2rem 0.5rem; font-size:0.8rem;">Update</button>
</form>
</td>
<td>{{ item.unit }}</td>
<td>{{ item.category or "" }}</td>
<td><button class="danger" hx-delete="/items/{{ item.id }}" hx-target="#item-list" hx-swap="innerHTML" hx-confirm="Delete {{ item.name }}?">X</button></td>
</tr>
{% endfor %}
</tbody>
</table>
@@ -0,0 +1,15 @@
<table>
<thead>
<tr><th>Name</th><th>Ingredients</th><th>Instructions</th><th></th></tr>
</thead>
<tbody>
{% for meal in meals %}
<tr>
<td><a href="/meals/{{ meal.id }}">{{ meal.name }}</a></td>
<td>{{ meal.ingredients | length }}</td>
<td>{{ (meal.instructions or "")[:50] }}</td>
<td><button class="danger" hx-delete="/meals/{{ meal.id }}" hx-target="#meal-list" hx-swap="innerHTML" hx-confirm="Delete {{ meal.name }}?">X</button></td>
</tr>
{% endfor %}
</tbody>
</table>
+1 -1
View File
@@ -257,7 +257,7 @@ def update_weather(config: Config) -> None:
logger.info(f"Masked location: {masked_lat}, {masked_lon}") logger.info(f"Masked location: {masked_lat}, {masked_lon}")
weather = fetch_weather(config.pirate_weather_api_key, masked_lat, masked_lon) weather = fetch_weather(config.pirate_weather_api_key, lat, lon)
logger.info(f"Weather: {weather.temperature}°F, {weather.condition}") logger.info(f"Weather: {weather.temperature}°F, {weather.condition}")
post_to_ha(config.ha_url, config.ha_token, weather) post_to_ha(config.ha_url, config.ha_token, weather)
+1 -3
View File
@@ -1,8 +1,6 @@
"""Models for van weather service.""" """Models for van weather service."""
from __future__ import annotations from datetime import datetime
from datetime import datetime # noqa: TC003 This is required for pydantic
from pydantic import BaseModel, field_serializer from pydantic import BaseModel, field_serializer
+2 -2
View File
@@ -108,7 +108,7 @@ class Dataset:
self.written = int(properties["written"]["value"]) self.written = int(properties["written"]["value"])
self.xattr = properties["xattr"]["value"] self.xattr = properties["xattr"]["value"]
def get_snapshots(self) -> list[Snapshot]: def get_snapshots(self) -> list[Snapshot] | None:
"""Get all snapshots from zfs and process then is test dicts of sets.""" """Get all snapshots from zfs and process then is test dicts of sets."""
snapshots_data = _zfs_list(f"zfs list -t snapshot -pHj {self.name} -o all") snapshots_data = _zfs_list(f"zfs list -t snapshot -pHj {self.name} -o all")
@@ -125,7 +125,7 @@ class Dataset:
if return_code == 0: if return_code == 0:
return "snapshot created" return "snapshot created"
snapshots = self.get_snapshots() if snapshots := self.get_snapshots():
snapshot_names = {snapshot.name for snapshot in snapshots} snapshot_names = {snapshot.name for snapshot in snapshots}
if snapshot_name in snapshot_names: if snapshot_name in snapshot_names:
return f"Snapshot {snapshot_name} already exists for {self.name}" return f"Snapshot {snapshot_name} already exists for {self.name}"
+50
View File
@@ -0,0 +1,50 @@
{
pkgs,
inputs,
...
}:
{
networking.firewall.allowedTCPPorts = [ 8001 ];
users = {
users.vaninventory = {
isSystemUser = true;
group = "vaninventory";
};
groups.vaninventory = { };
};
systemd.services.van_inventory = {
description = "Van Inventory API";
after = [
"network.target"
"postgresql.service"
];
requires = [ "postgresql.service" ];
wantedBy = [ "multi-user.target" ];
environment = {
PYTHONPATH = "${inputs.self}/";
VAN_INVENTORY_DB = "vaninventory";
VAN_INVENTORY_USER = "vaninventory";
VAN_INVENTORY_HOST = "/run/postgresql";
VAN_INVENTORY_PORT = "5432";
};
serviceConfig = {
Type = "simple";
User = "vaninventory";
Group = "vaninventory";
ExecStart = "${pkgs.my_python}/bin/python -m python.van_inventory.main --host 0.0.0.0 --port 8001";
Restart = "on-failure";
RestartSec = "5s";
StandardOutput = "journal";
StandardError = "journal";
NoNewPrivileges = true;
ProtectSystem = "strict";
ProtectHome = "read-only";
PrivateTmp = true;
ReadOnlyPaths = [ "${inputs.self}" ];
};
};
}
+1 -1
View File
@@ -47,7 +47,7 @@
}; };
networks = { networks = {
"10-Primary" = { "10-Primary" = {
matchConfig.Name = "enp97s0f1"; matchConfig.Name = "enp97s0";
address = [ "192.168.99.14/24" ]; address = [ "192.168.99.14/24" ];
dns = [ dns = [
"192.168.99.1" "192.168.99.1"
-1
View File
@@ -1 +0,0 @@
"""Focused ebook search tests."""
-18
View File
@@ -1,18 +0,0 @@
"""Tests for the shared query/gold set loader."""
from __future__ import annotations
from python.ebook_search.eval.dataset import load_gold_queries
def test_default_query_set_counts() -> None:
queries = load_gold_queries()
answerable = [query for query in queries if query.answerable]
assert len(queries) == 70
assert len(answerable) == 50
assert len(queries) - len(answerable) == 20
assert all(query.query for query in queries)
# Answerable queries carry at least one source; garbage queries carry none.
assert all(query.relevant_sources for query in answerable)
assert all(not query.relevant_sources for query in queries if not query.answerable)
-147
View File
@@ -1,147 +0,0 @@
"""Tests for serve-time output guardrails."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from python.ebook_search.api.main import create_app
from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.guardrails import is_confident, retrieval_confidence, validate_citations
from python.ebook_search.search import SearchResponse, SearchResult
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def make_results(count, *, vector_score=0.8):
return [
SearchResult(
chunk_id=index,
text=f"source text {index}",
source_title="Book",
score=vector_score,
vector_score=vector_score,
)
for index in range(1, count + 1)
]
def test_validate_citations_partitions_markers() -> None:
report = validate_citations("Supported by [1] and [2].", result_count=3)
assert report.cited == (1, 2)
assert report.invalid == ()
assert report.grounded is True
def test_validate_citations_flags_out_of_range_marker() -> None:
report = validate_citations("As shown in [5].", result_count=2)
assert report.cited == ()
assert report.invalid == (5,)
assert report.grounded is False
def test_validate_citations_uncited_answer_is_not_grounded() -> None:
report = validate_citations("No citations at all.", result_count=2)
assert report.cited == ()
assert report.invalid == ()
assert report.grounded is False
def test_retrieval_confidence_prefers_rerank_then_vector() -> None:
assert retrieval_confidence([]) == 0.0
rerank_top = [SearchResult(chunk_id=1, text="t", source_title="B", rerank_score=0.7, vector_score=0.2)]
assert retrieval_confidence(rerank_top) == 0.7
vector_top = [SearchResult(chunk_id=1, text="t", source_title="B", vector_score=0.5)]
assert retrieval_confidence(vector_top) == 0.5
def test_is_confident_against_threshold() -> None:
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), min_retrieval_confidence=0.5)
assert is_confident(make_results(1, vector_score=0.6), config) is True
assert is_confident(make_results(1, vector_score=0.4), config) is False
def patch_app_runtime(mocker: MockerFixture):
mocker.patch(
"python.ebook_search.api.main.get_postgres_engine",
side_effect=lambda **_kwargs: create_engine("sqlite+pysqlite:///:memory:", future=True),
)
mocker.patch("python.ebook_search.api.main.ensure_bm25_corpus", side_effect=lambda _session, _config: None)
def test_low_confidence_skips_answer_generation(mocker: MockerFixture) -> None:
called = False
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(1, vector_score=0.05))
def fake_answer_query(_query, _results, _config):
nonlocal called
called = True
return "answer"
config = EbookSearchConfig(
rerank=RerankConfig(enabled=False),
answer_enabled=True,
min_retrieval_confidence=0.5,
)
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks)
mocker.patch("python.ebook_search.api.routes.search.answer_query", side_effect=fake_answer_query)
mocker.patch("python.ebook_search.api.main.load_config", side_effect=lambda: config)
patch_app_runtime(mocker)
app = create_app()
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert called is False
assert "Low retrieval confidence" in response.text
def test_invalid_citation_is_flagged(mocker: MockerFixture) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks)
mocker.patch(
"python.ebook_search.api.routes.search.answer_query",
side_effect=lambda _query, _results, _config: "Per the text [9].",
)
patch_app_runtime(mocker)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert "Invalid citations" in response.text
assert "9" in response.text
def test_grounded_answer_has_no_warning_badge(mocker: MockerFixture) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks)
mocker.patch(
"python.ebook_search.api.routes.search.answer_query",
side_effect=lambda _query, _results, _config: "Grounded in [1] and [2].",
)
patch_app_runtime(mocker)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert "Unverified" not in response.text
assert "Invalid citations" not in response.text
-122
View File
@@ -1,122 +0,0 @@
"""Tests for EPUB search health and readiness routes."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from python.ebook_search.api.main import create_app
from python.ebook_search.config import EbookSearchConfig, RerankConfig
HEALTH_MODULE = "python.ebook_search.api.routes.health"
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def fake_get_postgres_engine(**_kwargs):
"""Return an in-memory engine for route tests."""
return create_engine("sqlite+pysqlite:///:memory:", future=True)
def patch_app_runtime(mocker: MockerFixture):
mocker.patch("python.ebook_search.api.main.get_postgres_engine", side_effect=fake_get_postgres_engine)
mocker.patch("python.ebook_search.api.main.ensure_bm25_corpus", side_effect=lambda _session, _config: None)
def patch_dependencies(mocker: MockerFixture, *, database=True, embedding=True, chat=True, bm25="ok"):
mocker.patch(f"{HEALTH_MODULE}.check_database", side_effect=lambda _session: database)
mocker.patch(f"{HEALTH_MODULE}.check_embedding_endpoint", side_effect=lambda _config: embedding)
mocker.patch(f"{HEALTH_MODULE}.check_chat_endpoint", side_effect=lambda _config: chat)
mocker.patch(f"{HEALTH_MODULE}.check_bm25_status", side_effect=lambda _config: bm25)
def build_client(mocker: MockerFixture, config=None):
resolved = config or EbookSearchConfig(rerank=RerankConfig(enabled=False))
mocker.patch("python.ebook_search.api.main.load_config", side_effect=lambda: resolved)
patch_app_runtime(mocker)
app = create_app()
return TestClient(app)
def test_health_returns_ok(mocker: MockerFixture) -> None:
with build_client(mocker) as client:
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_ready_all_dependencies_ok(mocker: MockerFixture) -> None:
patch_dependencies(mocker)
with build_client(mocker) as client:
response = client.get("/ready")
assert response.status_code == 200
body = response.json()
assert body["status"] == "ready"
assert body["checks"] == {"database": "ok", "embedding": "ok", "chat": "ok", "bm25": "ok"}
def test_ready_embedding_down_is_degraded(mocker: MockerFixture) -> None:
patch_dependencies(mocker, embedding=False)
with build_client(mocker) as client:
response = client.get("/ready")
assert response.status_code == 200
body = response.json()
assert body["status"] == "degraded"
assert body["checks"]["embedding"] == "fail"
def test_ready_chat_down_is_degraded(mocker: MockerFixture) -> None:
patch_dependencies(mocker, chat=False)
with build_client(mocker) as client:
response = client.get("/ready")
assert response.status_code == 200
body = response.json()
assert body["status"] == "degraded"
assert body["checks"]["chat"] == "fail"
def test_ready_chat_disabled_when_answers_off(mocker: MockerFixture) -> None:
patch_dependencies(mocker)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=False)
with build_client(mocker, config) as client:
response = client.get("/ready")
assert response.status_code == 200
body = response.json()
assert body["status"] == "ready"
assert body["checks"]["chat"] == "disabled"
def test_ready_database_down_is_unavailable(mocker: MockerFixture) -> None:
patch_dependencies(mocker, database=False)
with build_client(mocker) as client:
response = client.get("/ready")
assert response.status_code == 503
body = response.json()
assert body["status"] == "unavailable"
assert body["checks"]["database"] == "fail"
def test_ready_bm25_missing_is_degraded(mocker: MockerFixture) -> None:
patch_dependencies(mocker, bm25="missing")
with build_client(mocker) as client:
response = client.get("/ready")
assert response.status_code == 200
body = response.json()
assert body["status"] == "degraded"
assert body["checks"]["bm25"] == "missing"
-79
View File
@@ -1,79 +0,0 @@
"""Tests for the load-test runner and its statistics helpers."""
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
import pytest
from python.ebook_search.loadtest import RequestResult, load_queries, percentile, run_load, summarize
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def test_load_queries_reads_shared_set() -> None:
queries = load_queries(None)
assert len(queries) == 70
assert all(isinstance(query, str) and query for query in queries)
def test_percentile_interpolates() -> None:
values = [10.0, 20.0, 30.0, 40.0]
assert percentile(values, 50) == pytest.approx(25.0)
assert percentile(values, 90) == pytest.approx(37.0)
assert percentile(values, 0) == 10.0
assert percentile(values, 100) == 40.0
assert percentile([], 95) == 0.0
def test_summarize_counts_and_throughput() -> None:
results = [
RequestResult(status_code=200, latency_ms=10.0, ok=True),
RequestResult(status_code=200, latency_ms=20.0, ok=True),
RequestResult(status_code=200, latency_ms=30.0, ok=True),
RequestResult(status_code=500, latency_ms=40.0, ok=False),
]
summary = summarize(results, wall_seconds=2.0)
assert summary.total == 4
assert summary.successes == 3
assert summary.failures == 1
assert summary.throughput_rps == pytest.approx(2.0)
assert summary.latency_max_ms == 40.0
assert summary.status_counts == {200: 3, 500: 1}
def test_summarize_handles_empty() -> None:
summary = summarize([], wall_seconds=0.0)
assert summary.total == 0
assert summary.throughput_rps == 0.0
assert summary.latency_p95_ms == 0.0
def test_run_load_aggregates_mocked_responses(mocker: MockerFixture) -> None:
response = mocker.Mock(status_code=200, is_success=True)
client = mocker.MagicMock()
client.__aenter__.return_value = client
client.post = mocker.AsyncMock(return_value=response)
mocker.patch("python.ebook_search.loadtest.httpx.AsyncClient", return_value=client)
summary = asyncio.run(
run_load(
base_url="http://test",
queries=["q1", "q2"],
request_count=4,
concurrency=2,
rerank=False,
warmup=1,
timeout_seconds=1.0,
)
)
assert summary.total == 4
assert summary.successes == 4
assert summary.failures == 0
assert summary.status_counts == {200: 4}
# 1 warmup request (not measured) plus 4 measured requests.
assert client.post.await_count == 5
-49
View File
@@ -1,49 +0,0 @@
"""Tests for the ebook search RAG pipeline orchestration."""
from __future__ import annotations
from threading import Event
from typing import TYPE_CHECKING
from sqlalchemy import create_engine
from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.search import SearchResult, search_ebooks
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def test_search_ebooks_runs_vector_and_bm25_in_parallel(mocker: MockerFixture) -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
vector_started = Event()
bm25_started = Event()
received_engines: list[object] = []
def fake_vector_candidates(received_engine, query, _config):
"""Return vector candidates after confirming BM25 has started."""
received_engines.append(received_engine)
assert query == "what is parallel"
vector_started.set()
assert bm25_started.wait(timeout=2)
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
def fake_bm25_candidates(query, _config):
"""Return BM25 candidates after confirming vector search has started."""
assert query == "what is parallel"
bm25_started.set()
assert vector_started.wait(timeout=2)
return [SearchResult(chunk_id=2, text="bm25", source_title="BM25", bm25_score=2.0)]
mocker.patch("python.ebook_search.search.vector_candidates", side_effect=fake_vector_candidates)
mocker.patch("python.ebook_search.search.bm25_candidates", side_effect=fake_bm25_candidates)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
response = search_ebooks(engine, "what is parallel", config)
timings = {step.name: step for step in response.timings}
assert [result.chunk_id for result in response.results] == [1, 2]
assert timings["Embedding + vector search"].counts_toward_total is False
assert timings["BM25 search"].counts_toward_total is False
assert timings["Hybrid retrieval"].counts_toward_total is True
assert received_engines == [engine]
@@ -3,11 +3,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import replace
from datetime import UTC, datetime from datetime import UTC, datetime
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from threading import Event
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING
import pytest import pytest
from sqlalchemy import create_engine, select from sqlalchemy import create_engine, select
@@ -34,6 +35,7 @@ from python.ebook_search.search import (
bm25_candidates, bm25_candidates,
reciprocal_rank_fusion, reciprocal_rank_fusion,
retrieval_query_from_text, retrieval_query_from_text,
search_ebooks,
) )
from python.ebook_search.timing import RuntimeStep from python.ebook_search.timing import RuntimeStep
from python.orm.richie import ( from python.orm.richie import (
@@ -45,9 +47,6 @@ from python.orm.richie import (
RichieBase, RichieBase,
) )
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def test_chunk_text_uses_overlap() -> None: def test_chunk_text_uses_overlap() -> None:
chunks = chunk_text(" ".join(str(index) for index in range(100)), chunk_tokens=20, overlap_tokens=5) chunks = chunk_text(" ".join(str(index) for index in range(100)), chunk_tokens=20, overlap_tokens=5)
@@ -68,7 +67,7 @@ def test_reciprocal_rank_fusion_combines_vector_and_bm25_rankings() -> None:
SearchResult(chunk_id=3, text="c", source_title="C", score=2.1, bm25_score=2.1), SearchResult(chunk_id=3, text="c", source_title="C", score=2.1, bm25_score=2.1),
] ]
fused = reciprocal_rank_fusion(vector_results, lexical_results, rank_constant=60) fused = reciprocal_rank_fusion(vector_results, lexical_results)
assert [result.chunk_id for result in fused] == [2, 1, 3] assert [result.chunk_id for result in fused] == [2, 1, 3]
assert fused[0].rank_source == "Hybrid" assert fused[0].rank_source == "Hybrid"
@@ -146,7 +145,7 @@ def test_reciprocal_rank_fusion_marks_hybrid_source() -> None:
vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")] vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")]
lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")] lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")]
fused = reciprocal_rank_fusion(vector_results, lexical_results, rank_constant=60) fused = reciprocal_rank_fusion(vector_results, lexical_results)
assert {result.rank_source for result in fused} == {"Hybrid"} assert {result.rank_source for result in fused} == {"Hybrid"}
@@ -166,13 +165,49 @@ def test_search_response_sums_runtime_steps() -> None:
assert response.total_runtime_ms == 4.0 assert response.total_runtime_ms == 4.0
def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
vector_started = Event()
bm25_started = Event()
received_engines: list[object] = []
def fake_vector_candidates(received_engine, query, _config):
"""Return vector candidates after confirming BM25 has started."""
received_engines.append(received_engine)
assert query == "what is parallel"
vector_started.set()
assert bm25_started.wait(timeout=2)
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
def fake_bm25_candidates(query, _config):
"""Return BM25 candidates after confirming vector search has started."""
assert query == "parallel"
bm25_started.set()
assert vector_started.wait(timeout=2)
return [SearchResult(chunk_id=2, text="bm25", source_title="BM25", bm25_score=2.0)]
monkeypatch.setattr("python.ebook_search.search.vector_candidates", fake_vector_candidates)
monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
response = search_ebooks(engine, "what is parallel", config)
timings = {step.name: step for step in response.timings}
assert [result.chunk_id for result in response.results] == [1, 2]
assert timings["Embedding + vector search"].counts_toward_total is False
assert timings["BM25 search"].counts_toward_total is False
assert timings["Hybrid retrieval"].counts_toward_total is True
assert timings["BM25 query preparation"].counts_toward_total is True
assert received_engines == [engine]
def test_retrieval_query_keeps_entity_and_series_terms() -> None: def test_retrieval_query_keeps_entity_and_series_terms() -> None:
assert retrieval_query_from_text("what does Damien Montgomery stand for in starship mage") == ( assert retrieval_query_from_text("what does Damien Montgomery stand for in starship mage") == (
"damien montgomery stand starship mage" "damien montgomery stand starship mage"
) )
def test_bm25_candidates_scores_whole_corpus(mocker: MockerFixture) -> None: def test_bm25_candidates_scores_whole_corpus(monkeypatch) -> None:
record = { record = {
"chunk_id": 2, "chunk_id": 2,
"text": "high", "text": "high",
@@ -192,8 +227,8 @@ def test_bm25_candidates_scores_whole_corpus(mocker: MockerFixture) -> None:
captured["limit"] = limit captured["limit"] = limit
return [(record, 1.5)] return [(record, 1.5)]
mocker.patch("python.ebook_search.search.load_bm25_corpus", side_effect=lambda _config: corpus) monkeypatch.setattr("python.ebook_search.search.load_bm25_corpus", lambda _config: corpus)
mocker.patch("python.ebook_search.search.score_bm25_corpus", side_effect=fake_score_bm25_corpus) monkeypatch.setattr("python.ebook_search.search.score_bm25_corpus", fake_score_bm25_corpus)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
results = bm25_candidates("high", config) results = bm25_candidates("high", config)
@@ -205,11 +240,11 @@ def test_bm25_candidates_scores_whole_corpus(mocker: MockerFixture) -> None:
assert [result.bm25_score for result in results] == [1.5] assert [result.bm25_score for result in results] == [1.5]
def test_bm25_candidates_returns_empty_when_corpus_is_unavailable(mocker: MockerFixture, caplog) -> None: def test_bm25_candidates_returns_empty_when_corpus_is_unavailable(monkeypatch, caplog) -> None:
def fake_load_bm25_corpus(_config): def fake_load_bm25_corpus(_config):
raise BM25CorpusUnavailableError raise BM25CorpusUnavailableError
mocker.patch("python.ebook_search.search.load_bm25_corpus", side_effect=fake_load_bm25_corpus) monkeypatch.setattr("python.ebook_search.search.load_bm25_corpus", fake_load_bm25_corpus)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
with caplog.at_level(logging.WARNING): with caplog.at_level(logging.WARNING):
@@ -245,7 +280,7 @@ def test_write_bm25_corpus_publishes_dated_generation(tmp_path) -> None:
assert read_bm25_manifest(index_path) == manifest assert read_bm25_manifest(index_path) == manifest
def test_write_bm25_corpus_keeps_current_generation_when_publish_fails(mocker: MockerFixture, tmp_path) -> None: def test_write_bm25_corpus_keeps_current_generation_when_publish_fails(monkeypatch, tmp_path) -> None:
index_path = tmp_path / "bm25" index_path = tmp_path / "bm25"
index_path.mkdir() index_path.mkdir()
generations_path = index_path / "generations" generations_path = index_path / "generations"
@@ -263,7 +298,7 @@ def test_write_bm25_corpus_keeps_current_generation_when_publish_fails(mocker: M
raise OSError(msg) raise OSError(msg)
return original_replace(self, target) return original_replace(self, target)
mocker.patch.object(Path, "replace", fail_current_replace) monkeypatch.setattr(Path, "replace", fail_current_replace)
manifest = BM25Manifest( manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC), created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None, db_updated_at=None,
@@ -307,7 +342,7 @@ def test_load_bm25_corpus_uses_current_generation(tmp_path) -> None:
assert score_bm25_corpus("cached", corpus, limit=10) assert score_bm25_corpus("cached", corpus, limit=10)
def test_load_bm25_corpus_caches_disk_load(mocker: MockerFixture, tmp_path) -> None: def test_load_bm25_corpus_caches_disk_load(monkeypatch, tmp_path) -> None:
load_bm25_corpus.cache_clear() load_bm25_corpus.cache_clear()
manifest = BM25Manifest(created_at=datetime.now(tz=UTC), db_updated_at=None, chunk_count=1) manifest = BM25Manifest(created_at=datetime.now(tz=UTC), db_updated_at=None, chunk_count=1)
record = { record = {
@@ -340,9 +375,9 @@ def test_load_bm25_corpus_caches_disk_load(mocker: MockerFixture, tmp_path) -> N
fake_bm25s = ModuleType("bm25s") fake_bm25s = ModuleType("bm25s")
fake_bm25s.BM25 = FakeBM25 fake_bm25s.BM25 = FakeBM25
mocker.patch("python.ebook_search.bm25_corpus.read_bm25_manifest", side_effect=lambda _path: manifest) monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: manifest)
mocker.patch("python.ebook_search.bm25_corpus.bm25_index_exists", side_effect=lambda _path, _manifest: True) monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: True)
mocker.patch("python.ebook_search.bm25_corpus.bm25s", fake_bm25s) monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25s", fake_bm25s)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(tmp_path)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(tmp_path))
try: try:
@@ -357,10 +392,10 @@ def test_load_bm25_corpus_caches_disk_load(mocker: MockerFixture, tmp_path) -> N
assert load_count == 1 assert load_count == 1
def test_load_bm25_corpus_raises_when_index_is_missing(mocker: MockerFixture, tmp_path) -> None: def test_load_bm25_corpus_raises_when_index_is_missing(monkeypatch, tmp_path) -> None:
load_bm25_corpus.cache_clear() load_bm25_corpus.cache_clear()
mocker.patch("python.ebook_search.bm25_corpus.read_bm25_manifest", side_effect=lambda _path: None) monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: None)
mocker.patch("python.ebook_search.bm25_corpus.bm25_index_exists", side_effect=lambda _path, _manifest: False) monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: False)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(tmp_path)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(tmp_path))
try: try:
@@ -370,16 +405,16 @@ def test_load_bm25_corpus_raises_when_index_is_missing(mocker: MockerFixture, tm
load_bm25_corpus.cache_clear() load_bm25_corpus.cache_clear()
def test_ensure_bm25_corpus_refreshes_missing_index(mocker: MockerFixture) -> None: def test_ensure_bm25_corpus_refreshes_missing_index(monkeypatch) -> None:
refreshed: list[object] = [] refreshed: list[object] = []
db_updated_at = datetime.now(tz=UTC) db_updated_at = datetime.now(tz=UTC)
mocker.patch("python.ebook_search.bm25_corpus.read_bm25_manifest", side_effect=lambda _path: None) monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: None)
mocker.patch("python.ebook_search.bm25_corpus.bm25_index_exists", side_effect=lambda _path, _manifest: False) monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: False)
mocker.patch("python.ebook_search.bm25_corpus.corpus_last_updated_at", side_effect=lambda _session: db_updated_at) monkeypatch.setattr("python.ebook_search.bm25_corpus.corpus_last_updated_at", lambda _session: db_updated_at)
mocker.patch( monkeypatch.setattr(
"python.ebook_search.bm25_corpus.refresh_bm25_corpus", "python.ebook_search.bm25_corpus.refresh_bm25_corpus",
side_effect=lambda session, config, *, db_updated_at: refreshed.append((session, config, db_updated_at)), lambda session, config, *, db_updated_at: refreshed.append((session, config, db_updated_at)),
) )
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
@@ -390,18 +425,18 @@ def test_ensure_bm25_corpus_refreshes_missing_index(mocker: MockerFixture) -> No
assert refreshed == [(session, config, db_updated_at)] assert refreshed == [(session, config, db_updated_at)]
def test_ensure_bm25_corpus_refreshes_stale_index(mocker: MockerFixture) -> None: def test_ensure_bm25_corpus_refreshes_stale_index(monkeypatch) -> None:
refreshed: list[object] = [] refreshed: list[object] = []
created_at = datetime(2026, 1, 1, tzinfo=UTC) created_at = datetime(2026, 1, 1, tzinfo=UTC)
db_updated_at = datetime(2026, 1, 2, tzinfo=UTC) db_updated_at = datetime(2026, 1, 2, tzinfo=UTC)
manifest = BM25Manifest(created_at=created_at, db_updated_at=created_at, chunk_count=10) manifest = BM25Manifest(created_at=created_at, db_updated_at=created_at, chunk_count=10)
mocker.patch("python.ebook_search.bm25_corpus.read_bm25_manifest", side_effect=lambda _path: manifest) monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: manifest)
mocker.patch("python.ebook_search.bm25_corpus.bm25_index_exists", side_effect=lambda _path, _manifest: True) monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: True)
mocker.patch("python.ebook_search.bm25_corpus.corpus_last_updated_at", side_effect=lambda _session: db_updated_at) monkeypatch.setattr("python.ebook_search.bm25_corpus.corpus_last_updated_at", lambda _session: db_updated_at)
mocker.patch( monkeypatch.setattr(
"python.ebook_search.bm25_corpus.refresh_bm25_corpus", "python.ebook_search.bm25_corpus.refresh_bm25_corpus",
side_effect=lambda session, config, *, db_updated_at: refreshed.append((session, config, db_updated_at)), lambda session, config, *, db_updated_at: refreshed.append((session, config, db_updated_at)),
) )
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
@@ -445,9 +480,7 @@ def test_1024_embedding_table_has_cosine_hnsw_index() -> None:
assert index.dialect_options["postgresql"]["ops"] == {"embedding": "vector_cosine_ops"} assert index.dialect_options["postgresql"]["ops"] == {"embedding": "vector_cosine_ops"}
def test_embedding_model_aliases_normalize_to_provider_names(mocker: MockerFixture) -> None: def test_embedding_model_aliases_normalize_to_provider_names() -> None:
mocker.patch.dict(environ, {}, clear=False)
assert normalize_embedding_model() == "qwen3-embedding-0.6b" assert normalize_embedding_model() == "qwen3-embedding-0.6b"
environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "qwen3-embedding-0.6b" environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "qwen3-embedding-0.6b"
@@ -467,19 +500,17 @@ def test_embedding_model_aliases_normalize_to_provider_names(mocker: MockerFixtu
assert normalize_embedding_model() == "qwen3-embedding-8b" assert normalize_embedding_model() == "qwen3-embedding-8b"
def test_answer_generation_is_enabled_by_default(mocker: MockerFixture) -> None: def test_answer_generation_is_enabled_by_default(monkeypatch) -> None:
mocker.patch.dict(environ, {}, clear=False) monkeypatch.delenv("EBOOK_SEARCH_ANSWER_ENABLED", raising=False)
environ.pop("EBOOK_SEARCH_ANSWER_ENABLED", None)
config = load_config() config = load_config()
assert config.answer_enabled is True assert config.answer_enabled is True
def test_chat_defaults_use_ollama_cloud(mocker: MockerFixture) -> None: def test_chat_defaults_use_ollama_cloud(monkeypatch) -> None:
mocker.patch.dict(environ, {}, clear=False) monkeypatch.delenv("EBOOK_SEARCH_VLLM_BASE_URL", raising=False)
environ.pop("EBOOK_SEARCH_VLLM_BASE_URL", None) monkeypatch.delenv("EBOOK_SEARCH_CHAT_MODEL", raising=False)
environ.pop("EBOOK_SEARCH_CHAT_MODEL", None)
config = load_config() config = load_config()
@@ -487,9 +518,9 @@ def test_chat_defaults_use_ollama_cloud(mocker: MockerFixture) -> None:
assert config.chat_model == "deepseek-v4-flash" assert config.chat_model == "deepseek-v4-flash"
def test_chat_api_key_falls_back_to_ollama_api_key(mocker: MockerFixture) -> None: def test_chat_api_key_falls_back_to_ollama_api_key(monkeypatch) -> None:
mocker.patch.dict(environ, {"OLLAMA_API_KEY": "ollama-key"}, clear=False) monkeypatch.delenv("EBOOK_SEARCH_VLLM_API_KEY", raising=False)
environ.pop("EBOOK_SEARCH_VLLM_API_KEY", None) monkeypatch.setenv("OLLAMA_API_KEY", "ollama-key")
config = load_config() config = load_config()
@@ -497,7 +528,7 @@ def test_chat_api_key_falls_back_to_ollama_api_key(mocker: MockerFixture) -> Non
def test_answer_query_does_not_call_model_when_disabled() -> None: def test_answer_query_does_not_call_model_when_disabled() -> None:
config = load_config().model_copy(update={"answer_enabled": False}) config = replace(load_config(), answer_enabled=False)
result = SearchResult(chunk_id=1, text="source text", source_title="Book") result = SearchResult(chunk_id=1, text="source text", source_title="Book")
answer = answer_query("question", [result], config) answer = answer_query("question", [result], config)
@@ -2,8 +2,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
import httpx import httpx
import pytest import pytest
@@ -12,11 +10,8 @@ from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.embeddings import embed_texts from python.ebook_search.embeddings import embed_texts
from python.ebook_search.search import SearchResult from python.ebook_search.search import SearchResult
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def test_answer_query_uses_httpx_chat_completions(monkeypatch) -> None:
def test_answer_query_uses_httpx_chat_completions(mocker: MockerFixture) -> None:
captured: dict[str, object] = {} captured: dict[str, object] = {}
def fake_post(url: str, **kwargs: object) -> httpx.Response: def fake_post(url: str, **kwargs: object) -> httpx.Response:
@@ -28,7 +23,7 @@ def test_answer_query_uses_httpx_chat_completions(mocker: MockerFixture) -> None
request=httpx.Request("POST", url), request=httpx.Request("POST", url),
) )
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
config = EbookSearchConfig( config = EbookSearchConfig(
rerank=RerankConfig(enabled=False), rerank=RerankConfig(enabled=False),
vllm_base_url="https://ollama.com/v1", vllm_base_url="https://ollama.com/v1",
@@ -48,7 +43,7 @@ def test_answer_query_uses_httpx_chat_completions(mocker: MockerFixture) -> None
assert payload["model"] == "deepseek-v4-flash" assert payload["model"] == "deepseek-v4-flash"
def test_embed_texts_uses_httpx_embeddings(mocker: MockerFixture) -> None: def test_embed_texts_uses_httpx_embeddings(monkeypatch) -> None:
captured: dict[str, object] = {} captured: dict[str, object] = {}
vector = [0.0] * 1024 vector = [0.0] * 1024
@@ -61,7 +56,7 @@ def test_embed_texts_uses_httpx_embeddings(mocker: MockerFixture) -> None:
request=httpx.Request("POST", url), request=httpx.Request("POST", url),
) )
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
config = EbookSearchConfig( config = EbookSearchConfig(
rerank=RerankConfig(enabled=False), rerank=RerankConfig(enabled=False),
embedding_base_url="http://bob:8000/v1", embedding_base_url="http://bob:8000/v1",
@@ -78,11 +73,11 @@ def test_embed_texts_uses_httpx_embeddings(mocker: MockerFixture) -> None:
assert kwargs["json"] == {"model": "qwen3-embedding-0.6b", "input": ["hello"]} assert kwargs["json"] == {"model": "qwen3-embedding-0.6b", "input": ["hello"]}
def test_embed_texts_rejects_bad_response_shape(mocker: MockerFixture) -> None: def test_embed_texts_rejects_bad_response_shape(monkeypatch) -> None:
def fake_post(url: str, **_kwargs: object) -> httpx.Response: def fake_post(url: str, **_kwargs: object) -> httpx.Response:
return httpx.Response(200, json={"data": [{}]}, request=httpx.Request("POST", url)) return httpx.Response(200, json={"data": [{}]}, request=httpx.Request("POST", url))
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
with pytest.raises(RuntimeError, match="Embedding request failed"): with pytest.raises(RuntimeError, match="Embedding request failed"):
@@ -2,9 +2,6 @@
from __future__ import annotations from __future__ import annotations
from os import environ
from typing import TYPE_CHECKING
import httpx import httpx
import pytest import pytest
@@ -12,9 +9,6 @@ from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_rer
from python.ebook_search.rerank import rerank_chunks from python.ebook_search.rerank import rerank_chunks
from python.ebook_search.search import SearchResult, apply_rerank, skip_rerank from python.ebook_search.search import SearchResult, apply_rerank, skip_rerank
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def candidates() -> list[SearchResult]: def candidates() -> list[SearchResult]:
return [ return [
@@ -33,17 +27,16 @@ def rerank_response(payload: dict[str, object] | None = None, *, content: bytes
) )
def test_config_defaults_enable_reranking(mocker: MockerFixture) -> None: def test_config_defaults_keep_reranking_optional(monkeypatch: pytest.MonkeyPatch) -> None:
mocker.patch.dict(environ, {}, clear=False) monkeypatch.delenv("EBOOK_SEARCH_RERANK_ENABLED", raising=False)
environ.pop("EBOOK_SEARCH_RERANK_ENABLED", None) monkeypatch.delenv("EBOOK_SEARCH_RERANK_BASE_URL", raising=False)
environ.pop("EBOOK_SEARCH_RERANK_BASE_URL", None) monkeypatch.delenv("EBOOK_SEARCH_RERANK_MODEL", raising=False)
environ.pop("EBOOK_SEARCH_RERANK_MODEL", None) monkeypatch.delenv("EBOOK_SEARCH_RERANK_CANDIDATES", raising=False)
environ.pop("EBOOK_SEARCH_RERANK_CANDIDATES", None) monkeypatch.delenv("EBOOK_SEARCH_RERANK_TIMEOUT_SECONDS", raising=False)
environ.pop("EBOOK_SEARCH_RERANK_TIMEOUT_SECONDS", None)
config = load_rerank_config() config = load_rerank_config()
assert config.enabled is True assert config.enabled is False
assert config.base_url == "http://192.168.90.25:8001" assert config.base_url == "http://192.168.90.25:8001"
assert config.model == "qwen3-reranker-06b" assert config.model == "qwen3-reranker-06b"
assert config.candidates == 24 assert config.candidates == 24
@@ -59,7 +52,7 @@ def test_reranking_disabled_returns_original_fused_order() -> None:
assert [result.chunk_id for result in response.results] == [1, 2] assert [result.chunk_id for result in response.results] == [1, 2]
def test_reranking_enabled_reorders_candidates(mocker: MockerFixture) -> None: def test_reranking_enabled_reorders_candidates(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post(_url: str, *, json: dict[str, object], timeout: float) -> httpx.Response: def fake_post(_url: str, *, json: dict[str, object], timeout: float) -> httpx.Response:
assert timeout == 30 assert timeout == 30
assert json == { assert json == {
@@ -77,7 +70,7 @@ def test_reranking_enabled_reorders_candidates(mocker: MockerFixture) -> None:
} }
) )
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
results = rerank_chunks("query", candidates(), RerankConfig()) results = rerank_chunks("query", candidates(), RerankConfig())
@@ -86,7 +79,7 @@ def test_reranking_enabled_reorders_candidates(mocker: MockerFixture) -> None:
assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4] assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4]
def test_reranking_cannot_ignore_hybrid_score(mocker: MockerFixture) -> None: def test_reranking_cannot_ignore_hybrid_score(monkeypatch: pytest.MonkeyPatch) -> None:
candidates = [ candidates = [
SearchResult(chunk_id=1, text="strong hybrid", source_title="A", score=1.0), SearchResult(chunk_id=1, text="strong hybrid", source_title="A", score=1.0),
SearchResult(chunk_id=2, text="weak hybrid", source_title="B", score=0.1), SearchResult(chunk_id=2, text="weak hybrid", source_title="B", score=0.1),
@@ -102,7 +95,7 @@ def test_reranking_cannot_ignore_hybrid_score(mocker: MockerFixture) -> None:
} }
) )
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
results = rerank_chunks("query", candidates, RerankConfig()) results = rerank_chunks("query", candidates, RerankConfig())
@@ -112,7 +105,7 @@ def test_reranking_cannot_ignore_hybrid_score(mocker: MockerFixture) -> None:
assert results[1].rerank_score == 1.0 assert results[1].rerank_score == 1.0
def test_vllm_rerank_timeout_raises(mocker: MockerFixture) -> None: def test_vllm_rerank_timeout_raises(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_rerank_chunks( def fake_rerank_chunks(
_query: str, _query: str,
_candidates: list[SearchResult], _candidates: list[SearchResult],
@@ -121,25 +114,25 @@ def test_vllm_rerank_timeout_raises(mocker: MockerFixture) -> None:
message = "timeout" message = "timeout"
raise httpx.TimeoutException(message) raise httpx.TimeoutException(message)
mocker.patch("python.ebook_search.search.rerank_chunks", side_effect=fake_rerank_chunks) monkeypatch.setattr("python.ebook_search.search.rerank_chunks", fake_rerank_chunks)
config = EbookSearchConfig(rerank=RerankConfig(enabled=True), top_k=2) config = EbookSearchConfig(rerank=RerankConfig(enabled=True), top_k=2)
with pytest.raises(httpx.TimeoutException, match="timeout"): with pytest.raises(httpx.TimeoutException, match="timeout"):
apply_rerank("query", candidates(), config) apply_rerank("query", candidates(), config)
def test_malformed_vllm_rerank_json_does_not_crash_search(mocker: MockerFixture) -> None: def test_malformed_vllm_rerank_json_does_not_crash_search(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post(_url: str, **_kwargs: object) -> httpx.Response: def fake_post(_url: str, **_kwargs: object) -> httpx.Response:
return rerank_response(content=b"not-json") return rerank_response(content=b"not-json")
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
results = rerank_chunks("query", candidates()[:1], RerankConfig()) results = rerank_chunks("query", candidates()[:1], RerankConfig())
assert results[0].score == 0.3 assert results[0].score == 0.3
def test_vllm_rerank_scores_are_clamped(mocker: MockerFixture) -> None: def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post(_url: str, **_kwargs: object) -> httpx.Response: def fake_post(_url: str, **_kwargs: object) -> httpx.Response:
return rerank_response( return rerank_response(
{ {
@@ -150,7 +143,7 @@ def test_vllm_rerank_scores_are_clamped(mocker: MockerFixture) -> None:
} }
) )
mocker.patch.object(httpx, "post", side_effect=fake_post) monkeypatch.setattr(httpx, "post", fake_post)
results = rerank_chunks("query", candidates()[:2], RerankConfig()) results = rerank_chunks("query", candidates()[:2], RerankConfig())
@@ -3,8 +3,6 @@
from __future__ import annotations from __future__ import annotations
from compression import zstd from compression import zstd
from typing import TYPE_CHECKING
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine from sqlalchemy import create_engine
@@ -15,14 +13,11 @@ from python.ebook_search.embeddings import EmbeddingModelStats
from python.ebook_search.search import SearchResponse, SearchResult from python.ebook_search.search import SearchResponse, SearchResult
from python.ebook_search.timing import RuntimeStep from python.ebook_search.timing import RuntimeStep
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def patch_app_runtime(monkeypatch):
def patch_app_runtime(mocker: MockerFixture):
"""Patch app startup dependencies used by UI route tests.""" """Patch app startup dependencies used by UI route tests."""
mocker.patch("python.ebook_search.api.main.get_postgres_engine", side_effect=fake_get_postgres_engine) monkeypatch.setattr("python.ebook_search.api.main.get_postgres_engine", fake_get_postgres_engine)
mocker.patch("python.ebook_search.api.main.ensure_bm25_corpus", side_effect=lambda _session, _config: None) monkeypatch.setattr("python.ebook_search.api.main.ensure_bm25_corpus", lambda _session, _config: None)
def fake_get_postgres_engine(**_kwargs): def fake_get_postgres_engine(**_kwargs):
@@ -30,8 +25,8 @@ def fake_get_postgres_engine(**_kwargs):
return create_engine("sqlite+pysqlite:///:memory:", future=True) return create_engine("sqlite+pysqlite:///:memory:", future=True)
def test_search_page_uses_zstd_when_requested(mocker: MockerFixture) -> None: def test_search_page_uses_zstd_when_requested(monkeypatch) -> None:
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
@@ -43,7 +38,7 @@ def test_search_page_uses_zstd_when_requested(mocker: MockerFixture) -> None:
assert b"EPUB Search" in zstd.decompress(response.content) assert b"EPUB Search" in zstd.decompress(response.content)
def test_ui_form_passes_rerank_flag_to_search_handler(mocker: MockerFixture) -> None: def test_ui_form_passes_rerank_flag_to_search_handler(monkeypatch) -> None:
captured: dict[str, object] = {} captured: dict[str, object] = {}
def fake_search_ebooks(_engine, query, config, *, rerank=False): def fake_search_ebooks(_engine, query, config, *, rerank=False):
@@ -52,12 +47,12 @@ def test_ui_form_passes_rerank_flag_to_search_handler(mocker: MockerFixture) ->
captured["config"] = config captured["config"] = config
return SearchResponse(query=query, results=[], rank_label="Hybrid + rerank") return SearchResponse(query=query, results=[], rank_label="Hybrid + rerank")
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks) monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
mocker.patch( monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query", "python.ebook_search.api.routes.search.answer_query",
side_effect=lambda _query, _results, _config: "answer", lambda _query, _results, _config: "answer",
) )
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), top_k=12, answer_enabled=True) app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), top_k=12, answer_enabled=True)
@@ -70,14 +65,14 @@ def test_ui_form_passes_rerank_flag_to_search_handler(mocker: MockerFixture) ->
assert captured["rerank"] is True assert captured["rerank"] is True
def test_ui_search_failure_returns_visible_error(mocker: MockerFixture) -> None: def test_ui_search_failure_returns_visible_error(monkeypatch) -> None:
def fake_search_ebooks(_engine, _query, _config, *, rerank=False): def fake_search_ebooks(_engine, _query, _config, *, rerank=False):
del rerank del rerank
msg = "search exploded" msg = "search exploded"
raise RuntimeError(msg) raise RuntimeError(msg)
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks) monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), top_k=12) app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), top_k=12)
@@ -88,7 +83,7 @@ def test_ui_search_failure_returns_visible_error(mocker: MockerFixture) -> None:
assert "search exploded" in response.text assert "search exploded" in response.text
def test_ui_answer_failure_still_returns_sources(mocker: MockerFixture) -> None: def test_ui_answer_failure_still_returns_sources(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False): def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank del rerank
return SearchResponse(query=query, results=[], rank_label="Hybrid") return SearchResponse(query=query, results=[], rank_label="Hybrid")
@@ -97,9 +92,9 @@ def test_ui_answer_failure_still_returns_sources(mocker: MockerFixture) -> None:
msg = "answer exploded" msg = "answer exploded"
raise RuntimeError(msg) raise RuntimeError(msg)
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks) monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
mocker.patch("python.ebook_search.api.routes.search.answer_query", side_effect=fake_answer_query) monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query)
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), top_k=12, answer_enabled=True) app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), top_k=12, answer_enabled=True)
@@ -110,7 +105,7 @@ def test_ui_answer_failure_still_returns_sources(mocker: MockerFixture) -> None:
assert "Answer generation failed" in response.text assert "Answer generation failed" in response.text
def test_ui_skips_answer_when_disabled(mocker: MockerFixture) -> None: def test_ui_skips_answer_when_disabled(monkeypatch) -> None:
called = False called = False
def fake_search_ebooks(_engine, query, _config, *, rerank=False): def fake_search_ebooks(_engine, query, _config, *, rerank=False):
@@ -122,12 +117,11 @@ def test_ui_skips_answer_when_disabled(mocker: MockerFixture) -> None:
called = True called = True
return "answer" return "answer"
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=False) monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks) monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query)
mocker.patch("python.ebook_search.api.routes.search.answer_query", side_effect=fake_answer_query) patch_app_runtime(monkeypatch)
mocker.patch("python.ebook_search.api.main.load_config", side_effect=lambda: config)
patch_app_runtime(mocker)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=False)
with TestClient(app) as client: with TestClient(app) as client:
response = client.post("/search", data={"query": "where is the quote?"}) response = client.post("/search", data={"query": "where is the quote?"})
@@ -137,7 +131,7 @@ def test_ui_skips_answer_when_disabled(mocker: MockerFixture) -> None:
assert "Answer generation is disabled" in response.text assert "Answer generation is disabled" in response.text
def test_ui_shows_component_scores(mocker: MockerFixture) -> None: def test_ui_shows_component_scores(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False): def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank del rerank
return SearchResponse( return SearchResponse(
@@ -157,12 +151,12 @@ def test_ui_shows_component_scores(mocker: MockerFixture) -> None:
], ],
) )
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks) monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
mocker.patch( monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query", "python.ebook_search.api.routes.search.answer_query",
side_effect=lambda _query, _results, _config: "answer", lambda _query, _results, _config: "answer",
) )
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True) app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
@@ -176,7 +170,7 @@ def test_ui_shows_component_scores(mocker: MockerFixture) -> None:
assert "RRF" in response.text assert "RRF" in response.text
def test_ui_shows_search_runtime_chart(mocker: MockerFixture) -> None: def test_ui_shows_search_runtime_chart(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False): def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank del rerank
return SearchResponse( return SearchResponse(
@@ -189,12 +183,12 @@ def test_ui_shows_search_runtime_chart(mocker: MockerFixture) -> None:
), ),
) )
mocker.patch("python.ebook_search.api.routes.search.search_ebooks", side_effect=fake_search_ebooks) monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
mocker.patch( monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query", "python.ebook_search.api.routes.search.answer_query",
side_effect=lambda _query, _results, _config: "answer", lambda _query, _results, _config: "answer",
) )
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True) app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
@@ -210,7 +204,7 @@ def test_ui_shows_search_runtime_chart(mocker: MockerFixture) -> None:
assert "ms left" in response.text assert "ms left" in response.text
def test_ui_embed_all_batches_until_complete(mocker: MockerFixture) -> None: def test_ui_embed_all_batches_until_complete(monkeypatch) -> None:
counts = iter([32, 32, 5, 0]) counts = iter([32, 32, 5, 0])
batch_sizes: list[int] = [] batch_sizes: list[int] = []
@@ -218,8 +212,8 @@ def test_ui_embed_all_batches_until_complete(mocker: MockerFixture) -> None:
batch_sizes.append(config.embedding_batch_size) batch_sizes.append(config.embedding_batch_size)
return next(counts) return next(counts)
mocker.patch("python.ebook_search.api.routes.admin.embed_missing_chunks", side_effect=fake_embed_missing_chunks) monkeypatch.setattr("python.ebook_search.api.routes.admin.embed_missing_chunks", fake_embed_missing_chunks)
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
with TestClient(app) as client: with TestClient(app) as client:
@@ -230,7 +224,7 @@ def test_ui_embed_all_batches_until_complete(mocker: MockerFixture) -> None:
assert batch_sizes == [32, 32, 32, 32] assert batch_sizes == [32, 32, 32, 32]
def test_ui_scan_schedules_bm25_refresh_after_database_change(mocker: MockerFixture) -> None: def test_ui_scan_schedules_bm25_refresh_after_database_change(monkeypatch) -> None:
scheduled = False scheduled = False
def fake_ingest_configured_paths(_session, _config): def fake_ingest_configured_paths(_session, _config):
@@ -240,12 +234,9 @@ def test_ui_scan_schedules_bm25_refresh_after_database_change(mocker: MockerFixt
nonlocal scheduled nonlocal scheduled
scheduled = True scheduled = True
mocker.patch( monkeypatch.setattr("python.ebook_search.api.routes.admin.ingest_configured_paths", fake_ingest_configured_paths)
"python.ebook_search.api.routes.admin.ingest_configured_paths", monkeypatch.setattr("python.ebook_search.api.routes.admin.schedule_bm25_refresh", fake_schedule_bm25_refresh)
side_effect=fake_ingest_configured_paths, patch_app_runtime(monkeypatch)
)
mocker.patch("python.ebook_search.api.routes.admin.schedule_bm25_refresh", side_effect=fake_schedule_bm25_refresh)
patch_app_runtime(mocker)
app = create_app() app = create_app()
with TestClient(app) as client: with TestClient(app) as client:
@@ -256,7 +247,7 @@ def test_ui_scan_schedules_bm25_refresh_after_database_change(mocker: MockerFixt
assert scheduled is True assert scheduled is True
def test_bm25_refresh_clears_loaded_corpus_cache(mocker: MockerFixture) -> None: def test_bm25_refresh_clears_loaded_corpus_cache(monkeypatch) -> None:
refreshed: list[object] = [] refreshed: list[object] = []
cache_cleared = False cache_cleared = False
@@ -267,8 +258,8 @@ def test_bm25_refresh_clears_loaded_corpus_cache(mocker: MockerFixture) -> None:
nonlocal cache_cleared nonlocal cache_cleared
cache_cleared = True cache_cleared = True
mocker.patch("python.ebook_search.api.bm25_tasks.refresh_bm25_corpus", side_effect=fake_refresh_bm25_corpus) monkeypatch.setattr("python.ebook_search.api.bm25_tasks.refresh_bm25_corpus", fake_refresh_bm25_corpus)
mocker.patch("python.ebook_search.api.bm25_tasks.load_bm25_corpus.cache_clear", side_effect=fake_cache_clear) monkeypatch.setattr("python.ebook_search.api.bm25_tasks.load_bm25_corpus.cache_clear", fake_cache_clear)
engine = create_engine("sqlite+pysqlite:///:memory:", future=True) engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
@@ -279,7 +270,7 @@ def test_bm25_refresh_clears_loaded_corpus_cache(mocker: MockerFixture) -> None:
assert cache_cleared is True assert cache_cleared is True
def test_admin_page_shows_embedding_counts_by_model(mocker: MockerFixture) -> None: def test_admin_page_shows_embedding_counts_by_model(monkeypatch) -> None:
def fake_embedding_model_stats(_session): def fake_embedding_model_stats(_session):
return [ return [
EmbeddingModelStats( EmbeddingModelStats(
@@ -296,8 +287,8 @@ def test_admin_page_shows_embedding_counts_by_model(mocker: MockerFixture) -> No
), ),
] ]
mocker.patch("python.ebook_search.api.routes.admin.embedding_model_stats", side_effect=fake_embedding_model_stats) monkeypatch.setattr("python.ebook_search.api.routes.admin.embedding_model_stats", fake_embedding_model_stats)
patch_app_runtime(mocker) patch_app_runtime(monkeypatch)
app = create_app() app = create_app()
with TestClient(app) as client: with TestClient(app) as client:
+4 -4
View File
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
class MockFuture(Future): class MockFuture(Future):
"""MockFuture.""" """MockFuture."""
def __init__(self, result: Any) -> None: def __init__(self, result: Any) -> None: # noqa: ANN401
"""Init.""" """Init."""
super().__init__() super().__init__()
self._result = result self._result = result
@@ -31,7 +31,7 @@ class MockFuture(Future):
logging.debug(f"{timeout}=") logging.debug(f"{timeout}=")
return self._exception return self._exception
def result(self, timeout: float | None = None) -> Any: def result(self, timeout: float | None = None) -> Any: # noqa: ANN401
"""Result.""" """Result."""
logging.debug(f"{timeout}=") logging.debug(f"{timeout}=")
return self._result return self._result
@@ -40,11 +40,11 @@ class MockFuture(Future):
class MockPoolExecutor(ThreadPoolExecutor): class MockPoolExecutor(ThreadPoolExecutor):
"""MockPoolExecutor.""" """MockPoolExecutor."""
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
"""Initializes a new ThreadPoolExecutor instance.""" """Initializes a new ThreadPoolExecutor instance."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future: def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future: # noqa: ANN401
"""Submits a callable to be executed with the given arguments. """Submits a callable to be executed with the given arguments.
Args: Args:
+3 -6
View File
@@ -21,7 +21,7 @@ def test_validate_system(mocker: MockerFixture, fs: FakeFilesystem) -> None:
"""test_validate_system.""" """test_validate_system."""
fs.create_file( fs.create_file(
"/mock_snapshot_config.toml", "/mock_snapshot_config.toml",
contents='zpools = ["root_pool", "storage", "media"]\nservices = ["docker"]\n', contents='zpool = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
) )
mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=None) mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=None)
@@ -33,10 +33,9 @@ def test_validate_system_errors(mocker: MockerFixture, fs: FakeFilesystem) -> No
"""test_validate_system_errors.""" """test_validate_system_errors."""
fs.create_file( fs.create_file(
"/mock_snapshot_config.toml", "/mock_snapshot_config.toml",
contents='zpools = ["root_pool", "storage", "media"]\nservices = ["docker"]\n', contents='zpool = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
) )
mocker.patch(f"{VALIDATE_SYSTEM}.signal_alert")
mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=["systemd_tests error"]) mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=["systemd_tests error"])
mocker.patch(f"{VALIDATE_SYSTEM}.zpool_tests", return_value=["zpool_tests error"]) mocker.patch(f"{VALIDATE_SYSTEM}.zpool_tests", return_value=["zpool_tests error"])
@@ -50,11 +49,9 @@ def test_validate_system_execution(mocker: MockerFixture, fs: FakeFilesystem) ->
"""test_validate_system_execution.""" """test_validate_system_execution."""
fs.create_file( fs.create_file(
"/mock_snapshot_config.toml", "/mock_snapshot_config.toml",
contents='zpools = ["root_pool", "storage", "media"]\nservices = ["docker"]\n', contents='zpool = ["root_pool", "storage", "media"]\nservices = ["docker"]\n',
) )
mocker.patch(f"{VALIDATE_SYSTEM}.signal_alert")
mocker.patch(f"{VALIDATE_SYSTEM}.systemd_tests", return_value=None)
mocker.patch(f"{VALIDATE_SYSTEM}.zpool_tests", side_effect=RuntimeError("zpool_tests error")) mocker.patch(f"{VALIDATE_SYSTEM}.zpool_tests", side_effect=RuntimeError("zpool_tests error"))
with pytest.raises(SystemExit) as exception_info: with pytest.raises(SystemExit) as exception_info:
@@ -80,10 +80,8 @@
"fastapi", "fastapi",
"Michal", "Michal",
"Nornsight", "Nornsight",
"pydantic",
"sandboxing", "sandboxing",
"syncthing", "syncthing",
"vllm",
], ],
// nix // nix