From 7010a4f3b90c524371eefc5cf5cc72cfa92f8f06 Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Fri, 12 Jun 2026 02:45:44 -0400 Subject: [PATCH] built rag search setup --- python/ebook_search/search.py | 371 ++++++++++++++++++++++++++++++++++ python/ebook_search/timing.py | 36 ++++ 2 files changed, 407 insertions(+) create mode 100644 python/ebook_search/search.py create mode 100644 python/ebook_search/timing.py diff --git a/python/ebook_search/search.py b/python/ebook_search/search.py new file mode 100644 index 0000000..e25c17d --- /dev/null +++ b/python/ebook_search/search.py @@ -0,0 +1,371 @@ +"""Hybrid search orchestration.""" + +from __future__ import annotations + +import logging +import re +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING + +from pgvector.sqlalchemy import Vector +from sqlalchemy import literal, select +from sqlalchemy.orm import Session + +from python.ebook_search.bm25_corpus import ( + load_bm25_corpus, + score_bm25_corpus, +) +from python.ebook_search.embeddings import MODEL_DIMENSIONS, embed_query, get_embedding_table +from python.ebook_search.rerank import rerank_chunks +from python.ebook_search.timing import RuntimeStep, timed_result +from python.orm.richie import ( + EbookChapter, + EbookChunk, + EbookEmbeddingModel, + EbookSource, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + + from sqlalchemy.engine import Engine + + from python.ebook_search.config import EbookSearchConfig + +logger = logging.getLogger(__name__) +BM25_CANDIDATE_LIMIT = 120 + + +@dataclass(frozen=True) +class SearchResult: + """One source chunk returned by search.""" + + chunk_id: int + text: str + source_title: str + score: float = 0.0 + vector_score: float | None = None + bm25_score: float | None = None + fused_score: float | None = None + rerank_score: float | None = None + source_author: str | None = None + chapter_title: str | None = None + page_label: str | None = None + rank_source: str = "Hybrid" + + +@dataclass(frozen=True) +class SearchResponse: + """Search output for the UI.""" + + query: str + results: list[SearchResult] + rank_label: str + timings: tuple[RuntimeStep, ...] = () + + @property + def total_runtime_ms(self) -> float: + """Return total measured runtime for the response.""" + return sum(step.duration_ms for step in self.timings if step.counts_toward_total) + + +@dataclass(frozen=True) +class RetrievalResponse: + """Parallel retrieval output for vector and BM25 candidates.""" + + vector_results: list[SearchResult] + lexical_results: list[SearchResult] + timings: tuple[RuntimeStep, ...] + + +def search_ebooks( + engine: Engine, + query: str, + config: EbookSearchConfig, + *, + rerank: bool = False, +) -> SearchResponse: + """Run hybrid vector/BM25 search and optional reranking.""" + if not query.strip(): + logger.info("ebook_search_empty_query") + return SearchResponse(query=query, results=[], rank_label="Hybrid") + + logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank) + timings: list[RuntimeStep] = [] + retrieval_query, timing = timed_result("Query preparation", retrieval_query_from_text, query) + timings.append(timing) + retrieval, timing = timed_result( + "Hybrid retrieval", + parallel_retrieval, + engine, + retrieval_query, + config, + ) + timings.extend(retrieval.timings) + timings.append(timing) + fused, timing = timed_result( + "Reciprocal rank fusion", + reciprocal_rank_fusion, + retrieval.vector_results, + retrieval.lexical_results, + ) + timings.append(timing) + if config.rerank.enabled and rerank: + response, timing = timed_result("Rerank", apply_rerank, query, fused, config) + else: + response, timing = timed_result("Rerank skipped", skip_rerank, query, fused, config) + timings.append(timing) + response = replace(response, timings=tuple(timings)) + logger.info( + "ebook_search_complete vector_candidates=%s lexical_candidates=%s " + "fused_candidates=%s returned=%s rank_label=%s runtime_ms=%.1f", + len(retrieval.vector_results), + len(retrieval.lexical_results), + len(fused), + len(response.results), + response.rank_label, + response.total_runtime_ms, + ) + return response + + +def parallel_retrieval(engine: Engine, query: str, config: EbookSearchConfig) -> RetrievalResponse: + """Run vector and BM25 candidate retrieval concurrently with separate database sessions.""" + with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor: + vector_future = executor.submit( + timed_result, + "Embedding + vector search", + vector_candidates, + engine, + query, + config, + ) + bm25_future = executor.submit( + timed_result, + "BM25 search", + bm25_candidates, + query, + config, + ) + vector_results, vector_timing = vector_future.result() + lexical_results, lexical_timing = bm25_future.result() + + logger.info( + "ebook_parallel_retrieval_complete vector_candidates=%s lexical_candidates=%s", + len(vector_results), + len(lexical_results), + ) + return RetrievalResponse( + vector_results=vector_results, + lexical_results=lexical_results, + timings=( + replace(vector_timing, counts_toward_total=False), + replace(lexical_timing, counts_toward_total=False), + ), + ) + + +def skip_rerank( + query: str, + candidates: list[SearchResult], + config: EbookSearchConfig, +) -> SearchResponse: + """Return fused hybrid results without reranking.""" + logger.info("ebook_rerank_skipped candidates=%s", len(candidates)) + return SearchResponse(query=query, results=candidates[: config.top_k], rank_label="Hybrid") + + +def apply_rerank( + query: str, + candidates: list[SearchResult], + config: EbookSearchConfig, +) -> SearchResponse: + """Rerank already-fused hybrid candidates.""" + reranked = rerank_chunks(query, candidates[: config.rerank.candidates], config.rerank) + logger.info( + "ebook_rerank_complete input_candidates=%s returned=%s", + min(len(candidates), config.rerank.candidates), + len(reranked), + ) + return SearchResponse( + query=query, + results=[replace(result, rank_source="Hybrid + rerank") for result in reranked[: config.top_k]], + rank_label="Hybrid + rerank", + ) + + +def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]: + """Return pgvector cosine candidates for a normalized query.""" + with Session(engine) as session: + model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model)) + if model is None: + msg = f"Embedding model is not registered: {config.embedding_model}" + raise ValueError(msg) + + expected_dimension = MODEL_DIMENSIONS[config.embedding_model] + if model.dimension != expected_dimension: + msg = f"Model row dimension {model.dimension} does not match configured dimension {expected_dimension}" + raise ValueError(msg) + + embedding = embed_query(query, config) + limit = max(config.rerank.candidates, config.top_k) * 4 + embedding_table = get_embedding_table(model.dimension) + + embedding_param = literal(embedding, type_=Vector(model.dimension)) + distance = embedding_table.embedding.op("<=>")(embedding_param) + score = (literal(1.0) - distance).label("score") + statement = ( + select( + EbookChunk.id.label("chunk_id"), + EbookChunk.text.label("text"), + EbookSource.title.label("source_title"), + EbookSource.author.label("source_author"), + EbookChapter.title.label("chapter_title"), + EbookChunk.page_label.label("page_label"), + score, + ) + .select_from(embedding_table) + .join(EbookChunk, EbookChunk.id == embedding_table.chunk_id) + .join(EbookSource, EbookSource.id == EbookChunk.source_id) + .outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id) + .where(embedding_table.model_id == model.id) + .order_by(distance) + .limit(limit) + ) + rows = session.execute(statement).mappings() + results = [search_result_from_row(row) for row in rows] + logger.info( + "ebook_vector_search_complete model=%s dimension=%s candidates=%s", + config.embedding_model, + model.dimension, + len(results), + ) + return results + + +def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]: + """Return BM25-ranked lexical candidates using the persisted corpus.""" + corpus = load_bm25_corpus(config) + if not corpus.records: + logger.info("ebook_bm25_search_complete corpus=0 candidates=0") + return [] + + scored_records = score_bm25_corpus(query, corpus, limit=BM25_CANDIDATE_LIMIT) + results = [ + replace(search_result_from_row(record), score=score, vector_score=None, bm25_score=score) + for record, score in scored_records + ] + + max_score = results[0].bm25_score if results else 0.0 + logger.info( + "ebook_bm25_search_complete corpus=%s candidates=%s max_score=%.6f", + len(corpus.records), + len(results), + max_score, + ) + return results + + +def reciprocal_rank_fusion( + vector_results: list[SearchResult], + lexical_results: list[SearchResult], + *, + rank_constant: int = 60, +) -> list[SearchResult]: + """Fuse vector and lexical rankings with Reciprocal Rank Fusion.""" + by_chunk: dict[int, SearchResult] = {} + scores: dict[int, float] = {} + vector_scores: dict[int, float] = {} + bm25_scores: dict[int, float] = {} + + for rank, result in enumerate(vector_results, start=1): + by_chunk.setdefault(result.chunk_id, result) + vector_scores[result.chunk_id] = result.vector_score if result.vector_score is not None else result.score + scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank)) + + for rank, result in enumerate(lexical_results, start=1): + by_chunk.setdefault(result.chunk_id, result) + bm25_scores[result.chunk_id] = result.bm25_score if result.bm25_score is not None else result.score + scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank)) + + return sorted( + ( + replace( + result, + score=scores[result.chunk_id], + vector_score=vector_scores.get(result.chunk_id), + bm25_score=bm25_scores.get(result.chunk_id), + fused_score=scores[result.chunk_id], + rank_source="Hybrid", + ) + for result in by_chunk.values() + ), + key=lambda result: result.score, + reverse=True, + ) + + +def search_result_from_row(row: Mapping[str, object]) -> SearchResult: + """Convert a database row mapping into a search result.""" + return SearchResult( + chunk_id=int(row["chunk_id"]), + text=str(row["text"]), + source_title=str(row["source_title"]), + source_author=optional_str(row["source_author"]), + chapter_title=optional_str(row["chapter_title"]), + page_label=optional_str(row["page_label"]), + score=float(row["score"]) if "score" in row else 0.0, + vector_score=float(row["score"]) if "score" in row else None, + ) + + +def optional_str(value: object) -> str | None: + """Convert nullable database values to optional strings.""" + if value is None: + return None + return str(value) + + +TOKEN_RE = re.compile(r"[A-Za-z0-9_]+") + + +def tokens(text_value: str) -> list[str]: + """Extract tokens from a text value. + + This is a simple approximation of the tokenization used by PostgreSQL's full-text search, + which is sufficient for BM25 candidate retrieval. It lowercases tokens and includes alphanumeric characters and + underscores. + """ + return [match.group(0).lower() for match in TOKEN_RE.finditer(text_value)] + + +QUERY_STOP_WORDS = { + "a", + "an", + "and", + "are", + "as", + "at", + "does", + "for", + "in", + "is", + "of", + "the", + "to", + "what", + "when", + "where", + "which", + "who", + "why", +} + + +def retrieval_query_from_text(query: str) -> str: + """Remove generic question words while preserving entity and series terms.""" + keywords = [token for token in tokens(query) if token not in QUERY_STOP_WORDS] + if not keywords: + return query + return " ".join(keywords) diff --git a/python/ebook_search/timing.py b/python/ebook_search/timing.py new file mode 100644 index 0000000..eb8e474 --- /dev/null +++ b/python/ebook_search/timing.py @@ -0,0 +1,36 @@ +"""Runtime timing helpers for EPUB search.""" + +from __future__ import annotations + +from dataclasses import dataclass +from time import perf_counter +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + + +@dataclass(frozen=True) +class RuntimeStep: + """Elapsed runtime for one named search step.""" + + name: str + duration_ms: float + counts_toward_total: bool = True + + +def runtime_step_from_start(name: str, start_seconds: float) -> RuntimeStep: + """Create a runtime step from a prior perf_counter timestamp.""" + return RuntimeStep(name=name, duration_ms=(perf_counter() - start_seconds) * 1000) + + +def timed_result[T, **P]( + name: str, + operation: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> tuple[T, RuntimeStep]: + """Run an operation and return its result plus elapsed runtime.""" + start_seconds = perf_counter() + result = operation(*args, **kwargs) + return result, runtime_step_from_start(name, start_seconds)