diff --git a/python/ebook_search/bm25_corpus.py b/python/ebook_search/bm25_corpus.py new file mode 100644 index 0000000..79172a4 --- /dev/null +++ b/python/ebook_search/bm25_corpus.py @@ -0,0 +1,249 @@ +"""Persisted BM25 corpus management.""" + +from __future__ import annotations + +import json +import logging +import shutil +import tempfile +from dataclasses import dataclass +from datetime import UTC, datetime +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING + +import bm25s +from sqlalchemy import func, select, union_all + +from python.orm.richie import EbookChapter, EbookChunk, EbookSource + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from python.ebook_search.config import EbookSearchConfig + +logger = logging.getLogger(__name__) +MANIFEST_NAME = "manifest.json" +REQUIRED_INDEX_FILES = frozenset( + { + "data.csc.index.npy", + "indices.csc.index.npy", + "indptr.csc.index.npy", + "params.index.json", + "vocab.index.json", + "corpus.jsonl", + } +) + + +@dataclass(frozen=True) +class BM25Manifest: + """Metadata describing a persisted BM25 corpus.""" + + created_at: datetime + db_updated_at: datetime | None + chunk_count: int + + +@dataclass(frozen=True) +class BM25Corpus: + """Loaded persisted BM25 corpus and retriever.""" + + retriever: object | None + records: tuple[dict[str, object], ...] + manifest: BM25Manifest + + +class BM25CorpusUnavailableError(RuntimeError): + """Raised when the persisted BM25 corpus cannot be loaded.""" + + +def bm25_index_path(config: EbookSearchConfig) -> Path: + """Return the configured BM25 index path relative to the current working directory.""" + path = Path(config.bm25_index_dir).expanduser() + if path.is_absolute(): + return path + return Path.cwd() / path + + +def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None: + """Create or refresh the persisted BM25 corpus when it is missing or stale.""" + index_path = bm25_index_path(config) + manifest = read_bm25_manifest(index_path) + db_updated_at = corpus_last_updated_at(session) + if not bm25_index_exists(index_path, manifest): + logger.info("ebook_bm25_index_missing path=%s", index_path) + refresh_bm25_corpus(session, config, db_updated_at=db_updated_at) + return + if db_updated_at is not None and manifest is not None and manifest.created_at < db_updated_at: + logger.info( + "ebook_bm25_index_stale path=%s created_at=%s db_updated_at=%s", + index_path, + manifest.created_at.isoformat(), + db_updated_at.isoformat(), + ) + refresh_bm25_corpus(session, config, db_updated_at=db_updated_at) + return + logger.info( + "ebook_bm25_index_current path=%s chunks=%s created_at=%s", + index_path, + manifest.chunk_count if manifest else 0, + manifest.created_at.isoformat() if manifest else None, + ) + + +def refresh_bm25_corpus( + session: Session, + config: EbookSearchConfig, + *, + db_updated_at: datetime | None = None, +) -> BM25Manifest: + """Rebuild and persist the BM25 corpus from the current database chunks.""" + index_path = bm25_index_path(config) + records = fetch_bm25_corpus_records(session) + manifest = BM25Manifest( + created_at=datetime.now(tz=UTC), + db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session), + chunk_count=len(records), + ) + write_bm25_corpus(index_path, records, manifest) + logger.info( + "ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s note=%s", + index_path, + manifest.chunk_count, + manifest.created_at.isoformat(), + "restart_service_to_use_refreshed_bm25_cache", + ) + return manifest + + +@cache +def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus: + """Load the BM25 corpus into memory once per process. + + This cache intentionally does not notice later on-disk corpus refreshes. Restart the service after rebuilding the + BM25 corpus for searches to use the new index. + """ + index_path = bm25_index_path(config) + logger.info( + "ebook_bm25_corpus_cache_load path=%s note=%s", + index_path, + "restart_service_after_bm25_refresh", + ) + manifest = read_bm25_manifest(index_path) + if manifest is None or not bm25_index_exists(index_path, manifest): + msg = f"BM25 corpus is not available: {index_path}" + raise BM25CorpusUnavailableError(msg) + if manifest.chunk_count == 0: + return BM25Corpus(retriever=None, records=(), manifest=manifest) + + retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True) + records = tuple(dict(record) for record in retriever.corpus) + return BM25Corpus(retriever=retriever, records=records, manifest=manifest) + + +def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tuple[dict[str, object], float]]: + """Score a query against a loaded BM25 corpus.""" + if corpus.retriever is None or not corpus.records: + return [] + k = min(limit, len(corpus.records)) + documents, scores = corpus.retriever.retrieve( + bm25s.tokenize(query, show_progress=False), + corpus=list(corpus.records), + k=k, + show_progress=False, + ) + results: list[tuple[dict[str, object], float]] = [] + for document, score in zip(documents[0], scores[0], strict=True): + score_value = float(score) + if score_value <= 0: + continue + results.append((dict(document), score_value)) + return results + + +def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]: + """Fetch BM25 corpus records from the database.""" + statement = ( + select( + EbookChunk.id.label("chunk_id"), + EbookChunk.text.label("text"), + EbookSource.title.label("source_title"), + EbookSource.author.label("source_author"), + EbookChapter.title.label("chapter_title"), + EbookChunk.page_label.label("page_label"), + func.concat_ws( + " ", + EbookSource.title, + EbookSource.author, + EbookChapter.title, + EbookChunk.search_text, + ).label("bm25_text"), + ) + .select_from(EbookChunk) + .join(EbookSource, EbookSource.id == EbookChunk.source_id) + .outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id) + .order_by(EbookChunk.id) + ) + return [dict(row) for row in session.execute(statement).mappings()] + + +def corpus_last_updated_at(session: Session) -> datetime | None: + """Return the latest source/chapter/chunk update timestamp relevant to BM25 text.""" + update_times = union_all( + select(func.max(EbookSource.updated).label("updated")), + select(func.max(EbookChapter.updated).label("updated")), + select(func.max(EbookChunk.updated).label("updated")), + ).subquery() + return session.scalar(select(func.max(update_times.c.updated))) + + +def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None: + """Write a BM25 corpus and manifest atomically.""" + index_path.parent.mkdir(parents=True, exist_ok=True) + temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent)) + try: + if records: + retriever = bm25s.BM25() + texts = [str(record["bm25_text"]) for record in records] + retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False) + retriever.save(temp_path, corpus=records, show_progress=False) + write_bm25_manifest(temp_path, manifest) + if index_path.exists(): + shutil.rmtree(index_path) + temp_path.rename(index_path) + except Exception: + shutil.rmtree(temp_path, ignore_errors=True) + raise + + +def read_bm25_manifest(index_path: Path) -> BM25Manifest | None: + """Read the BM25 manifest if it exists and is valid.""" + manifest_path = index_path / MANIFEST_NAME + if not manifest_path.exists(): + return None + body = json.loads(manifest_path.read_text(encoding="utf-8")) + return BM25Manifest( + created_at=datetime.fromisoformat(str(body["created_at"])), + db_updated_at=datetime.fromisoformat(str(body["db_updated_at"])) if body.get("db_updated_at") else None, + chunk_count=int(body["chunk_count"]), + ) + + +def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None: + """Write the BM25 manifest to an index directory.""" + body = { + "created_at": manifest.created_at.isoformat(), + "db_updated_at": manifest.db_updated_at.isoformat() if manifest.db_updated_at else None, + "chunk_count": manifest.chunk_count, + } + (index_path / MANIFEST_NAME).write_text(json.dumps(body, indent=2, sort_keys=True), encoding="utf-8") + + +def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool: + """Return whether a usable persisted BM25 index exists.""" + if manifest is None or not index_path.is_dir(): + return False + if manifest.chunk_count == 0: + return True + return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)