142 lines
5.3 KiB
Python
142 lines
5.3 KiB
Python
"""Tests for serve-time output guardrails."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
|
|
from python.ebook_search.api.main import create_app
|
|
from python.ebook_search.config import EbookSearchConfig, RerankConfig
|
|
from python.ebook_search.guardrails import is_confident, retrieval_confidence, validate_citations
|
|
from python.ebook_search.search import SearchResponse, SearchResult
|
|
|
|
|
|
def make_results(count, *, vector_score=0.8):
|
|
return [
|
|
SearchResult(
|
|
chunk_id=index,
|
|
text=f"source text {index}",
|
|
source_title="Book",
|
|
score=vector_score,
|
|
vector_score=vector_score,
|
|
)
|
|
for index in range(1, count + 1)
|
|
]
|
|
|
|
|
|
def test_validate_citations_partitions_markers() -> None:
|
|
report = validate_citations("Supported by [1] and [2].", result_count=3)
|
|
assert report.cited == (1, 2)
|
|
assert report.invalid == ()
|
|
assert report.grounded is True
|
|
|
|
|
|
def test_validate_citations_flags_out_of_range_marker() -> None:
|
|
report = validate_citations("As shown in [5].", result_count=2)
|
|
assert report.cited == ()
|
|
assert report.invalid == (5,)
|
|
assert report.grounded is False
|
|
|
|
|
|
def test_validate_citations_uncited_answer_is_not_grounded() -> None:
|
|
report = validate_citations("No citations at all.", result_count=2)
|
|
assert report.cited == ()
|
|
assert report.invalid == ()
|
|
assert report.grounded is False
|
|
|
|
|
|
def test_retrieval_confidence_prefers_rerank_then_vector() -> None:
|
|
assert retrieval_confidence([]) == 0.0
|
|
rerank_top = [SearchResult(chunk_id=1, text="t", source_title="B", rerank_score=0.7, vector_score=0.2)]
|
|
assert retrieval_confidence(rerank_top) == 0.7
|
|
vector_top = [SearchResult(chunk_id=1, text="t", source_title="B", vector_score=0.5)]
|
|
assert retrieval_confidence(vector_top) == 0.5
|
|
|
|
|
|
def test_is_confident_against_threshold() -> None:
|
|
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), min_retrieval_confidence=0.5)
|
|
assert is_confident(make_results(1, vector_score=0.6), config) is True
|
|
assert is_confident(make_results(1, vector_score=0.4), config) is False
|
|
|
|
|
|
def patch_app_runtime(monkeypatch):
|
|
monkeypatch.setattr(
|
|
"python.ebook_search.api.main.get_postgres_engine",
|
|
lambda **_kwargs: create_engine("sqlite+pysqlite:///:memory:", future=True),
|
|
)
|
|
monkeypatch.setattr("python.ebook_search.api.main.ensure_bm25_corpus", lambda _session, _config: None)
|
|
|
|
|
|
def test_low_confidence_skips_answer_generation(monkeypatch) -> None:
|
|
called = False
|
|
|
|
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
|
|
del rerank
|
|
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(1, vector_score=0.05))
|
|
|
|
def fake_answer_query(_query, _results, _config):
|
|
nonlocal called
|
|
called = True
|
|
return "answer"
|
|
|
|
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
|
|
monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query)
|
|
patch_app_runtime(monkeypatch)
|
|
app = create_app()
|
|
app.state.config = EbookSearchConfig(
|
|
rerank=RerankConfig(enabled=False),
|
|
answer_enabled=True,
|
|
min_retrieval_confidence=0.5,
|
|
)
|
|
|
|
with TestClient(app) as client:
|
|
response = client.post("/search", data={"query": "q"})
|
|
|
|
assert response.status_code == 200
|
|
assert called is False
|
|
assert "Low retrieval confidence" in response.text
|
|
|
|
|
|
def test_invalid_citation_is_flagged(monkeypatch) -> None:
|
|
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
|
|
del rerank
|
|
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
|
|
|
|
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
|
|
monkeypatch.setattr(
|
|
"python.ebook_search.api.routes.search.answer_query",
|
|
lambda _query, _results, _config: "Per the text [9].",
|
|
)
|
|
patch_app_runtime(monkeypatch)
|
|
app = create_app()
|
|
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
|
|
|
|
with TestClient(app) as client:
|
|
response = client.post("/search", data={"query": "q"})
|
|
|
|
assert response.status_code == 200
|
|
assert "Invalid citations" in response.text
|
|
assert "9" in response.text
|
|
|
|
|
|
def test_grounded_answer_has_no_warning_badge(monkeypatch) -> None:
|
|
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
|
|
del rerank
|
|
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
|
|
|
|
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
|
|
monkeypatch.setattr(
|
|
"python.ebook_search.api.routes.search.answer_query",
|
|
lambda _query, _results, _config: "Grounded in [1] and [2].",
|
|
)
|
|
patch_app_runtime(monkeypatch)
|
|
app = create_app()
|
|
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
|
|
|
|
with TestClient(app) as client:
|
|
response = client.post("/search", data={"query": "q"})
|
|
|
|
assert response.status_code == 200
|
|
assert "Unverified" not in response.text
|
|
assert "Invalid citations" not in response.text
|