added guardrails.py to constrain responses and added validation to config.py

This commit is contained in:
2026-06-15 21:57:38 -04:00
parent 2e68c83021
commit f71ae7d2c6
4 changed files with 280 additions and 15 deletions
+64 -13
View File
@@ -11,8 +11,16 @@ from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse
from python.ebook_search.answer import answer_query
from python.ebook_search.api.dependencies import AppConfig, AppEngine
from python.ebook_search.api.web import templates
from python.ebook_search.search import search_ebooks
from python.ebook_search.config import EbookSearchConfig
from python.ebook_search.guardrails import (
CitationReport,
is_confident,
retrieval_confidence,
validate_citations,
)
from python.ebook_search.search import SearchResponse, search_ebooks
from python.ebook_search.timing import runtime_step_from_start
logger = logging.getLogger(__name__)
@@ -20,30 +28,64 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def build_answer(
query: str,
response: SearchResponse,
config: EbookSearchConfig,
) -> tuple[str, bool, CitationReport | None]:
"""Generate the answer for a search, returning ``(answer, low_confidence, citation_report)``."""
if not config.answer_enabled:
logger.info("ebook_answer_skipped_disabled")
return "Answer generation is disabled. Source chunks are shown below.", False, None
if not is_confident(response.results, config):
logger.info(
"ebook_answer_low_confidence confidence=%.4f threshold=%.4f",
retrieval_confidence(response.results),
config.min_retrieval_confidence,
)
answer = (
"Retrieval confidence is low for this query, so answer generation was skipped. "
"Source chunks are shown below."
)
return answer, True, None
try:
answer = answer_query(query, response.results, config)
except RuntimeError as error:
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
return "Answer generation failed. Source chunks are still shown below.", False, None
citation_report = None
if config.validate_citations_enabled and response.results:
citation_report = validate_citations(answer, len(response.results))
if citation_report.invalid or not citation_report.grounded:
logger.warning(
"ebook_answer_citation_issue invalid=%s grounded=%s",
citation_report.invalid,
citation_report.grounded,
)
return answer, False, citation_report
@router.post("/search", response_class=HTMLResponse)
def search(
request: Request,
config: AppConfig,
engine: AppEngine,
query: Annotated[str, Form()],
rerank: Annotated[str | None, Form()] = None,
) -> HTMLResponse:
"""Run a search and render HTMX results."""
try:
response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true")
response = search_ebooks(engine, query, config, rerank=rerank == "true")
except Exception as error:
logger.exception("ebook_search_request_failed")
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
answer_start = perf_counter()
if request.app.state.config.answer_enabled:
try:
answer = answer_query(query, response.results, request.app.state.config)
except RuntimeError as error:
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
answer = "Answer generation failed. Source chunks are still shown below."
else:
logger.info("ebook_answer_skipped_disabled")
answer = "Answer generation is disabled. Source chunks are shown below."
answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped"
answer, low_confidence, citation_report = build_answer(query, response, config)
answer_step_name = "Answer generation" if config.answer_enabled else "Answer skipped"
response = replace(
response,
timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)),
@@ -55,4 +97,13 @@ def search(
response.rank_label,
response.total_runtime_ms,
)
return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response})
return templates.TemplateResponse(
request,
"partials/results.html",
{
"answer": answer,
"response": response,
"low_confidence": low_confidence,
"citation_report": citation_report,
},
)
+18 -2
View File
@@ -3,9 +3,9 @@
from __future__ import annotations
from os import getenv
from typing import Annotated
from typing import Annotated, Self
from pydantic import AliasChoices, Field, field_validator
from pydantic import AliasChoices, Field, field_validator, model_validator
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
@@ -82,6 +82,8 @@ class EbookSearchConfig(BaseSettings):
vector_candidate_multiplier: int = 4
bm25_candidate_limit: int = 120
rrf_rank_constant: int = 60
min_retrieval_confidence: float = 0.0
validate_citations_enabled: bool = True
bm25_index_dir: str = ".ebook_search_bm25"
bm25_refresh_delay_seconds: int = 60
@@ -99,6 +101,20 @@ class EbookSearchConfig(BaseSettings):
"""Normalize the configured embedding alias to its provider model name."""
return normalize_embedding_alias(value)
@model_validator(mode="after")
def validate_runtime_consistency(self) -> Self:
"""Reject configurations that cannot serve the features they enable."""
if not self.embedding_base_url.strip():
msg = "embedding_base_url must be set"
raise ValueError(msg)
if self.answer_enabled and (not self.vllm_base_url.strip() or not self.chat_model.strip()):
msg = "answer_enabled requires vllm_base_url and chat_model to be set"
raise ValueError(msg)
if self.rerank.enabled and not self.rerank.base_url.strip():
msg = "rerank.enabled requires rerank.base_url to be set"
raise ValueError(msg)
return self
def load_rerank_config() -> RerankConfig:
"""Load reranker config from environment variables."""
+57
View File
@@ -0,0 +1,57 @@
"""Serve-time output guardrails for retrieval confidence and answer citations."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from python.ebook_search.config import EbookSearchConfig
from python.ebook_search.search import SearchResult
CITATION_RE = re.compile(r"\[(\d+)\]")
def retrieval_confidence(results: list[SearchResult]) -> float:
"""Return the strongest interpretable relevance signal of the top result.
Reciprocal-rank-fusion scores are rank-based and not comparable across queries,
so the rerank relevance score is preferred, then vector cosine similarity, then
the final score.
"""
if not results:
return 0.0
top = results[0]
if top.rerank_score is not None:
return top.rerank_score
if top.vector_score is not None:
return top.vector_score
return top.score
def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool:
"""Return whether top-result confidence meets the configured threshold."""
return retrieval_confidence(results) >= config.min_retrieval_confidence
@dataclass(frozen=True)
class CitationReport:
"""Validation summary for bracketed citation markers in a generated answer."""
cited: tuple[int, ...]
invalid: tuple[int, ...]
grounded: bool
def validate_citations(answer: str, result_count: int) -> CitationReport:
"""Validate bracketed citation markers against the number of shown sources.
A marker is valid when it points to a returned source (``1..result_count``).
``grounded`` is true when the answer cites at least one valid source.
"""
markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)})
valid = range(1, result_count + 1)
cited = tuple(marker for marker in markers if marker in valid)
invalid = tuple(marker for marker in markers if marker not in valid)
return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited))
+141
View File
@@ -0,0 +1,141 @@
"""Tests for serve-time output guardrails."""
from __future__ import annotations
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from python.ebook_search.api.main import create_app
from python.ebook_search.config import EbookSearchConfig, RerankConfig
from python.ebook_search.guardrails import is_confident, retrieval_confidence, validate_citations
from python.ebook_search.search import SearchResponse, SearchResult
def make_results(count, *, vector_score=0.8):
return [
SearchResult(
chunk_id=index,
text=f"source text {index}",
source_title="Book",
score=vector_score,
vector_score=vector_score,
)
for index in range(1, count + 1)
]
def test_validate_citations_partitions_markers() -> None:
report = validate_citations("Supported by [1] and [2].", result_count=3)
assert report.cited == (1, 2)
assert report.invalid == ()
assert report.grounded is True
def test_validate_citations_flags_out_of_range_marker() -> None:
report = validate_citations("As shown in [5].", result_count=2)
assert report.cited == ()
assert report.invalid == (5,)
assert report.grounded is False
def test_validate_citations_uncited_answer_is_not_grounded() -> None:
report = validate_citations("No citations at all.", result_count=2)
assert report.cited == ()
assert report.invalid == ()
assert report.grounded is False
def test_retrieval_confidence_prefers_rerank_then_vector() -> None:
assert retrieval_confidence([]) == 0.0
rerank_top = [SearchResult(chunk_id=1, text="t", source_title="B", rerank_score=0.7, vector_score=0.2)]
assert retrieval_confidence(rerank_top) == 0.7
vector_top = [SearchResult(chunk_id=1, text="t", source_title="B", vector_score=0.5)]
assert retrieval_confidence(vector_top) == 0.5
def test_is_confident_against_threshold() -> None:
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), min_retrieval_confidence=0.5)
assert is_confident(make_results(1, vector_score=0.6), config) is True
assert is_confident(make_results(1, vector_score=0.4), config) is False
def patch_app_runtime(monkeypatch):
monkeypatch.setattr(
"python.ebook_search.api.main.get_postgres_engine",
lambda **_kwargs: create_engine("sqlite+pysqlite:///:memory:", future=True),
)
monkeypatch.setattr("python.ebook_search.api.main.ensure_bm25_corpus", lambda _session, _config: None)
def test_low_confidence_skips_answer_generation(monkeypatch) -> None:
called = False
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(1, vector_score=0.05))
def fake_answer_query(_query, _results, _config):
nonlocal called
called = True
return "answer"
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query)
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(
rerank=RerankConfig(enabled=False),
answer_enabled=True,
min_retrieval_confidence=0.5,
)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert called is False
assert "Low retrieval confidence" in response.text
def test_invalid_citation_is_flagged(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query",
lambda _query, _results, _config: "Per the text [9].",
)
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert "Invalid citations" in response.text
assert "9" in response.text
def test_grounded_answer_has_no_warning_badge(monkeypatch) -> None:
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
del rerank
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
monkeypatch.setattr(
"python.ebook_search.api.routes.search.answer_query",
lambda _query, _results, _config: "Grounded in [1] and [2].",
)
patch_app_runtime(monkeypatch)
app = create_app()
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
with TestClient(app) as client:
response = client.post("/search", data={"query": "q"})
assert response.status_code == 200
assert "Unverified" not in response.text
assert "Invalid citations" not in response.text