238 lines
8.5 KiB
Python
238 lines
8.5 KiB
Python
"""Persisted BM25 corpus management."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import shutil
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from functools import cache
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import bm25s
|
|
from sqlalchemy import func, select, union_all
|
|
|
|
from python.orm.richie import EbookChapter, EbookChunk, EbookSource
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.orm import Session
|
|
|
|
from python.ebook_search.config import EbookSearchConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
MANIFEST_NAME = "manifest.json"
|
|
REQUIRED_INDEX_FILES = frozenset(
|
|
{
|
|
"data.csc.index.npy",
|
|
"indices.csc.index.npy",
|
|
"indptr.csc.index.npy",
|
|
"params.index.json",
|
|
"vocab.index.json",
|
|
"corpus.jsonl",
|
|
}
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BM25Manifest:
|
|
"""Metadata describing a persisted BM25 corpus."""
|
|
|
|
created_at: datetime
|
|
db_updated_at: datetime | None
|
|
chunk_count: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BM25Corpus:
|
|
"""Loaded persisted BM25 corpus and retriever."""
|
|
|
|
retriever: object | None
|
|
records: tuple[dict[str, object], ...]
|
|
manifest: BM25Manifest
|
|
|
|
|
|
class BM25CorpusUnavailableError(RuntimeError):
|
|
"""Raised when the persisted BM25 corpus cannot be loaded."""
|
|
|
|
|
|
def bm25_index_path(config: EbookSearchConfig) -> Path:
|
|
"""Return the configured BM25 index path relative to the current working directory."""
|
|
path = Path(config.bm25_index_dir).expanduser()
|
|
if path.is_absolute():
|
|
return path
|
|
return Path.cwd() / path
|
|
|
|
|
|
def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None:
|
|
"""Create or refresh the persisted BM25 corpus when it is missing or stale."""
|
|
index_path = bm25_index_path(config)
|
|
manifest = read_bm25_manifest(index_path)
|
|
db_updated_at = corpus_last_updated_at(session)
|
|
if not bm25_index_exists(index_path, manifest):
|
|
logger.info("ebook_bm25_index_missing path=%s", index_path)
|
|
refresh_bm25_corpus(session, config, db_updated_at=db_updated_at)
|
|
return
|
|
if db_updated_at is not None and manifest is not None and manifest.created_at < db_updated_at:
|
|
logger.info(
|
|
"ebook_bm25_index_stale path=%s created_at=%s db_updated_at=%s",
|
|
index_path,
|
|
manifest.created_at.isoformat(),
|
|
db_updated_at.isoformat(),
|
|
)
|
|
refresh_bm25_corpus(session, config, db_updated_at=db_updated_at)
|
|
return
|
|
logger.info(
|
|
"ebook_bm25_index_current path=%s chunks=%s created_at=%s",
|
|
index_path,
|
|
manifest.chunk_count if manifest else 0,
|
|
manifest.created_at.isoformat() if manifest else None,
|
|
)
|
|
|
|
|
|
def refresh_bm25_corpus(
|
|
session: Session,
|
|
config: EbookSearchConfig,
|
|
*,
|
|
db_updated_at: datetime | None = None,
|
|
) -> BM25Manifest:
|
|
"""Rebuild and persist the BM25 corpus from the current database chunks."""
|
|
index_path = bm25_index_path(config)
|
|
records = fetch_bm25_corpus_records(session)
|
|
manifest = BM25Manifest(
|
|
created_at=datetime.now(tz=UTC),
|
|
db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session),
|
|
chunk_count=len(records),
|
|
)
|
|
write_bm25_corpus(index_path, records, manifest)
|
|
logger.info(
|
|
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
|
|
index_path,
|
|
manifest.chunk_count,
|
|
manifest.created_at.isoformat(),
|
|
)
|
|
return manifest
|
|
|
|
|
|
@cache
|
|
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
|
|
"""Load the BM25 corpus into memory once per process.
|
|
|
|
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
|
|
"""
|
|
index_path = bm25_index_path(config)
|
|
logger.info("ebook_bm25_corpus_cache_load path=%s", index_path)
|
|
manifest = read_bm25_manifest(index_path)
|
|
if manifest is None or not bm25_index_exists(index_path, manifest):
|
|
msg = f"BM25 corpus is not available: {index_path}"
|
|
raise BM25CorpusUnavailableError(msg)
|
|
if manifest.chunk_count == 0:
|
|
return BM25Corpus(retriever=None, records=(), manifest=manifest)
|
|
|
|
retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True)
|
|
records = tuple(dict(record) for record in retriever.corpus)
|
|
return BM25Corpus(retriever=retriever, records=records, manifest=manifest)
|
|
|
|
|
|
def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tuple[dict[str, object], float]]:
|
|
"""Score a query against a loaded BM25 corpus."""
|
|
if corpus.retriever is None or not corpus.records:
|
|
return []
|
|
k = min(limit, len(corpus.records))
|
|
documents, scores = corpus.retriever.retrieve(
|
|
bm25s.tokenize(query, show_progress=False),
|
|
corpus=list(corpus.records),
|
|
k=k,
|
|
show_progress=False,
|
|
)
|
|
results: list[tuple[dict[str, object], float]] = []
|
|
for document, score in zip(documents[0], scores[0], strict=True):
|
|
score_value = float(score)
|
|
if score_value <= 0:
|
|
continue
|
|
results.append((dict(document), score_value))
|
|
return results
|
|
|
|
|
|
def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
|
|
"""Fetch BM25 corpus records from the database."""
|
|
statement = (
|
|
select(
|
|
EbookChunk.id.label("chunk_id"),
|
|
EbookChunk.text.label("text"),
|
|
EbookSource.title.label("source_title"),
|
|
EbookSource.author.label("source_author"),
|
|
EbookChapter.title.label("chapter_title"),
|
|
EbookChunk.page_label.label("page_label"),
|
|
EbookChunk.search_text.label("bm25_text"),
|
|
)
|
|
.select_from(EbookChunk)
|
|
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
|
|
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
|
|
.order_by(EbookChunk.id)
|
|
)
|
|
return [dict(row) for row in session.execute(statement).mappings()]
|
|
|
|
|
|
def corpus_last_updated_at(session: Session) -> datetime | None:
|
|
"""Return the latest source/chapter/chunk update timestamp relevant to BM25 text."""
|
|
update_times = union_all(
|
|
select(func.max(EbookSource.updated).label("updated")),
|
|
select(func.max(EbookChapter.updated).label("updated")),
|
|
select(func.max(EbookChunk.updated).label("updated")),
|
|
).subquery()
|
|
return session.scalar(select(func.max(update_times.c.updated)))
|
|
|
|
|
|
def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None:
|
|
"""Write a BM25 corpus and manifest atomically."""
|
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
|
temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent))
|
|
try:
|
|
if records:
|
|
retriever = bm25s.BM25()
|
|
texts = [str(record["bm25_text"]) for record in records]
|
|
retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False)
|
|
retriever.save(temp_path, corpus=records, show_progress=False)
|
|
write_bm25_manifest(temp_path, manifest)
|
|
if index_path.exists():
|
|
shutil.rmtree(index_path)
|
|
temp_path.rename(index_path)
|
|
except Exception:
|
|
shutil.rmtree(temp_path, ignore_errors=True)
|
|
raise
|
|
|
|
|
|
def read_bm25_manifest(index_path: Path) -> BM25Manifest | None:
|
|
"""Read the BM25 manifest if it exists and is valid."""
|
|
manifest_path = index_path / MANIFEST_NAME
|
|
if not manifest_path.exists():
|
|
return None
|
|
body = json.loads(manifest_path.read_text(encoding="utf-8"))
|
|
return BM25Manifest(
|
|
created_at=datetime.fromisoformat(str(body["created_at"])),
|
|
db_updated_at=datetime.fromisoformat(str(body["db_updated_at"])) if body.get("db_updated_at") else None,
|
|
chunk_count=int(body["chunk_count"]),
|
|
)
|
|
|
|
|
|
def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None:
|
|
"""Write the BM25 manifest to an index directory."""
|
|
body = {
|
|
"created_at": manifest.created_at.isoformat(),
|
|
"db_updated_at": manifest.db_updated_at.isoformat() if manifest.db_updated_at else None,
|
|
"chunk_count": manifest.chunk_count,
|
|
}
|
|
(index_path / MANIFEST_NAME).write_text(json.dumps(body, indent=2, sort_keys=True), encoding="utf-8")
|
|
|
|
|
|
def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool:
|
|
"""Return whether a usable persisted BM25 index exists."""
|
|
if manifest is None or not index_path.is_dir():
|
|
return False
|
|
if manifest.chunk_count == 0:
|
|
return True
|
|
return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
|