improved BM25 write
This commit is contained in:
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from functools import cache
|
||||
@@ -59,13 +58,21 @@ class BM25CorpusUnavailableError(RuntimeError):
|
||||
|
||||
|
||||
def bm25_index_path(config: EbookSearchConfig) -> Path:
|
||||
"""Return the configured BM25 index path relative to the current working directory."""
|
||||
"""Return the configured BM25 index root path relative to the current working directory."""
|
||||
path = Path(config.bm25_index_dir).expanduser()
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return Path.cwd() / path
|
||||
|
||||
|
||||
def get_current_bm25_index(index_path: Path) -> Path:
|
||||
"""Return the live BM25 index directory."""
|
||||
current_path = index_path / "current"
|
||||
if current_path.exists() or current_path.is_symlink():
|
||||
return current_path
|
||||
return index_path
|
||||
|
||||
|
||||
def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None:
|
||||
"""Create or refresh the persisted BM25 corpus when it is missing or stale."""
|
||||
index_path = bm25_index_path(config)
|
||||
@@ -100,13 +107,13 @@ def refresh_bm25_corpus(
|
||||
) -> BM25Manifest:
|
||||
"""Rebuild and persist the BM25 corpus from the current database chunks."""
|
||||
index_path = bm25_index_path(config)
|
||||
records = fetch_bm25_corpus_records(session)
|
||||
records, texts = fetch_bm25_corpus_records(session)
|
||||
manifest = BM25Manifest(
|
||||
created_at=datetime.now(tz=UTC),
|
||||
db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session),
|
||||
chunk_count=len(records),
|
||||
)
|
||||
write_bm25_corpus(index_path, records, manifest)
|
||||
write_bm25_corpus(index_path, records, texts, manifest)
|
||||
logger.info(
|
||||
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
|
||||
index_path,
|
||||
@@ -123,7 +130,8 @@ def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
|
||||
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
|
||||
"""
|
||||
index_path = bm25_index_path(config)
|
||||
logger.info("ebook_bm25_corpus_cache_load path=%s", index_path)
|
||||
active_index_path = get_current_bm25_index(index_path)
|
||||
logger.info("ebook_bm25_corpus_cache_load path=%s active_path=%s", index_path, active_index_path)
|
||||
manifest = read_bm25_manifest(index_path)
|
||||
if manifest is None or not bm25_index_exists(index_path, manifest):
|
||||
msg = f"BM25 corpus is not available: {index_path}"
|
||||
@@ -131,7 +139,7 @@ def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
|
||||
if manifest.chunk_count == 0:
|
||||
return BM25Corpus(retriever=None, records=(), manifest=manifest)
|
||||
|
||||
retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True)
|
||||
retriever = bm25s.BM25.load(active_index_path, load_corpus=True, mmap=True)
|
||||
records = tuple(dict(record) for record in retriever.corpus)
|
||||
return BM25Corpus(retriever=retriever, records=records, manifest=manifest)
|
||||
|
||||
@@ -156,8 +164,12 @@ def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tup
|
||||
return results
|
||||
|
||||
|
||||
def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
|
||||
"""Fetch BM25 corpus records from the database."""
|
||||
def fetch_bm25_corpus_records(session: Session) -> tuple[list[dict[str, object]], list[str]]:
|
||||
"""Fetch persistable BM25 corpus records and their matching index texts from the database.
|
||||
|
||||
search_text is only needed to build the index, so it is returned separately instead of
|
||||
being persisted into the corpus records, which would double the corpus size.
|
||||
"""
|
||||
statement = (
|
||||
select(
|
||||
EbookChunk.id.label("chunk_id"),
|
||||
@@ -173,7 +185,13 @@ def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
|
||||
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
|
||||
.order_by(EbookChunk.id)
|
||||
)
|
||||
return [dict(row) for row in session.execute(statement).mappings()]
|
||||
records: list[dict[str, object]] = []
|
||||
texts: list[str] = []
|
||||
for row in session.execute(statement).mappings():
|
||||
record = dict(row)
|
||||
texts.append(str(record.pop("bm25_text")))
|
||||
records.append(record)
|
||||
return records, texts
|
||||
|
||||
|
||||
def corpus_last_updated_at(session: Session) -> datetime | None:
|
||||
@@ -186,28 +204,42 @@ def corpus_last_updated_at(session: Session) -> datetime | None:
|
||||
return session.scalar(select(func.max(update_times.c.updated)))
|
||||
|
||||
|
||||
def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None:
|
||||
"""Write a BM25 corpus and manifest atomically."""
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent))
|
||||
def write_bm25_corpus(
|
||||
index_path: Path,
|
||||
records: list[dict[str, object]],
|
||||
texts: list[str],
|
||||
manifest: BM25Manifest,
|
||||
) -> None:
|
||||
"""Write a BM25 corpus generation and publish it through the current symlink."""
|
||||
index_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generations_path = index_path / "generations"
|
||||
generations_path.mkdir(exist_ok=True)
|
||||
|
||||
generation_path = next_bm25_generation_path(generations_path, manifest.created_at)
|
||||
current_path = index_path / "current"
|
||||
next_current_path = index_path / f".current.{generation_path.name}.tmp"
|
||||
try:
|
||||
generation_path.mkdir()
|
||||
|
||||
# Empty corpora publish a manifest-only generation so startup succeeds before any chunks exist.
|
||||
if records:
|
||||
retriever = bm25s.BM25()
|
||||
texts = [str(record["bm25_text"]) for record in records]
|
||||
retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False)
|
||||
retriever.save(temp_path, corpus=records, show_progress=False)
|
||||
write_bm25_manifest(temp_path, manifest)
|
||||
if index_path.exists():
|
||||
shutil.rmtree(index_path)
|
||||
temp_path.rename(index_path)
|
||||
retriever.save(generation_path, corpus=records, show_progress=False)
|
||||
write_bm25_manifest(generation_path, manifest)
|
||||
next_current_path.unlink(missing_ok=True)
|
||||
next_current_path.symlink_to(generation_path, target_is_directory=True)
|
||||
next_current_path.replace(current_path)
|
||||
except Exception:
|
||||
shutil.rmtree(temp_path, ignore_errors=True)
|
||||
next_current_path.unlink(missing_ok=True)
|
||||
shutil.rmtree(generation_path, ignore_errors=True)
|
||||
raise
|
||||
|
||||
|
||||
def read_bm25_manifest(index_path: Path) -> BM25Manifest | None:
|
||||
"""Read the BM25 manifest if it exists and is valid."""
|
||||
manifest_path = index_path / MANIFEST_NAME
|
||||
manifest_path = get_current_bm25_index(index_path) / MANIFEST_NAME
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
body = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
@@ -230,8 +262,20 @@ def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None:
|
||||
|
||||
def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool:
|
||||
"""Return whether a usable persisted BM25 index exists."""
|
||||
if manifest is None or not index_path.is_dir():
|
||||
active_index_path = get_current_bm25_index(index_path)
|
||||
if manifest is None or not active_index_path.is_dir():
|
||||
return False
|
||||
if manifest.chunk_count == 0:
|
||||
return True
|
||||
return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
|
||||
return all((active_index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
|
||||
|
||||
|
||||
def next_bm25_generation_path(generations_path: Path, created_at: datetime) -> Path:
|
||||
"""Return an unused dated BM25 generation path."""
|
||||
base_name = created_at.astimezone(UTC).strftime("%Y%m%dT%H%M%S.%fZ")
|
||||
generation_path = generations_path / base_name
|
||||
suffix = 1
|
||||
while generation_path.exists():
|
||||
generation_path = generations_path / f"{base_name}.{suffix}"
|
||||
suffix += 1
|
||||
return generation_path
|
||||
|
||||
Reference in New Issue
Block a user