"""Persisted BM25 corpus management.""" from __future__ import annotations import json import logging import shutil import tempfile from dataclasses import dataclass from datetime import UTC, datetime from functools import cache from pathlib import Path from typing import TYPE_CHECKING import bm25s from sqlalchemy import func, select, union_all from python.orm.richie import EbookChapter, EbookChunk, EbookSource if TYPE_CHECKING: from sqlalchemy.orm import Session from python.ebook_search.config import EbookSearchConfig logger = logging.getLogger(__name__) MANIFEST_NAME = "manifest.json" REQUIRED_INDEX_FILES = frozenset( { "data.csc.index.npy", "indices.csc.index.npy", "indptr.csc.index.npy", "params.index.json", "vocab.index.json", "corpus.jsonl", } ) @dataclass(frozen=True) class BM25Manifest: """Metadata describing a persisted BM25 corpus.""" created_at: datetime db_updated_at: datetime | None chunk_count: int @dataclass(frozen=True) class BM25Corpus: """Loaded persisted BM25 corpus and retriever.""" retriever: object | None records: tuple[dict[str, object], ...] manifest: BM25Manifest class BM25CorpusUnavailableError(RuntimeError): """Raised when the persisted BM25 corpus cannot be loaded.""" def bm25_index_path(config: EbookSearchConfig) -> Path: """Return the configured BM25 index path relative to the current working directory.""" path = Path(config.bm25_index_dir).expanduser() if path.is_absolute(): return path return Path.cwd() / path def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None: """Create or refresh the persisted BM25 corpus when it is missing or stale.""" index_path = bm25_index_path(config) manifest = read_bm25_manifest(index_path) db_updated_at = corpus_last_updated_at(session) if not bm25_index_exists(index_path, manifest): logger.info("ebook_bm25_index_missing path=%s", index_path) refresh_bm25_corpus(session, config, db_updated_at=db_updated_at) return if db_updated_at is not None and manifest is not None and manifest.created_at < db_updated_at: logger.info( "ebook_bm25_index_stale path=%s created_at=%s db_updated_at=%s", index_path, manifest.created_at.isoformat(), db_updated_at.isoformat(), ) refresh_bm25_corpus(session, config, db_updated_at=db_updated_at) return logger.info( "ebook_bm25_index_current path=%s chunks=%s created_at=%s", index_path, manifest.chunk_count if manifest else 0, manifest.created_at.isoformat() if manifest else None, ) def refresh_bm25_corpus( session: Session, config: EbookSearchConfig, *, db_updated_at: datetime | None = None, ) -> BM25Manifest: """Rebuild and persist the BM25 corpus from the current database chunks.""" index_path = bm25_index_path(config) records = fetch_bm25_corpus_records(session) manifest = BM25Manifest( created_at=datetime.now(tz=UTC), db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session), chunk_count=len(records), ) write_bm25_corpus(index_path, records, manifest) logger.info( "ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s note=%s", index_path, manifest.chunk_count, manifest.created_at.isoformat(), "restart_service_to_use_refreshed_bm25_cache", ) return manifest @cache def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus: """Load the BM25 corpus into memory once per process. This cache intentionally does not notice later on-disk corpus refreshes. Restart the service after rebuilding the BM25 corpus for searches to use the new index. """ index_path = bm25_index_path(config) logger.info( "ebook_bm25_corpus_cache_load path=%s note=%s", index_path, "restart_service_after_bm25_refresh", ) manifest = read_bm25_manifest(index_path) if manifest is None or not bm25_index_exists(index_path, manifest): msg = f"BM25 corpus is not available: {index_path}" raise BM25CorpusUnavailableError(msg) if manifest.chunk_count == 0: return BM25Corpus(retriever=None, records=(), manifest=manifest) retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True) records = tuple(dict(record) for record in retriever.corpus) return BM25Corpus(retriever=retriever, records=records, manifest=manifest) def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tuple[dict[str, object], float]]: """Score a query against a loaded BM25 corpus.""" if corpus.retriever is None or not corpus.records: return [] k = min(limit, len(corpus.records)) documents, scores = corpus.retriever.retrieve( bm25s.tokenize(query, show_progress=False), corpus=list(corpus.records), k=k, show_progress=False, ) results: list[tuple[dict[str, object], float]] = [] for document, score in zip(documents[0], scores[0], strict=True): score_value = float(score) if score_value <= 0: continue results.append((dict(document), score_value)) return results def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]: """Fetch BM25 corpus records from the database.""" statement = ( select( EbookChunk.id.label("chunk_id"), EbookChunk.text.label("text"), EbookSource.title.label("source_title"), EbookSource.author.label("source_author"), EbookChapter.title.label("chapter_title"), EbookChunk.page_label.label("page_label"), EbookChunk.search_text.label("bm25_text"), ) .select_from(EbookChunk) .join(EbookSource, EbookSource.id == EbookChunk.source_id) .outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id) .order_by(EbookChunk.id) ) return [dict(row) for row in session.execute(statement).mappings()] def corpus_last_updated_at(session: Session) -> datetime | None: """Return the latest source/chapter/chunk update timestamp relevant to BM25 text.""" update_times = union_all( select(func.max(EbookSource.updated).label("updated")), select(func.max(EbookChapter.updated).label("updated")), select(func.max(EbookChunk.updated).label("updated")), ).subquery() return session.scalar(select(func.max(update_times.c.updated))) def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None: """Write a BM25 corpus and manifest atomically.""" index_path.parent.mkdir(parents=True, exist_ok=True) temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent)) try: if records: retriever = bm25s.BM25() texts = [str(record["bm25_text"]) for record in records] retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False) retriever.save(temp_path, corpus=records, show_progress=False) write_bm25_manifest(temp_path, manifest) if index_path.exists(): shutil.rmtree(index_path) temp_path.rename(index_path) except Exception: shutil.rmtree(temp_path, ignore_errors=True) raise def read_bm25_manifest(index_path: Path) -> BM25Manifest | None: """Read the BM25 manifest if it exists and is valid.""" manifest_path = index_path / MANIFEST_NAME if not manifest_path.exists(): return None body = json.loads(manifest_path.read_text(encoding="utf-8")) return BM25Manifest( created_at=datetime.fromisoformat(str(body["created_at"])), db_updated_at=datetime.fromisoformat(str(body["db_updated_at"])) if body.get("db_updated_at") else None, chunk_count=int(body["chunk_count"]), ) def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None: """Write the BM25 manifest to an index directory.""" body = { "created_at": manifest.created_at.isoformat(), "db_updated_at": manifest.db_updated_at.isoformat() if manifest.db_updated_at else None, "chunk_count": manifest.chunk_count, } (index_path / MANIFEST_NAME).write_text(json.dumps(body, indent=2, sort_keys=True), encoding="utf-8") def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool: """Return whether a usable persisted BM25 index exists.""" if manifest is None or not index_path.is_dir(): return False if manifest.chunk_count == 0: return True return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)