Compare commits

...

8 Commits

Author SHA1 Message Date
Richie cd38a0f277 added vector_engine to fix name postgres name space issue
treefmt / nix fmt (pull_request) Successful in 6s
pytest / pytest (pull_request) Successful in 31s
build_systems / build-bob (pull_request) Successful in 47s
build_systems / build-brain (pull_request) Successful in 48s
build_systems / build-leviathan (pull_request) Successful in 56s
build_systems / build-rhapsody-in-green (pull_request) Successful in 1m0s
build_systems / build-jeeves (pull_request) Successful in 2m36s
2026-06-12 14:50:22 -04:00
Richie 61d86446ed reworked ebook_search routers 2026-06-12 14:46:00 -04:00
Richie e14c20010f made fastapi tools 2026-06-12 14:45:10 -04:00
Richie bd87dd2015 added proper cache invalidation to load_bm25_corpus 2026-06-12 13:47:43 -04:00
Richie 94493647a6 updated tests 2026-06-12 13:36:45 -04:00
Richie 70d65bbbe0 improved reranking weights 2026-06-12 13:36:34 -04:00
Richie 4dee9f76a7 fixed duplicat enrichment 2026-06-12 13:35:20 -04:00
Richie 0e874a3489 improved queary for vector search 2026-06-12 13:34:59 -04:00
19 changed files with 145 additions and 81 deletions
+1 -1
View File
@@ -9,9 +9,9 @@ import typer
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from python.api.middleware import ZstdMiddleware
from python.api.routers import contact_router, views_router from python.api.routers import contact_router, views_router
from python.common import configure_logger from python.common import configure_logger
from python.fastapi_tools import ZstdMiddleware
from python.orm.common import get_postgres_engine from python.orm.common import get_postgres_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
+1 -1
View File
@@ -9,7 +9,7 @@ from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from python.api.dependencies import DbSession from python.fastapi_tools.db import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates" TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+1 -1
View File
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from python.api.dependencies import DbSession from python.fastapi_tools.db import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates" TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+3 -1
View File
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from python.ebook_search.bm25_corpus import refresh_bm25_corpus from python.ebook_search.bm25_corpus import load_bm25_corpus, refresh_bm25_corpus
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
@@ -56,3 +56,5 @@ def refresh_bm25_for_engine(engine: Engine, config: EbookSearchConfig) -> None:
"""Refresh the BM25 corpus using a SQLAlchemy engine.""" """Refresh the BM25 corpus using a SQLAlchemy engine."""
with Session(engine) as session: with Session(engine) as session:
refresh_bm25_corpus(session, config) refresh_bm25_corpus(session, config)
load_bm25_corpus.cache_clear()
logger.info("ebook_bm25_corpus_cache_cleared_after_refresh")
+7 -5
View File
@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from python.common import configure_logger from python.common import configure_logger
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
from python.ebook_search.api.routes import register_admin_routes, register_page_routes, register_search_routes from python.ebook_search.api.routes import admin_router, page_router, search_router
from python.ebook_search.api.web import STATIC_DIR from python.ebook_search.api.web import STATIC_DIR
from python.ebook_search.bm25_corpus import ensure_bm25_corpus from python.ebook_search.bm25_corpus import ensure_bm25_corpus
from python.ebook_search.config import load_config from python.ebook_search.config import load_config
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
async def lifespan(app: FastAPI) -> AsyncIterator[None]: async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage application startup and shutdown resources.""" """Manage application startup and shutdown resources."""
logger.info("ebook_search_startup") logger.info("ebook_search_startup")
app.state.engine = get_postgres_engine(name="RICHIE") app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True)
with Session(app.state.engine) as session: with Session(app.state.engine) as session:
ensure_bm25_corpus(session, app.state.config) ensure_bm25_corpus(session, app.state.config)
try: try:
@@ -55,9 +55,11 @@ def create_app() -> FastAPI:
app.state.config.answer_enabled, app.state.config.answer_enabled,
len(app.state.config.library_paths), len(app.state.config.library_paths),
) )
register_page_routes(app)
register_search_routes(app) app.include_router(admin_router)
register_admin_routes(app) app.include_router(page_router)
app.include_router(search_router)
return app return app
+6 -11
View File
@@ -1,16 +1,11 @@
"""EPUB search web route modules.""" """EPUB search web route modules."""
from python.ebook_search.api.routes import admin, page, search from python.ebook_search.api.routes.admin import router as admin_router
from python.ebook_search.api.routes.page import router as page_router
register_admin_routes = admin.register_admin_routes from python.ebook_search.api.routes.search import router as search_router
register_page_routes = page.register_page_routes
register_search_routes = search.register_search_routes
__all__ = [ __all__ = [
"admin", "admin_router",
"page", "page_router",
"register_admin_routes", "search_router",
"register_page_routes",
"register_search_routes",
"search",
] ]
-9
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
from dataclasses import replace from dataclasses import replace
from typing import TYPE_CHECKING
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
@@ -15,20 +14,12 @@ from python.ebook_search.api.web import templates
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
from python.ebook_search.ingest import ingest_configured_paths from python.ebook_search.ingest import ingest_configured_paths
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin") router = APIRouter(prefix="/admin")
EMBED_ALL_BATCH_SIZE = 32 EMBED_ALL_BATCH_SIZE = 32
def register_admin_routes(app: FastAPI) -> None:
"""Register admin routes on the app."""
app.include_router(router)
@router.get("", response_class=HTMLResponse) @router.get("", response_class=HTMLResponse)
def admin(request: Request) -> HTMLResponse: def admin(request: Request) -> HTMLResponse:
"""Render the admin page.""" """Render the admin page."""
-9
View File
@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
@@ -13,19 +12,11 @@ from sqlalchemy.orm import Session
from python.ebook_search.api.web import templates from python.ebook_search.api.web import templates
from python.orm.richie import EbookSource from python.orm.richie import EbookSource
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def register_page_routes(app: FastAPI) -> None:
"""Register page routes on the app."""
app.include_router(router)
@router.get("/", response_class=HTMLResponse) @router.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse: def index(request: Request) -> HTMLResponse:
"""Render the search page.""" """Render the search page."""
+1 -9
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
import logging import logging
from dataclasses import replace from dataclasses import replace
from time import perf_counter from time import perf_counter
from typing import TYPE_CHECKING, Annotated from typing import Annotated
from fastapi import APIRouter, Form, Request from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
@@ -15,19 +15,11 @@ from python.ebook_search.api.web import templates
from python.ebook_search.search import search_ebooks from python.ebook_search.search import search_ebooks
from python.ebook_search.timing import runtime_step_from_start from python.ebook_search.timing import runtime_step_from_start
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def register_search_routes(app: FastAPI) -> None:
"""Register search routes on the app."""
app.include_router(router)
@router.post("/search", response_class=HTMLResponse) @router.post("/search", response_class=HTMLResponse)
def search( def search(
request: Request, request: Request,
+4 -16
View File
@@ -108,11 +108,10 @@ def refresh_bm25_corpus(
) )
write_bm25_corpus(index_path, records, manifest) write_bm25_corpus(index_path, records, manifest)
logger.info( logger.info(
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s note=%s", "ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
index_path, index_path,
manifest.chunk_count, manifest.chunk_count,
manifest.created_at.isoformat(), manifest.created_at.isoformat(),
"restart_service_to_use_refreshed_bm25_cache",
) )
return manifest return manifest
@@ -121,15 +120,10 @@ def refresh_bm25_corpus(
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus: def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
"""Load the BM25 corpus into memory once per process. """Load the BM25 corpus into memory once per process.
This cache intentionally does not notice later on-disk corpus refreshes. Restart the service after rebuilding the Background refresh tasks clear this cache after rebuilding the on-disk corpus.
BM25 corpus for searches to use the new index.
""" """
index_path = bm25_index_path(config) index_path = bm25_index_path(config)
logger.info( logger.info("ebook_bm25_corpus_cache_load path=%s", index_path)
"ebook_bm25_corpus_cache_load path=%s note=%s",
index_path,
"restart_service_after_bm25_refresh",
)
manifest = read_bm25_manifest(index_path) manifest = read_bm25_manifest(index_path)
if manifest is None or not bm25_index_exists(index_path, manifest): if manifest is None or not bm25_index_exists(index_path, manifest):
msg = f"BM25 corpus is not available: {index_path}" msg = f"BM25 corpus is not available: {index_path}"
@@ -172,13 +166,7 @@ def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
EbookSource.author.label("source_author"), EbookSource.author.label("source_author"),
EbookChapter.title.label("chapter_title"), EbookChapter.title.label("chapter_title"),
EbookChunk.page_label.label("page_label"), EbookChunk.page_label.label("page_label"),
func.concat_ws( EbookChunk.search_text.label("bm25_text"),
" ",
EbookSource.title,
EbookSource.author,
EbookChapter.title,
EbookChunk.search_text,
).label("bm25_text"),
) )
.select_from(EbookChunk) .select_from(EbookChunk)
.join(EbookSource, EbookSource.id == EbookChunk.source_id) .join(EbookSource, EbookSource.id == EbookChunk.source_id)
+3 -1
View File
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
from python.ebook_search.search import SearchResult from python.ebook_search.search import SearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RERANK_SCORE_WEIGHT = 0.7
HYBRID_SCORE_WEIGHT = 0.3
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -110,7 +112,7 @@ def clamp_score(score: float) -> float:
def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float: def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
"""Combine rerank relevance with normalized hybrid retrieval evidence.""" """Combine rerank relevance with normalized hybrid retrieval evidence."""
return rerank_score * normalized_hybrid_score(result, candidates) return (RERANK_SCORE_WEIGHT * rerank_score) + (HYBRID_SCORE_WEIGHT * normalized_hybrid_score(result, candidates))
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float: def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
+12 -6
View File
@@ -93,13 +93,14 @@ def search_ebooks(
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank) logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
timings: list[RuntimeStep] = [] timings: list[RuntimeStep] = []
retrieval_query, timing = timed_result("Query preparation", retrieval_query_from_text, query) bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query)
timings.append(timing) timings.append(timing)
retrieval, timing = timed_result( retrieval, timing = timed_result(
"Hybrid retrieval", "Hybrid retrieval",
parallel_retrieval, parallel_retrieval,
engine, engine,
retrieval_query, query,
bm25_query,
config, config,
) )
timings.extend(retrieval.timings) timings.extend(retrieval.timings)
@@ -130,7 +131,12 @@ def search_ebooks(
return response return response
def parallel_retrieval(engine: Engine, query: str, config: EbookSearchConfig) -> RetrievalResponse: def parallel_retrieval(
engine: Engine,
vector_query: str,
bm25_query: str,
config: EbookSearchConfig,
) -> RetrievalResponse:
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions.""" """Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor: with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor:
vector_future = executor.submit( vector_future = executor.submit(
@@ -138,14 +144,14 @@ def parallel_retrieval(engine: Engine, query: str, config: EbookSearchConfig) ->
"Embedding + vector search", "Embedding + vector search",
vector_candidates, vector_candidates,
engine, engine,
query, vector_query,
config, config,
) )
bm25_future = executor.submit( bm25_future = executor.submit(
timed_result, timed_result,
"BM25 search", "BM25 search",
bm25_candidates, bm25_candidates,
query, bm25_query,
config, config,
) )
vector_results, vector_timing = vector_future.result() vector_results, vector_timing = vector_future.result()
@@ -196,7 +202,7 @@ def apply_rerank(
def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]: def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]:
"""Return pgvector cosine candidates for a normalized query.""" """Return pgvector cosine candidates for a natural-language query."""
with Session(engine) as session: with Session(engine) as session:
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model)) model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
if model is None: if model is None:
+6
View File
@@ -0,0 +1,6 @@
"""Reusable FastAPI tools."""
from python.fastapi_tools.db import DbSession, get_db
from python.fastapi_tools.zstd_middleware import ZstdMiddleware
__all__ = ["DbSession", "ZstdMiddleware", "get_db"]
@@ -1,4 +1,4 @@
"""Middleware for the FastAPI application.""" """Zstd response compression middleware."""
from compression import zstd from compression import zstd
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+24 -2
View File
@@ -31,8 +31,24 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password)) return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine: def get_postgres_engine(
"""Create a SQLAlchemy engine from environment variables.""" *,
name: str = "POSTGRES",
pool_pre_ping: bool = True,
vector_engine: bool = False,
) -> Engine:
"""Create a SQLAlchemy engine from environment variables.
Args:
name (str, optional): The name of the environment variable prefix. Defaults to "POSTGRES".
pool_pre_ping (bool, optional): Whether to ping the database before each connection. Defaults to True.
This fixes the issue of trying to use a conection that has timed out on the database side.
vector_engine (bool, optional): Whether to use the vector search schema. Defaults to False.
This updates the search path the incldued the vecore types and operators.
Returns:
Engine: The SQLAlchemy engine.
"""
database, host, port, username, password = get_connection_info(name) database, host, port, username, password = get_connection_info(name)
url = URL.create( url = URL.create(
@@ -44,8 +60,14 @@ def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -
database=database, database=database,
) )
connect_args = {}
# There more better way to do this is with separate PG account and a dedicated vector schema for the vector types
if vector_engine:
connect_args["options"] = "-csearch_path=main,public"
return create_engine( return create_engine(
url=url, url=url,
pool_pre_ping=pool_pre_ping, pool_pre_ping=pool_pre_ping,
pool_recycle=1800, pool_recycle=1800,
connect_args=connect_args,
) )
+46 -3
View File
@@ -19,6 +19,7 @@ from python.ebook_search.bm25_corpus import (
BM25CorpusUnavailableError, BM25CorpusUnavailableError,
BM25Manifest, BM25Manifest,
ensure_bm25_corpus, ensure_bm25_corpus,
fetch_bm25_corpus_records,
load_bm25_corpus, load_bm25_corpus,
) )
from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_config, normalize_embedding_model from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_config, normalize_embedding_model
@@ -33,7 +34,7 @@ from python.ebook_search.search import (
search_ebooks, search_ebooks,
) )
from python.ebook_search.timing import RuntimeStep from python.ebook_search.timing import RuntimeStep
from python.orm.richie import EbookEmbeddingModel, EbookSource, RichieBase from python.orm.richie import EbookChapter, EbookChunk, EbookEmbeddingModel, EbookSource, RichieBase
def test_chunk_text_uses_overlap() -> None: def test_chunk_text_uses_overlap() -> None:
@@ -86,6 +87,47 @@ def test_find_existing_source_matches_path_or_hash() -> None:
assert find_existing_source(session, Path("/new/book.epub"), "a" * 64) == source assert find_existing_source(session, Path("/new/book.epub"), "a" * 64) == source
def test_bm25_corpus_uses_existing_search_text_without_duplicate_metadata() -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
RichieBase.metadata.create_all(engine)
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as session:
source = EbookSource(
title="Book",
author="Author",
language=None,
publisher=None,
identifier=None,
file_path="/book.epub",
file_sha256="a" * 64,
file_mtime=datetime.now(tz=UTC),
file_size=10,
)
session.add(source)
session.flush()
chapter = EbookChapter(source_id=source.id, spine_index=0, title="Chapter", href=None)
session.add(chapter)
session.flush()
session.add(
EbookChunk(
id=1,
source_id=source.id,
chapter_id=chapter.id,
chunk_index=0,
text="content",
token_start=0,
token_count=1,
page_label=None,
content_sha256="b" * 64,
search_text="Book Author Chapter content",
)
)
session.commit()
records = fetch_bm25_corpus_records(session)
assert records[0]["bm25_text"] == "Book Author Chapter content"
def test_reciprocal_rank_fusion_marks_hybrid_source() -> None: def test_reciprocal_rank_fusion_marks_hybrid_source() -> None:
vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")] vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")]
lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")] lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")]
@@ -119,7 +161,7 @@ def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
def fake_vector_candidates(received_engine, query, _config): def fake_vector_candidates(received_engine, query, _config):
"""Return vector candidates after confirming BM25 has started.""" """Return vector candidates after confirming BM25 has started."""
received_engines.append(received_engine) received_engines.append(received_engine)
assert query == "parallel" assert query == "what is parallel"
vector_started.set() vector_started.set()
assert bm25_started.wait(timeout=2) assert bm25_started.wait(timeout=2)
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)] return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
@@ -135,13 +177,14 @@ def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates) monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False)) config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
response = search_ebooks(engine, "parallel", config) response = search_ebooks(engine, "what is parallel", config)
timings = {step.name: step for step in response.timings} timings = {step.name: step for step in response.timings}
assert [result.chunk_id for result in response.results] == [1, 2] assert [result.chunk_id for result in response.results] == [1, 2]
assert timings["Embedding + vector search"].counts_toward_total is False assert timings["Embedding + vector search"].counts_toward_total is False
assert timings["BM25 search"].counts_toward_total is False assert timings["BM25 search"].counts_toward_total is False
assert timings["Hybrid retrieval"].counts_toward_total is True assert timings["Hybrid retrieval"].counts_toward_total is True
assert timings["BM25 query preparation"].counts_toward_total is True
assert received_engines == [engine] assert received_engines == [engine]
+5 -5
View File
@@ -75,7 +75,7 @@ def test_reranking_enabled_reorders_candidates(monkeypatch: pytest.MonkeyPatch)
results = rerank_chunks("query", candidates(), RerankConfig()) results = rerank_chunks("query", candidates(), RerankConfig())
assert [result.chunk_id for result in results] == [2, 1, 3] assert [result.chunk_id for result in results] == [2, 1, 3]
assert [round(result.score, 3) for result in results] == [0.45, 0.1, 0.0] assert [round(result.score, 3) for result in results] == [0.78, 0.37, 0.28]
assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4] assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4]
@@ -100,8 +100,8 @@ def test_reranking_cannot_ignore_hybrid_score(monkeypatch: pytest.MonkeyPatch) -
results = rerank_chunks("query", candidates, RerankConfig()) results = rerank_chunks("query", candidates, RerankConfig())
assert [result.chunk_id for result in results] == [1, 2] assert [result.chunk_id for result in results] == [1, 2]
assert results[0].score == 0.7 assert results[0].score == pytest.approx(0.79)
assert results[1].score == 0.0 assert results[1].score == 0.7
assert results[1].rerank_score == 1.0 assert results[1].rerank_score == 1.0
@@ -129,7 +129,7 @@ def test_malformed_vllm_rerank_json_does_not_crash_search(monkeypatch: pytest.Mo
results = rerank_chunks("query", candidates()[:1], RerankConfig()) results = rerank_chunks("query", candidates()[:1], RerankConfig())
assert results[0].score == 0.0 assert results[0].score == 0.3
def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None: def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -147,4 +147,4 @@ def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None
results = rerank_chunks("query", candidates()[:2], RerankConfig()) results = rerank_chunks("query", candidates()[:2], RerankConfig())
assert [result.rerank_score for result in results] == [0.0, 1.0] assert {result.chunk_id: result.rerank_score for result in results} == {1: 0.0, 2: 1.0}
+24
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine from sqlalchemy import create_engine
from python.ebook_search.api.bm25_tasks import refresh_bm25_for_engine
from python.ebook_search.api.main import create_app from python.ebook_search.api.main import create_app
from python.ebook_search.config import EbookSearchConfig, RerankConfig from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.embeddings import EmbeddingModelStats from python.ebook_search.embeddings import EmbeddingModelStats
@@ -232,6 +233,29 @@ def test_ui_scan_schedules_bm25_refresh_after_database_change(monkeypatch) -> No
assert scheduled is True assert scheduled is True
def test_bm25_refresh_clears_loaded_corpus_cache(monkeypatch) -> None:
refreshed: list[object] = []
cache_cleared = False
def fake_refresh_bm25_corpus(session, config):
refreshed.append((session, config))
def fake_cache_clear():
nonlocal cache_cleared
cache_cleared = True
monkeypatch.setattr("python.ebook_search.api.bm25_tasks.refresh_bm25_corpus", fake_refresh_bm25_corpus)
monkeypatch.setattr("python.ebook_search.api.bm25_tasks.load_bm25_corpus.cache_clear", fake_cache_clear)
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
refresh_bm25_for_engine(engine, config)
assert len(refreshed) == 1
assert refreshed[0][1] == config
assert cache_cleared is True
def test_admin_page_shows_embedding_counts_by_model(monkeypatch) -> None: def test_admin_page_shows_embedding_counts_by_model(monkeypatch) -> None:
def fake_embedding_model_stats(_session): def fake_embedding_model_stats(_session):
return [ return [