improved BM25 write

This commit is contained in:
2026-06-13 17:05:32 -04:00
parent 70d24c2a85
commit 74e4c2e921
2 changed files with 169 additions and 28 deletions
+67 -23
View File
@@ -5,7 +5,6 @@ from __future__ import annotations
import json
import logging
import shutil
import tempfile
from dataclasses import dataclass
from datetime import UTC, datetime
from functools import cache
@@ -59,13 +58,21 @@ class BM25CorpusUnavailableError(RuntimeError):
def bm25_index_path(config: EbookSearchConfig) -> Path:
"""Return the configured BM25 index path relative to the current working directory."""
"""Return the configured BM25 index root path relative to the current working directory."""
path = Path(config.bm25_index_dir).expanduser()
if path.is_absolute():
return path
return Path.cwd() / path
def get_current_bm25_index(index_path: Path) -> Path:
"""Return the live BM25 index directory."""
current_path = index_path / "current"
if current_path.exists() or current_path.is_symlink():
return current_path
return index_path
def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None:
"""Create or refresh the persisted BM25 corpus when it is missing or stale."""
index_path = bm25_index_path(config)
@@ -100,13 +107,13 @@ def refresh_bm25_corpus(
) -> BM25Manifest:
"""Rebuild and persist the BM25 corpus from the current database chunks."""
index_path = bm25_index_path(config)
records = fetch_bm25_corpus_records(session)
records, texts = fetch_bm25_corpus_records(session)
manifest = BM25Manifest(
created_at=datetime.now(tz=UTC),
db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session),
chunk_count=len(records),
)
write_bm25_corpus(index_path, records, manifest)
write_bm25_corpus(index_path, records, texts, manifest)
logger.info(
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
index_path,
@@ -123,7 +130,8 @@ def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
"""
index_path = bm25_index_path(config)
logger.info("ebook_bm25_corpus_cache_load path=%s", index_path)
active_index_path = get_current_bm25_index(index_path)
logger.info("ebook_bm25_corpus_cache_load path=%s active_path=%s", index_path, active_index_path)
manifest = read_bm25_manifest(index_path)
if manifest is None or not bm25_index_exists(index_path, manifest):
msg = f"BM25 corpus is not available: {index_path}"
@@ -131,7 +139,7 @@ def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
if manifest.chunk_count == 0:
return BM25Corpus(retriever=None, records=(), manifest=manifest)
retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True)
retriever = bm25s.BM25.load(active_index_path, load_corpus=True, mmap=True)
records = tuple(dict(record) for record in retriever.corpus)
return BM25Corpus(retriever=retriever, records=records, manifest=manifest)
@@ -156,8 +164,12 @@ def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tup
return results
def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
"""Fetch BM25 corpus records from the database."""
def fetch_bm25_corpus_records(session: Session) -> tuple[list[dict[str, object]], list[str]]:
"""Fetch persistable BM25 corpus records and their matching index texts from the database.
search_text is only needed to build the index, so it is returned separately instead of
being persisted into the corpus records, which would double the corpus size.
"""
statement = (
select(
EbookChunk.id.label("chunk_id"),
@@ -173,7 +185,13 @@ def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
.order_by(EbookChunk.id)
)
return [dict(row) for row in session.execute(statement).mappings()]
records: list[dict[str, object]] = []
texts: list[str] = []
for row in session.execute(statement).mappings():
record = dict(row)
texts.append(str(record.pop("bm25_text")))
records.append(record)
return records, texts
def corpus_last_updated_at(session: Session) -> datetime | None:
@@ -186,28 +204,42 @@ def corpus_last_updated_at(session: Session) -> datetime | None:
return session.scalar(select(func.max(update_times.c.updated)))
def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None:
"""Write a BM25 corpus and manifest atomically."""
index_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent))
def write_bm25_corpus(
index_path: Path,
records: list[dict[str, object]],
texts: list[str],
manifest: BM25Manifest,
) -> None:
"""Write a BM25 corpus generation and publish it through the current symlink."""
index_path.mkdir(parents=True, exist_ok=True)
generations_path = index_path / "generations"
generations_path.mkdir(exist_ok=True)
generation_path = next_bm25_generation_path(generations_path, manifest.created_at)
current_path = index_path / "current"
next_current_path = index_path / f".current.{generation_path.name}.tmp"
try:
generation_path.mkdir()
# Empty corpora publish a manifest-only generation so startup succeeds before any chunks exist.
if records:
retriever = bm25s.BM25()
texts = [str(record["bm25_text"]) for record in records]
retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False)
retriever.save(temp_path, corpus=records, show_progress=False)
write_bm25_manifest(temp_path, manifest)
if index_path.exists():
shutil.rmtree(index_path)
temp_path.rename(index_path)
retriever.save(generation_path, corpus=records, show_progress=False)
write_bm25_manifest(generation_path, manifest)
next_current_path.unlink(missing_ok=True)
next_current_path.symlink_to(generation_path, target_is_directory=True)
next_current_path.replace(current_path)
except Exception:
shutil.rmtree(temp_path, ignore_errors=True)
next_current_path.unlink(missing_ok=True)
shutil.rmtree(generation_path, ignore_errors=True)
raise
def read_bm25_manifest(index_path: Path) -> BM25Manifest | None:
"""Read the BM25 manifest if it exists and is valid."""
manifest_path = index_path / MANIFEST_NAME
manifest_path = get_current_bm25_index(index_path) / MANIFEST_NAME
if not manifest_path.exists():
return None
body = json.loads(manifest_path.read_text(encoding="utf-8"))
@@ -230,8 +262,20 @@ def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None:
def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool:
"""Return whether a usable persisted BM25 index exists."""
if manifest is None or not index_path.is_dir():
active_index_path = get_current_bm25_index(index_path)
if manifest is None or not active_index_path.is_dir():
return False
if manifest.chunk_count == 0:
return True
return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
return all((active_index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
def next_bm25_generation_path(generations_path: Path, created_at: datetime) -> Path:
"""Return an unused dated BM25 generation path."""
base_name = created_at.astimezone(UTC).strftime("%Y%m%dT%H%M%S.%fZ")
generation_path = generations_path / base_name
suffix = 1
while generation_path.exists():
generation_path = generations_path / f"{base_name}.{suffix}"
suffix += 1
return generation_path