Files
dotfiles/tests/test_ebook_search_guardrails.py
T

142 lines
5.3 KiB
Python

"""Tests for serve-time output guardrails."""
from __future__ import annotations
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from python.ebook_search.api.main import create_app
from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.guardrails import is_confident, retrieval_confidence, validate_citations
from python.ebook_search.search import SearchResponse, SearchResult
def make_results(count, *, vector_score=0.8):
return [
SearchResult(
chunk_id=index,
text=f"source text {index}",
source_title="Book",
score=vector_score,
vector_score=vector_score,
)
for index in range(1, count + 1)
]
def test_validate_citations_partitions_markers() -> None:
report = validate_citations("Supported by [1] and [2].", result_count=3)
assert report.cited == (1, 2)
assert report.invalid == ()
assert report.grounded is True
def test_validate_citations_flags_out_of_range_marker() -> None:
report = validate_citations("As shown in [5].", result_count=2)
assert report.cited == ()
assert report.invalid == (5,)
assert report.grounded is False
def test_validate_citations_uncited_answer_is_not_grounded() -> None:
report = validate_citations("No citations at all.", result_count=2)
assert report.cited == ()
assert report.invalid == ()
assert report.grounded is False
def test_retrieval_confidence_prefers_rerank_then_vector() -> None:
assert retrieval_confidence([]) == 0.0
rerank_top = [SearchResult(chunk_id=1, text="t", source_title="B", rerank_score=0.7, vector_score=0.2)]
assert retrieval_confidence(rerank_top) == 0.7
vector_top = [SearchResult(chunk_id=1, text="t", source_title="B", vector_score=0.5)]
assert retrieval_confidence(vector_top) == 0.5
def test_is_confident_against_threshold() -> None:
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), min_retrieval_confidence=0.5)
assert is_confident(make_results(1, vector_score=0.6), config) is True
assert is_confident(make_results(1, vector_score=0.4), config) is False
def patch_app_runtime(monkeypatch):
monkeypatch.setattr(
"python.ebook_search.api.main.get_postgres_engine",
lambda **_kwargs: create_engine("sqlite+pysqlite:///:memory:", future=True),
)
monkeypatch.setattr("python.ebook_search.api.main.ensure_bm25_corpus", lambda _session, _config: None)
def test_low_confidence_skips_answer_generation(monkeypatch) -> None:
called = False
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(1, vector_score=0.05))
def fake_answer_query(_query, _results, _config):
nonlocal called
called = True
return "answer"
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query)
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(
rerank=RerankConfig(enabled=False),
answer_enabled=True,
min_retrieval_confidence=0.5,
)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert called is False
assert "Low retrieval confidence" in response.text
def test_invalid_citation_is_flagged(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query",
lambda _query, _results, _config: "Per the text [9].",
)
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert "Invalid citations" in response.text
assert "9" in response.text
def test_grounded_answer_has_no_warning_badge(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query",
lambda _query, _results, _config: "Grounded in [1] and [2].",
)
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert "Unverified" not in response.text
assert "Invalid citations" not in response.text