dbc6b5b53b
Move ebook search tests into tests/ebook_search and standardize mocking on pytest-mock.
51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
"""Tests for the ebook search RAG pipeline orchestration."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from threading import Event
|
|
from typing import TYPE_CHECKING
|
|
|
|
from sqlalchemy import create_engine
|
|
|
|
from python.ebook_search.config import EbookSearchConfig, RerankConfig
|
|
from python.ebook_search.search import SearchResult, search_ebooks
|
|
|
|
if TYPE_CHECKING:
|
|
from pytest_mock import MockerFixture
|
|
|
|
|
|
def test_search_ebooks_runs_vector_and_bm25_in_parallel(mocker: MockerFixture) -> None:
|
|
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
|
vector_started = Event()
|
|
bm25_started = Event()
|
|
received_engines: list[object] = []
|
|
|
|
def fake_vector_candidates(received_engine, query, _config):
|
|
"""Return vector candidates after confirming BM25 has started."""
|
|
received_engines.append(received_engine)
|
|
assert query == "what is parallel"
|
|
vector_started.set()
|
|
assert bm25_started.wait(timeout=2)
|
|
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
|
|
|
|
def fake_bm25_candidates(query, _config):
|
|
"""Return BM25 candidates after confirming vector search has started."""
|
|
assert query == "parallel"
|
|
bm25_started.set()
|
|
assert vector_started.wait(timeout=2)
|
|
return [SearchResult(chunk_id=2, text="bm25", source_title="BM25", bm25_score=2.0)]
|
|
|
|
mocker.patch("python.ebook_search.search.vector_candidates", side_effect=fake_vector_candidates)
|
|
mocker.patch("python.ebook_search.search.bm25_candidates", side_effect=fake_bm25_candidates)
|
|
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
|
|
|
|
response = search_ebooks(engine, "what is parallel", config)
|
|
|
|
timings = {step.name: step for step in response.timings}
|
|
assert [result.chunk_id for result in response.results] == [1, 2]
|
|
assert timings["Embedding + vector search"].counts_toward_total is False
|
|
assert timings["BM25 search"].counts_toward_total is False
|
|
assert timings["Hybrid retrieval"].counts_toward_total is True
|
|
assert timings["BM25 query preparation"].counts_toward_total is True
|
|
assert received_engines == [engine]
|