Files
dotfiles/tests/test_ebook_search_core.py
T
2026-06-13 22:41:09 -04:00

521 lines
19 KiB
Python

"""Tests for EPUB search core helpers."""
from __future__ import annotations
import logging
from dataclasses import replace
from datetime import UTC, datetime
from os import environ
from pathlib import Path
from threading import Event
from types import ModuleType
import pytest
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from python.ebook_search.answer import answer_query
from python.ebook_search.bm25_corpus import (
BM25Corpus,
BM25CorpusUnavailableError,
BM25Manifest,
ensure_bm25_corpus,
fetch_bm25_corpus_records,
load_bm25_corpus,
read_bm25_manifest,
score_bm25_corpus,
write_bm25_corpus,
)
from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_config, normalize_embedding_model
from python.ebook_search.embeddings import MODEL_DIMENSIONS, ensure_embedding_models
from python.ebook_search.ingest import chunk_text, find_existing_source
from python.ebook_search.search import (
SearchResponse,
SearchResult,
bm25_candidates,
reciprocal_rank_fusion,
retrieval_query_from_text,
search_ebooks,
)
from python.ebook_search.timing import RuntimeStep
from python.orm.richie import EbookChapter, EbookChunk, EbookEmbeddingModel, EbookSource, RichieBase
def test_chunk_text_uses_overlap() -> None:
chunks = chunk_text(" ".join(str(index) for index in range(100)), chunk_tokens=20, overlap_tokens=5)
assert len(chunks) > 1
assert chunks[0].token_start == 0
assert chunks[1].token_start == 15
assert all(chunk.token_count <= 20 for chunk in chunks)
def test_reciprocal_rank_fusion_combines_vector_and_bm25_rankings() -> None:
vector_results = [
SearchResult(chunk_id=1, text="a", source_title="A", score=0.9, vector_score=0.9),
SearchResult(chunk_id=2, text="b", source_title="B", score=0.8, vector_score=0.8),
]
lexical_results = [
SearchResult(chunk_id=2, text="b", source_title="B", score=4.2, bm25_score=4.2),
SearchResult(chunk_id=3, text="c", source_title="C", score=2.1, bm25_score=2.1),
]
fused = reciprocal_rank_fusion(vector_results, lexical_results)
assert [result.chunk_id for result in fused] == [2, 1, 3]
assert fused[0].rank_source == "Hybrid"
assert fused[0].vector_score == 0.8
assert fused[0].bm25_score == 4.2
assert fused[0].fused_score == fused[0].score
def test_find_existing_source_matches_path_or_hash() -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
RichieBase.metadata.create_all(engine)
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as session:
source = EbookSource(
title="Book",
author=None,
language=None,
publisher=None,
identifier=None,
file_path="/old/book.epub",
file_sha256="a" * 64,
file_mtime=datetime.now(tz=UTC),
file_size=10,
)
session.add(source)
session.commit()
assert find_existing_source(session, Path("/old/book.epub"), "b" * 64) == source
assert find_existing_source(session, Path("/new/book.epub"), "a" * 64) == source
def test_bm25_corpus_uses_existing_search_text_without_duplicate_metadata() -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
RichieBase.metadata.create_all(engine)
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as session:
source = EbookSource(
title="Book",
author="Author",
language=None,
publisher=None,
identifier=None,
file_path="/book.epub",
file_sha256="a" * 64,
file_mtime=datetime.now(tz=UTC),
file_size=10,
)
session.add(source)
session.flush()
chapter = EbookChapter(source_id=source.id, spine_index=0, title="Chapter", href=None)
session.add(chapter)
session.flush()
session.add(
EbookChunk(
id=1,
source_id=source.id,
chapter_id=chapter.id,
chunk_index=0,
text="content",
token_start=0,
token_count=1,
page_label=None,
content_sha256="b" * 64,
search_text="Book Author Chapter content",
)
)
session.commit()
records, texts = fetch_bm25_corpus_records(session)
assert texts == ["Book Author Chapter content"]
assert records[0]["chunk_id"] == 1
assert "bm25_text" not in records[0]
def test_reciprocal_rank_fusion_marks_hybrid_source() -> None:
vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")]
lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")]
fused = reciprocal_rank_fusion(vector_results, lexical_results)
assert {result.rank_source for result in fused} == {"Hybrid"}
def test_search_response_sums_runtime_steps() -> None:
response = SearchResponse(
query="query",
results=[],
rank_label="Hybrid",
timings=(
RuntimeStep(name="A", duration_ms=1.25),
RuntimeStep(name="B", duration_ms=2.75),
RuntimeStep(name="Parallel detail", duration_ms=10.0, counts_toward_total=False),
),
)
assert response.total_runtime_ms == 4.0
def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
vector_started = Event()
bm25_started = Event()
received_engines: list[object] = []
def fake_vector_candidates(received_engine, query, _config):
"""Return vector candidates after confirming BM25 has started."""
received_engines.append(received_engine)
assert query == "what is parallel"
vector_started.set()
assert bm25_started.wait(timeout=2)
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
def fake_bm25_candidates(query, _config):
"""Return BM25 candidates after confirming vector search has started."""
assert query == "parallel"
bm25_started.set()
assert vector_started.wait(timeout=2)
return [SearchResult(chunk_id=2, text="bm25", source_title="BM25", bm25_score=2.0)]
monkeypatch.setattr("python.ebook_search.search.vector_candidates", fake_vector_candidates)
monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
response = search_ebooks(engine, "what is parallel", config)
timings = {step.name: step for step in response.timings}
assert [result.chunk_id for result in response.results] == [1, 2]
assert timings["Embedding + vector search"].counts_toward_total is False
assert timings["BM25 search"].counts_toward_total is False
assert timings["Hybrid retrieval"].counts_toward_total is True
assert timings["BM25 query preparation"].counts_toward_total is True
assert received_engines == [engine]
def test_retrieval_query_keeps_entity_and_series_terms() -> None:
assert retrieval_query_from_text("what does Damien Montgomery stand for in starship mage") == (
"damien montgomery stand starship mage"
)
def test_bm25_candidates_scores_whole_corpus(monkeypatch) -> None:
record = {
"chunk_id": 2,
"text": "high",
"source_title": "B",
"source_author": None,
"chapter_title": None,
"page_label": None,
"bm25_text": "high",
}
manifest = BM25Manifest(created_at=datetime.now(tz=UTC), db_updated_at=None, chunk_count=1)
corpus = BM25Corpus(retriever=object(), records=(record,), manifest=manifest)
captured: dict[str, object] = {}
def fake_score_bm25_corpus(query, saved_corpus, *, limit):
captured["query"] = query
captured["corpus"] = saved_corpus
captured["limit"] = limit
return [(record, 1.5)]
monkeypatch.setattr("python.ebook_search.search.load_bm25_corpus", lambda _config: corpus)
monkeypatch.setattr("python.ebook_search.search.score_bm25_corpus", fake_score_bm25_corpus)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
results = bm25_candidates("high", config)
assert captured["query"] == "high"
assert captured["corpus"] == corpus
assert captured["limit"] == 120
assert [result.chunk_id for result in results] == [2]
assert [result.bm25_score for result in results] == [1.5]
def test_bm25_candidates_returns_empty_when_corpus_is_unavailable(monkeypatch, caplog) -> None:
def fake_load_bm25_corpus(_config):
raise BM25CorpusUnavailableError
monkeypatch.setattr("python.ebook_search.search.load_bm25_corpus", fake_load_bm25_corpus)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
with caplog.at_level(logging.WARNING):
results = bm25_candidates("high", config)
assert results == []
assert "ebook_bm25_index_unavailable_skipping" in caplog.text
def test_write_bm25_corpus_publishes_dated_generation(tmp_path) -> None:
index_path = tmp_path / "bm25"
index_path.mkdir()
generations_path = index_path / "generations"
generations_path.mkdir()
old_generation = generations_path / "20260101T000000.000000Z"
old_generation.mkdir()
(old_generation / "sentinel").write_text("old", encoding="utf-8")
(index_path / "current").symlink_to(Path("generations") / old_generation.name, target_is_directory=True)
manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None,
chunk_count=0,
)
write_bm25_corpus(index_path, [], [], manifest)
current_path = index_path / "current"
assert current_path.is_symlink()
assert current_path.readlink() == generations_path / "20260612T010203.456789Z"
assert old_generation.is_dir()
assert (old_generation / "sentinel").read_text(encoding="utf-8") == "old"
assert (generations_path / "20260612T010203.456789Z").is_dir()
assert read_bm25_manifest(index_path) == manifest
def test_write_bm25_corpus_keeps_current_generation_when_publish_fails(monkeypatch, tmp_path) -> None:
index_path = tmp_path / "bm25"
index_path.mkdir()
generations_path = index_path / "generations"
generations_path.mkdir()
old_generation = generations_path / "20260101T000000.000000Z"
old_generation.mkdir()
(old_generation / "sentinel").write_text("old", encoding="utf-8")
current_path = index_path / "current"
current_path.symlink_to(Path("generations") / old_generation.name, target_is_directory=True)
original_replace = Path.replace
def fail_current_replace(self, target):
if self.parent == index_path and self.name.startswith(".current.") and target == current_path:
msg = "current publish failed"
raise OSError(msg)
return original_replace(self, target)
monkeypatch.setattr(Path, "replace", fail_current_replace)
manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None,
chunk_count=0,
)
with pytest.raises(OSError, match="current publish failed"):
write_bm25_corpus(index_path, [], [], manifest)
assert current_path.readlink() == Path("generations") / old_generation.name
assert (old_generation / "sentinel").read_text(encoding="utf-8") == "old"
assert not (generations_path / "20260612T010203.456789Z").exists()
def test_load_bm25_corpus_uses_current_generation(tmp_path) -> None:
load_bm25_corpus.cache_clear()
index_path = tmp_path / "bm25"
manifest = BM25Manifest(
created_at=datetime(2026, 6, 12, 1, 2, 3, 456789, tzinfo=UTC),
db_updated_at=None,
chunk_count=1,
)
record = {
"chunk_id": 2,
"text": "cached",
"source_title": "B",
"source_author": None,
"chapter_title": None,
"page_label": None,
}
write_bm25_corpus(index_path, [record], ["cached phrase"], manifest)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(index_path))
try:
corpus = load_bm25_corpus(config)
finally:
load_bm25_corpus.cache_clear()
assert corpus.manifest == manifest
assert corpus.records[0]["chunk_id"] == 2
assert score_bm25_corpus("cached", corpus, limit=10)
def test_load_bm25_corpus_caches_disk_load(monkeypatch, tmp_path) -> None:
load_bm25_corpus.cache_clear()
manifest = BM25Manifest(created_at=datetime.now(tz=UTC), db_updated_at=None, chunk_count=1)
record = {
"chunk_id": 2,
"text": "cached",
"source_title": "B",
"source_author": None,
"chapter_title": None,
"page_label": None,
"bm25_text": "cached",
}
load_count = 0
class FakeRetriever:
"""Fake persisted BM25 retriever."""
corpus = (record,)
class FakeBM25:
"""Fake BM25 class with observable load count."""
@staticmethod
def load(index_path, *, load_corpus, mmap):
nonlocal load_count
load_count += 1
assert index_path == tmp_path
assert load_corpus is True
assert mmap is True
return FakeRetriever()
fake_bm25s = ModuleType("bm25s")
fake_bm25s.BM25 = FakeBM25
monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: manifest)
monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: True)
monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25s", fake_bm25s)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(tmp_path))
try:
first = load_bm25_corpus(config)
second = load_bm25_corpus(config)
finally:
load_bm25_corpus.cache_clear()
assert first is second
assert first is not None
assert first.records == (record,)
assert load_count == 1
def test_load_bm25_corpus_raises_when_index_is_missing(monkeypatch, tmp_path) -> None:
load_bm25_corpus.cache_clear()
monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: None)
monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: False)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), bm25_index_dir=str(tmp_path))
try:
with pytest.raises(BM25CorpusUnavailableError, match="BM25 corpus is not available"):
load_bm25_corpus(config)
finally:
load_bm25_corpus.cache_clear()
def test_ensure_bm25_corpus_refreshes_missing_index(monkeypatch) -> None:
refreshed: list[object] = []
db_updated_at = datetime.now(tz=UTC)
monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: None)
monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: False)
monkeypatch.setattr("python.ebook_search.bm25_corpus.corpus_last_updated_at", lambda _session: db_updated_at)
monkeypatch.setattr(
"python.ebook_search.bm25_corpus.refresh_bm25_corpus",
lambda session, config, *, db_updated_at: refreshed.append((session, config, db_updated_at)),
)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
session = object()
ensure_bm25_corpus(session, config)
assert refreshed == [(session, config, db_updated_at)]
def test_ensure_bm25_corpus_refreshes_stale_index(monkeypatch) -> None:
refreshed: list[object] = []
created_at = datetime(2026, 1, 1, tzinfo=UTC)
db_updated_at = datetime(2026, 1, 2, tzinfo=UTC)
manifest = BM25Manifest(created_at=created_at, db_updated_at=created_at, chunk_count=10)
monkeypatch.setattr("python.ebook_search.bm25_corpus.read_bm25_manifest", lambda _path: manifest)
monkeypatch.setattr("python.ebook_search.bm25_corpus.bm25_index_exists", lambda _path, _manifest: True)
monkeypatch.setattr("python.ebook_search.bm25_corpus.corpus_last_updated_at", lambda _session: db_updated_at)
monkeypatch.setattr(
"python.ebook_search.bm25_corpus.refresh_bm25_corpus",
lambda session, config, *, db_updated_at: refreshed.append((session, config, db_updated_at)),
)
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
session = object()
ensure_bm25_corpus(session, config)
assert refreshed == [(session, config, db_updated_at)]
def test_supported_embedding_models_match_service_names() -> None:
assert MODEL_DIMENSIONS == {
"qwen3-embedding-0.6b": 1024,
"qwen3-embedding-4b": 2560,
"qwen3-embedding-8b": 4096,
}
def test_ensure_embedding_models_registers_service_names() -> None:
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
RichieBase.metadata.create_all(engine)
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as session:
ensure_embedding_models(session)
session.commit()
models = list(session.scalars(select(EbookEmbeddingModel).order_by(EbookEmbeddingModel.name)))
assert [(model.name, model.dimension) for model in models] == [
("qwen3-embedding-0.6b", 1024),
("qwen3-embedding-4b", 2560),
("qwen3-embedding-8b", 4096),
]
def test_embedding_model_aliases_normalize_to_provider_names() -> None:
assert normalize_embedding_model() == "qwen3-embedding-0.6b"
environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "qwen3-embedding-0.6b"
assert normalize_embedding_model() == "qwen3-embedding-0.6b"
environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "Qwen3-Embedding-0.6B"
assert normalize_embedding_model() == "qwen3-embedding-0.6b"
environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "Qwen/Qwen3-Embedding-4B"
assert normalize_embedding_model() == "qwen3-embedding-4b"
environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "qwen3-embedding:8b"
assert normalize_embedding_model() == "qwen3-embedding-8b"
environ["EBOOK_SEARCH_EMBEDDING_MODEL"] = "qwen3-embedding-8b"
assert normalize_embedding_model() == "qwen3-embedding-8b"
def test_answer_generation_is_enabled_by_default(monkeypatch) -> None:
monkeypatch.delenv("EBOOK_SEARCH_ANSWER_ENABLED", raising=False)
config = load_config()
assert config.answer_enabled is True
def test_chat_defaults_use_ollama_cloud(monkeypatch) -> None:
monkeypatch.delenv("EBOOK_SEARCH_VLLM_BASE_URL", raising=False)
monkeypatch.delenv("EBOOK_SEARCH_CHAT_MODEL", raising=False)
config = load_config()
assert config.vllm_base_url == "https://ollama.com/v1"
assert config.chat_model == "deepseek-v4-flash"
def test_chat_api_key_falls_back_to_ollama_api_key(monkeypatch) -> None:
monkeypatch.delenv("EBOOK_SEARCH_VLLM_API_KEY", raising=False)
monkeypatch.setenv("OLLAMA_API_KEY", "ollama-key")
config = load_config()
assert config.vllm_api_key == "ollama-key"
def test_answer_query_does_not_call_model_when_disabled() -> None:
config = replace(load_config(), answer_enabled=False)
result = SearchResult(chunk_id=1, text="source text", source_title="Book")
answer = answer_query("question", [result], config)
assert "Answer generation is disabled" in answer