From f71ae7d2c601d79b46b150e5d6ffd7c445f04fea Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Mon, 15 Jun 2026 21:57:38 -0400 Subject: [PATCH] added guardrails.py to constrain responses and added validation to config.py --- python/ebook_search/api/routes/search.py | 77 ++++++++++--- python/ebook_search/config.py | 20 +++- python/ebook_search/guardrails.py | 57 +++++++++ tests/test_ebook_search_guardrails.py | 141 +++++++++++++++++++++++ 4 files changed, 280 insertions(+), 15 deletions(-) create mode 100644 python/ebook_search/guardrails.py create mode 100644 tests/test_ebook_search_guardrails.py diff --git a/python/ebook_search/api/routes/search.py b/python/ebook_search/api/routes/search.py index 235dee2..9cc0e3e 100644 --- a/python/ebook_search/api/routes/search.py +++ b/python/ebook_search/api/routes/search.py @@ -11,8 +11,16 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse from python.ebook_search.answer import answer_query +from python.ebook_search.api.dependencies import AppConfig, AppEngine from python.ebook_search.api.web import templates -from python.ebook_search.search import search_ebooks +from python.ebook_search.config import EbookSearchConfig +from python.ebook_search.guardrails import ( + CitationReport, + is_confident, + retrieval_confidence, + validate_citations, +) +from python.ebook_search.search import SearchResponse, search_ebooks from python.ebook_search.timing import runtime_step_from_start logger = logging.getLogger(__name__) @@ -20,30 +28,64 @@ logger = logging.getLogger(__name__) router = APIRouter() +def build_answer( + query: str, + response: SearchResponse, + config: EbookSearchConfig, +) -> tuple[str, bool, CitationReport | None]: + """Generate the answer for a search, returning ``(answer, low_confidence, citation_report)``.""" + if not config.answer_enabled: + logger.info("ebook_answer_skipped_disabled") + return "Answer generation is disabled. Source chunks are shown below.", False, None + + if not is_confident(response.results, config): + logger.info( + "ebook_answer_low_confidence confidence=%.4f threshold=%.4f", + retrieval_confidence(response.results), + config.min_retrieval_confidence, + ) + answer = ( + "Retrieval confidence is low for this query, so answer generation was skipped. " + "Source chunks are shown below." + ) + return answer, True, None + + try: + answer = answer_query(query, response.results, config) + except RuntimeError as error: + logger.warning("ebook_answer_request_failed_falling_back error=%s", error) + return "Answer generation failed. Source chunks are still shown below.", False, None + + citation_report = None + if config.validate_citations_enabled and response.results: + citation_report = validate_citations(answer, len(response.results)) + if citation_report.invalid or not citation_report.grounded: + logger.warning( + "ebook_answer_citation_issue invalid=%s grounded=%s", + citation_report.invalid, + citation_report.grounded, + ) + return answer, False, citation_report + + @router.post("/search", response_class=HTMLResponse) def search( request: Request, + config: AppConfig, + engine: AppEngine, query: Annotated[str, Form()], rerank: Annotated[str | None, Form()] = None, ) -> HTMLResponse: """Run a search and render HTMX results.""" try: - response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true") + response = search_ebooks(engine, query, config, rerank=rerank == "true") except Exception as error: logger.exception("ebook_search_request_failed") return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500) answer_start = perf_counter() - if request.app.state.config.answer_enabled: - try: - answer = answer_query(query, response.results, request.app.state.config) - except RuntimeError as error: - logger.warning("ebook_answer_request_failed_falling_back error=%s", error) - answer = "Answer generation failed. Source chunks are still shown below." - else: - logger.info("ebook_answer_skipped_disabled") - answer = "Answer generation is disabled. Source chunks are shown below." - answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped" + answer, low_confidence, citation_report = build_answer(query, response, config) + answer_step_name = "Answer generation" if config.answer_enabled else "Answer skipped" response = replace( response, timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)), @@ -55,4 +97,13 @@ def search( response.rank_label, response.total_runtime_ms, ) - return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response}) + return templates.TemplateResponse( + request, + "partials/results.html", + { + "answer": answer, + "response": response, + "low_confidence": low_confidence, + "citation_report": citation_report, + }, + ) diff --git a/python/ebook_search/config.py b/python/ebook_search/config.py index 8a65c6e..dd1cd7b 100644 --- a/python/ebook_search/config.py +++ b/python/ebook_search/config.py @@ -3,9 +3,9 @@ from __future__ import annotations from os import getenv -from typing import Annotated +from typing import Annotated, Self -from pydantic import AliasChoices, Field, field_validator +from pydantic import AliasChoices, Field, field_validator, model_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict @@ -82,6 +82,8 @@ class EbookSearchConfig(BaseSettings): vector_candidate_multiplier: int = 4 bm25_candidate_limit: int = 120 rrf_rank_constant: int = 60 + min_retrieval_confidence: float = 0.0 + validate_citations_enabled: bool = True bm25_index_dir: str = ".ebook_search_bm25" bm25_refresh_delay_seconds: int = 60 @@ -99,6 +101,20 @@ class EbookSearchConfig(BaseSettings): """Normalize the configured embedding alias to its provider model name.""" return normalize_embedding_alias(value) + @model_validator(mode="after") + def validate_runtime_consistency(self) -> Self: + """Reject configurations that cannot serve the features they enable.""" + if not self.embedding_base_url.strip(): + msg = "embedding_base_url must be set" + raise ValueError(msg) + if self.answer_enabled and (not self.vllm_base_url.strip() or not self.chat_model.strip()): + msg = "answer_enabled requires vllm_base_url and chat_model to be set" + raise ValueError(msg) + if self.rerank.enabled and not self.rerank.base_url.strip(): + msg = "rerank.enabled requires rerank.base_url to be set" + raise ValueError(msg) + return self + def load_rerank_config() -> RerankConfig: """Load reranker config from environment variables.""" diff --git a/python/ebook_search/guardrails.py b/python/ebook_search/guardrails.py new file mode 100644 index 0000000..510dbeb --- /dev/null +++ b/python/ebook_search/guardrails.py @@ -0,0 +1,57 @@ +"""Serve-time output guardrails for retrieval confidence and answer citations.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from python.ebook_search.config import EbookSearchConfig + from python.ebook_search.search import SearchResult + +CITATION_RE = re.compile(r"\[(\d+)\]") + + +def retrieval_confidence(results: list[SearchResult]) -> float: + """Return the strongest interpretable relevance signal of the top result. + + Reciprocal-rank-fusion scores are rank-based and not comparable across queries, + so the rerank relevance score is preferred, then vector cosine similarity, then + the final score. + """ + if not results: + return 0.0 + top = results[0] + if top.rerank_score is not None: + return top.rerank_score + if top.vector_score is not None: + return top.vector_score + return top.score + + +def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool: + """Return whether top-result confidence meets the configured threshold.""" + return retrieval_confidence(results) >= config.min_retrieval_confidence + + +@dataclass(frozen=True) +class CitationReport: + """Validation summary for bracketed citation markers in a generated answer.""" + + cited: tuple[int, ...] + invalid: tuple[int, ...] + grounded: bool + + +def validate_citations(answer: str, result_count: int) -> CitationReport: + """Validate bracketed citation markers against the number of shown sources. + + A marker is valid when it points to a returned source (``1..result_count``). + ``grounded`` is true when the answer cites at least one valid source. + """ + markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)}) + valid = range(1, result_count + 1) + cited = tuple(marker for marker in markers if marker in valid) + invalid = tuple(marker for marker in markers if marker not in valid) + return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited)) diff --git a/tests/test_ebook_search_guardrails.py b/tests/test_ebook_search_guardrails.py new file mode 100644 index 0000000..cec944d --- /dev/null +++ b/tests/test_ebook_search_guardrails.py @@ -0,0 +1,141 @@ +"""Tests for serve-time output guardrails.""" + +from __future__ import annotations + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine + +from python.ebook_search.api.main import create_app +from python.ebook_search.config import EbookSearchConfig, RerankConfig +from python.ebook_search.guardrails import is_confident, retrieval_confidence, validate_citations +from python.ebook_search.search import SearchResponse, SearchResult + + +def make_results(count, *, vector_score=0.8): + return [ + SearchResult( + chunk_id=index, + text=f"source text {index}", + source_title="Book", + score=vector_score, + vector_score=vector_score, + ) + for index in range(1, count + 1) + ] + + +def test_validate_citations_partitions_markers() -> None: + report = validate_citations("Supported by [1] and [2].", result_count=3) + assert report.cited == (1, 2) + assert report.invalid == () + assert report.grounded is True + + +def test_validate_citations_flags_out_of_range_marker() -> None: + report = validate_citations("As shown in [5].", result_count=2) + assert report.cited == () + assert report.invalid == (5,) + assert report.grounded is False + + +def test_validate_citations_uncited_answer_is_not_grounded() -> None: + report = validate_citations("No citations at all.", result_count=2) + assert report.cited == () + assert report.invalid == () + assert report.grounded is False + + +def test_retrieval_confidence_prefers_rerank_then_vector() -> None: + assert retrieval_confidence([]) == 0.0 + rerank_top = [SearchResult(chunk_id=1, text="t", source_title="B", rerank_score=0.7, vector_score=0.2)] + assert retrieval_confidence(rerank_top) == 0.7 + vector_top = [SearchResult(chunk_id=1, text="t", source_title="B", vector_score=0.5)] + assert retrieval_confidence(vector_top) == 0.5 + + +def test_is_confident_against_threshold() -> None: + config = EbookSearchConfig(rerank=RerankConfig(enabled=False), min_retrieval_confidence=0.5) + assert is_confident(make_results(1, vector_score=0.6), config) is True + assert is_confident(make_results(1, vector_score=0.4), config) is False + + +def patch_app_runtime(monkeypatch): + monkeypatch.setattr( + "python.ebook_search.api.main.get_postgres_engine", + lambda **_kwargs: create_engine("sqlite+pysqlite:///:memory:", future=True), + ) + monkeypatch.setattr("python.ebook_search.api.main.ensure_bm25_corpus", lambda _session, _config: None) + + +def test_low_confidence_skips_answer_generation(monkeypatch) -> None: + called = False + + def fake_search_ebooks(_engine, query, _config, *, rerank=False): + del rerank + return SearchResponse(query=query, rank_label="Hybrid", results=make_results(1, vector_score=0.05)) + + def fake_answer_query(_query, _results, _config): + nonlocal called + called = True + return "answer" + + monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks) + monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query) + patch_app_runtime(monkeypatch) + app = create_app() + app.state.config = EbookSearchConfig( + rerank=RerankConfig(enabled=False), + answer_enabled=True, + min_retrieval_confidence=0.5, + ) + + with TestClient(app) as client: + response = client.post("/search", data={"query": "q"}) + + assert response.status_code == 200 + assert called is False + assert "Low retrieval confidence" in response.text + + +def test_invalid_citation_is_flagged(monkeypatch) -> None: + def fake_search_ebooks(_engine, query, _config, *, rerank=False): + del rerank + return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9)) + + monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks) + monkeypatch.setattr( + "python.ebook_search.api.routes.search.answer_query", + lambda _query, _results, _config: "Per the text [9].", + ) + patch_app_runtime(monkeypatch) + app = create_app() + app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True) + + with TestClient(app) as client: + response = client.post("/search", data={"query": "q"}) + + assert response.status_code == 200 + assert "Invalid citations" in response.text + assert "9" in response.text + + +def test_grounded_answer_has_no_warning_badge(monkeypatch) -> None: + def fake_search_ebooks(_engine, query, _config, *, rerank=False): + del rerank + return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9)) + + monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks) + monkeypatch.setattr( + "python.ebook_search.api.routes.search.answer_query", + lambda _query, _results, _config: "Grounded in [1] and [2].", + ) + patch_app_runtime(monkeypatch) + app = create_app() + app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True) + + with TestClient(app) as client: + response = client.post("/search", data={"query": "q"}) + + assert response.status_code == 200 + assert "Unverified" not in response.text + assert "Invalid citations" not in response.text