"""Hybrid search orchestration.""" from __future__ import annotations import logging import re from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, replace from typing import TYPE_CHECKING from pgvector.sqlalchemy import Vector from sqlalchemy import literal, select from sqlalchemy.orm import Session from python.ebook_search.bm25_corpus import ( load_bm25_corpus, score_bm25_corpus, ) from python.ebook_search.embeddings import MODEL_DIMENSIONS, embed_query, get_embedding_table from python.ebook_search.rerank import rerank_chunks from python.ebook_search.timing import RuntimeStep, timed_result from python.orm.richie import ( EbookChapter, EbookChunk, EbookEmbeddingModel, EbookSource, ) if TYPE_CHECKING: from collections.abc import Mapping from sqlalchemy.engine import Engine from python.ebook_search.config import EbookSearchConfig logger = logging.getLogger(__name__) BM25_CANDIDATE_LIMIT = 120 @dataclass(frozen=True) class SearchResult: """One source chunk returned by search.""" chunk_id: int text: str source_title: str score: float = 0.0 vector_score: float | None = None bm25_score: float | None = None fused_score: float | None = None rerank_score: float | None = None source_author: str | None = None chapter_title: str | None = None page_label: str | None = None rank_source: str = "Hybrid" @dataclass(frozen=True) class SearchResponse: """Search output for the UI.""" query: str results: list[SearchResult] rank_label: str timings: tuple[RuntimeStep, ...] = () @property def total_runtime_ms(self) -> float: """Return total measured runtime for the response.""" return sum(step.duration_ms for step in self.timings if step.counts_toward_total) @dataclass(frozen=True) class RetrievalResponse: """Parallel retrieval output for vector and BM25 candidates.""" vector_results: list[SearchResult] lexical_results: list[SearchResult] timings: tuple[RuntimeStep, ...] def search_ebooks( engine: Engine, query: str, config: EbookSearchConfig, *, rerank: bool = False, ) -> SearchResponse: """Run hybrid vector/BM25 search and optional reranking.""" if not query.strip(): logger.info("ebook_search_empty_query") return SearchResponse(query=query, results=[], rank_label="Hybrid") logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank) timings: list[RuntimeStep] = [] bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query) timings.append(timing) retrieval, timing = timed_result( "Hybrid retrieval", parallel_retrieval, engine, query, bm25_query, config, ) timings.extend(retrieval.timings) timings.append(timing) fused, timing = timed_result( "Reciprocal rank fusion", reciprocal_rank_fusion, retrieval.vector_results, retrieval.lexical_results, ) timings.append(timing) if config.rerank.enabled and rerank: response, timing = timed_result("Rerank", apply_rerank, query, fused, config) else: response, timing = timed_result("Rerank skipped", skip_rerank, query, fused, config) timings.append(timing) response = replace(response, timings=tuple(timings)) logger.info( "ebook_search_complete vector_candidates=%s lexical_candidates=%s " "fused_candidates=%s returned=%s rank_label=%s runtime_ms=%.1f", len(retrieval.vector_results), len(retrieval.lexical_results), len(fused), len(response.results), response.rank_label, response.total_runtime_ms, ) return response def parallel_retrieval( engine: Engine, vector_query: str, bm25_query: str, config: EbookSearchConfig, ) -> RetrievalResponse: """Run vector and BM25 candidate retrieval concurrently with separate database sessions.""" with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor: vector_future = executor.submit( timed_result, "Embedding + vector search", vector_candidates, engine, vector_query, config, ) bm25_future = executor.submit( timed_result, "BM25 search", bm25_candidates, bm25_query, config, ) vector_results, vector_timing = vector_future.result() lexical_results, lexical_timing = bm25_future.result() logger.info( "ebook_parallel_retrieval_complete vector_candidates=%s lexical_candidates=%s", len(vector_results), len(lexical_results), ) return RetrievalResponse( vector_results=vector_results, lexical_results=lexical_results, timings=( replace(vector_timing, counts_toward_total=False), replace(lexical_timing, counts_toward_total=False), ), ) def skip_rerank( query: str, candidates: list[SearchResult], config: EbookSearchConfig, ) -> SearchResponse: """Return fused hybrid results without reranking.""" logger.info("ebook_rerank_skipped candidates=%s", len(candidates)) return SearchResponse(query=query, results=candidates[: config.top_k], rank_label="Hybrid") def apply_rerank( query: str, candidates: list[SearchResult], config: EbookSearchConfig, ) -> SearchResponse: """Rerank already-fused hybrid candidates.""" reranked = rerank_chunks(query, candidates[: config.rerank.candidates], config.rerank) logger.info( "ebook_rerank_complete input_candidates=%s returned=%s", min(len(candidates), config.rerank.candidates), len(reranked), ) return SearchResponse( query=query, results=[replace(result, rank_source="Hybrid + rerank") for result in reranked[: config.top_k]], rank_label="Hybrid + rerank", ) def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]: """Return pgvector cosine candidates for a natural-language query.""" with Session(engine) as session: model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model)) if model is None: msg = f"Embedding model is not registered: {config.embedding_model}" raise ValueError(msg) expected_dimension = MODEL_DIMENSIONS[config.embedding_model] if model.dimension != expected_dimension: msg = f"Model row dimension {model.dimension} does not match configured dimension {expected_dimension}" raise ValueError(msg) embedding = embed_query(query, config) limit = max(config.rerank.candidates, config.top_k) * 4 embedding_table = get_embedding_table(model.dimension) embedding_param = literal(embedding, type_=Vector(model.dimension)) distance = embedding_table.embedding.op("<=>")(embedding_param) score = (literal(1.0) - distance).label("score") statement = ( select( EbookChunk.id.label("chunk_id"), EbookChunk.text.label("text"), EbookSource.title.label("source_title"), EbookSource.author.label("source_author"), EbookChapter.title.label("chapter_title"), EbookChunk.page_label.label("page_label"), score, ) .select_from(embedding_table) .join(EbookChunk, EbookChunk.id == embedding_table.chunk_id) .join(EbookSource, EbookSource.id == EbookChunk.source_id) .outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id) .where(embedding_table.model_id == model.id) .order_by(distance) .limit(limit) ) rows = session.execute(statement).mappings() results = [search_result_from_row(row) for row in rows] logger.info( "ebook_vector_search_complete model=%s dimension=%s candidates=%s", config.embedding_model, model.dimension, len(results), ) return results def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]: """Return BM25-ranked lexical candidates using the persisted corpus.""" corpus = load_bm25_corpus(config) if not corpus.records: logger.info("ebook_bm25_search_complete corpus=0 candidates=0") return [] scored_records = score_bm25_corpus(query, corpus, limit=BM25_CANDIDATE_LIMIT) results = [ replace(search_result_from_row(record), score=score, vector_score=None, bm25_score=score) for record, score in scored_records ] max_score = results[0].bm25_score if results else 0.0 logger.info( "ebook_bm25_search_complete corpus=%s candidates=%s max_score=%.6f", len(corpus.records), len(results), max_score, ) return results def reciprocal_rank_fusion( vector_results: list[SearchResult], lexical_results: list[SearchResult], *, rank_constant: int = 60, ) -> list[SearchResult]: """Fuse vector and lexical rankings with Reciprocal Rank Fusion.""" by_chunk: dict[int, SearchResult] = {} scores: dict[int, float] = {} vector_scores: dict[int, float] = {} bm25_scores: dict[int, float] = {} for rank, result in enumerate(vector_results, start=1): by_chunk.setdefault(result.chunk_id, result) vector_scores[result.chunk_id] = result.vector_score if result.vector_score is not None else result.score scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank)) for rank, result in enumerate(lexical_results, start=1): by_chunk.setdefault(result.chunk_id, result) bm25_scores[result.chunk_id] = result.bm25_score if result.bm25_score is not None else result.score scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank)) return sorted( ( replace( result, score=scores[result.chunk_id], vector_score=vector_scores.get(result.chunk_id), bm25_score=bm25_scores.get(result.chunk_id), fused_score=scores[result.chunk_id], rank_source="Hybrid", ) for result in by_chunk.values() ), key=lambda result: result.score, reverse=True, ) def search_result_from_row(row: Mapping[str, object]) -> SearchResult: """Convert a database row mapping into a search result.""" return SearchResult( chunk_id=int(row["chunk_id"]), text=str(row["text"]), source_title=str(row["source_title"]), source_author=optional_str(row["source_author"]), chapter_title=optional_str(row["chapter_title"]), page_label=optional_str(row["page_label"]), score=float(row["score"]) if "score" in row else 0.0, vector_score=float(row["score"]) if "score" in row else None, ) def optional_str(value: object) -> str | None: """Convert nullable database values to optional strings.""" if value is None: return None return str(value) TOKEN_RE = re.compile(r"[A-Za-z0-9_]+") def tokens(text_value: str) -> list[str]: """Extract tokens from a text value. This is a simple approximation of the tokenization used by PostgreSQL's full-text search, which is sufficient for BM25 candidate retrieval. It lowercases tokens and includes alphanumeric characters and underscores. """ return [match.group(0).lower() for match in TOKEN_RE.finditer(text_value)] QUERY_STOP_WORDS = { "a", "an", "and", "are", "as", "at", "does", "for", "in", "is", "of", "the", "to", "what", "when", "where", "which", "who", "why", } def retrieval_query_from_text(query: str) -> str: """Remove generic question words while preserving entity and series terms.""" keywords = [token for token in tokens(query) if token not in QUERY_STOP_WORDS] if not keywords: return query return " ".join(keywords)