added guardrails.py to constrain responses and added validation to config.py
This commit is contained in:
@@ -11,8 +11,16 @@ from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from python.ebook_search.answer import answer_query
|
||||
from python.ebook_search.api.dependencies import AppConfig, AppEngine
|
||||
from python.ebook_search.api.web import templates
|
||||
from python.ebook_search.search import search_ebooks
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
from python.ebook_search.guardrails import (
|
||||
CitationReport,
|
||||
is_confident,
|
||||
retrieval_confidence,
|
||||
validate_citations,
|
||||
)
|
||||
from python.ebook_search.search import SearchResponse, search_ebooks
|
||||
from python.ebook_search.timing import runtime_step_from_start
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,30 +28,64 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def build_answer(
|
||||
query: str,
|
||||
response: SearchResponse,
|
||||
config: EbookSearchConfig,
|
||||
) -> tuple[str, bool, CitationReport | None]:
|
||||
"""Generate the answer for a search, returning ``(answer, low_confidence, citation_report)``."""
|
||||
if not config.answer_enabled:
|
||||
logger.info("ebook_answer_skipped_disabled")
|
||||
return "Answer generation is disabled. Source chunks are shown below.", False, None
|
||||
|
||||
if not is_confident(response.results, config):
|
||||
logger.info(
|
||||
"ebook_answer_low_confidence confidence=%.4f threshold=%.4f",
|
||||
retrieval_confidence(response.results),
|
||||
config.min_retrieval_confidence,
|
||||
)
|
||||
answer = (
|
||||
"Retrieval confidence is low for this query, so answer generation was skipped. "
|
||||
"Source chunks are shown below."
|
||||
)
|
||||
return answer, True, None
|
||||
|
||||
try:
|
||||
answer = answer_query(query, response.results, config)
|
||||
except RuntimeError as error:
|
||||
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
|
||||
return "Answer generation failed. Source chunks are still shown below.", False, None
|
||||
|
||||
citation_report = None
|
||||
if config.validate_citations_enabled and response.results:
|
||||
citation_report = validate_citations(answer, len(response.results))
|
||||
if citation_report.invalid or not citation_report.grounded:
|
||||
logger.warning(
|
||||
"ebook_answer_citation_issue invalid=%s grounded=%s",
|
||||
citation_report.invalid,
|
||||
citation_report.grounded,
|
||||
)
|
||||
return answer, False, citation_report
|
||||
|
||||
|
||||
@router.post("/search", response_class=HTMLResponse)
|
||||
def search(
|
||||
request: Request,
|
||||
config: AppConfig,
|
||||
engine: AppEngine,
|
||||
query: Annotated[str, Form()],
|
||||
rerank: Annotated[str | None, Form()] = None,
|
||||
) -> HTMLResponse:
|
||||
"""Run a search and render HTMX results."""
|
||||
try:
|
||||
response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true")
|
||||
response = search_ebooks(engine, query, config, rerank=rerank == "true")
|
||||
except Exception as error:
|
||||
logger.exception("ebook_search_request_failed")
|
||||
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
|
||||
|
||||
answer_start = perf_counter()
|
||||
if request.app.state.config.answer_enabled:
|
||||
try:
|
||||
answer = answer_query(query, response.results, request.app.state.config)
|
||||
except RuntimeError as error:
|
||||
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
|
||||
answer = "Answer generation failed. Source chunks are still shown below."
|
||||
else:
|
||||
logger.info("ebook_answer_skipped_disabled")
|
||||
answer = "Answer generation is disabled. Source chunks are shown below."
|
||||
answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped"
|
||||
answer, low_confidence, citation_report = build_answer(query, response, config)
|
||||
answer_step_name = "Answer generation" if config.answer_enabled else "Answer skipped"
|
||||
response = replace(
|
||||
response,
|
||||
timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)),
|
||||
@@ -55,4 +97,13 @@ def search(
|
||||
response.rank_label,
|
||||
response.total_runtime_ms,
|
||||
)
|
||||
return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response})
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/results.html",
|
||||
{
|
||||
"answer": answer,
|
||||
"response": response,
|
||||
"low_confidence": low_confidence,
|
||||
"citation_report": citation_report,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from os import getenv
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Self
|
||||
|
||||
from pydantic import AliasChoices, Field, field_validator
|
||||
from pydantic import AliasChoices, Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
||||
|
||||
|
||||
@@ -82,6 +82,8 @@ class EbookSearchConfig(BaseSettings):
|
||||
vector_candidate_multiplier: int = 4
|
||||
bm25_candidate_limit: int = 120
|
||||
rrf_rank_constant: int = 60
|
||||
min_retrieval_confidence: float = 0.0
|
||||
validate_citations_enabled: bool = True
|
||||
bm25_index_dir: str = ".ebook_search_bm25"
|
||||
bm25_refresh_delay_seconds: int = 60
|
||||
|
||||
@@ -99,6 +101,20 @@ class EbookSearchConfig(BaseSettings):
|
||||
"""Normalize the configured embedding alias to its provider model name."""
|
||||
return normalize_embedding_alias(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runtime_consistency(self) -> Self:
|
||||
"""Reject configurations that cannot serve the features they enable."""
|
||||
if not self.embedding_base_url.strip():
|
||||
msg = "embedding_base_url must be set"
|
||||
raise ValueError(msg)
|
||||
if self.answer_enabled and (not self.vllm_base_url.strip() or not self.chat_model.strip()):
|
||||
msg = "answer_enabled requires vllm_base_url and chat_model to be set"
|
||||
raise ValueError(msg)
|
||||
if self.rerank.enabled and not self.rerank.base_url.strip():
|
||||
msg = "rerank.enabled requires rerank.base_url to be set"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
|
||||
def load_rerank_config() -> RerankConfig:
|
||||
"""Load reranker config from environment variables."""
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Serve-time output guardrails for retrieval confidence and answer citations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
from python.ebook_search.search import SearchResult
|
||||
|
||||
CITATION_RE = re.compile(r"\[(\d+)\]")
|
||||
|
||||
|
||||
def retrieval_confidence(results: list[SearchResult]) -> float:
|
||||
"""Return the strongest interpretable relevance signal of the top result.
|
||||
|
||||
Reciprocal-rank-fusion scores are rank-based and not comparable across queries,
|
||||
so the rerank relevance score is preferred, then vector cosine similarity, then
|
||||
the final score.
|
||||
"""
|
||||
if not results:
|
||||
return 0.0
|
||||
top = results[0]
|
||||
if top.rerank_score is not None:
|
||||
return top.rerank_score
|
||||
if top.vector_score is not None:
|
||||
return top.vector_score
|
||||
return top.score
|
||||
|
||||
|
||||
def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool:
|
||||
"""Return whether top-result confidence meets the configured threshold."""
|
||||
return retrieval_confidence(results) >= config.min_retrieval_confidence
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CitationReport:
|
||||
"""Validation summary for bracketed citation markers in a generated answer."""
|
||||
|
||||
cited: tuple[int, ...]
|
||||
invalid: tuple[int, ...]
|
||||
grounded: bool
|
||||
|
||||
|
||||
def validate_citations(answer: str, result_count: int) -> CitationReport:
|
||||
"""Validate bracketed citation markers against the number of shown sources.
|
||||
|
||||
A marker is valid when it points to a returned source (``1..result_count``).
|
||||
``grounded`` is true when the answer cites at least one valid source.
|
||||
"""
|
||||
markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)})
|
||||
valid = range(1, result_count + 1)
|
||||
cited = tuple(marker for marker in markers if marker in valid)
|
||||
invalid = tuple(marker for marker in markers if marker not in valid)
|
||||
return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited))
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Tests for serve-time output guardrails."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from python.ebook_search.api.main import create_app
|
||||
from python.ebook_search.config import EbookSearchConfig, RerankConfig
|
||||
from python.ebook_search.guardrails import is_confident, retrieval_confidence, validate_citations
|
||||
from python.ebook_search.search import SearchResponse, SearchResult
|
||||
|
||||
|
||||
def make_results(count, *, vector_score=0.8):
|
||||
return [
|
||||
SearchResult(
|
||||
chunk_id=index,
|
||||
text=f"source text {index}",
|
||||
source_title="Book",
|
||||
score=vector_score,
|
||||
vector_score=vector_score,
|
||||
)
|
||||
for index in range(1, count + 1)
|
||||
]
|
||||
|
||||
|
||||
def test_validate_citations_partitions_markers() -> None:
|
||||
report = validate_citations("Supported by [1] and [2].", result_count=3)
|
||||
assert report.cited == (1, 2)
|
||||
assert report.invalid == ()
|
||||
assert report.grounded is True
|
||||
|
||||
|
||||
def test_validate_citations_flags_out_of_range_marker() -> None:
|
||||
report = validate_citations("As shown in [5].", result_count=2)
|
||||
assert report.cited == ()
|
||||
assert report.invalid == (5,)
|
||||
assert report.grounded is False
|
||||
|
||||
|
||||
def test_validate_citations_uncited_answer_is_not_grounded() -> None:
|
||||
report = validate_citations("No citations at all.", result_count=2)
|
||||
assert report.cited == ()
|
||||
assert report.invalid == ()
|
||||
assert report.grounded is False
|
||||
|
||||
|
||||
def test_retrieval_confidence_prefers_rerank_then_vector() -> None:
|
||||
assert retrieval_confidence([]) == 0.0
|
||||
rerank_top = [SearchResult(chunk_id=1, text="t", source_title="B", rerank_score=0.7, vector_score=0.2)]
|
||||
assert retrieval_confidence(rerank_top) == 0.7
|
||||
vector_top = [SearchResult(chunk_id=1, text="t", source_title="B", vector_score=0.5)]
|
||||
assert retrieval_confidence(vector_top) == 0.5
|
||||
|
||||
|
||||
def test_is_confident_against_threshold() -> None:
|
||||
config = EbookSearchConfig(rerank=RerankConfig(enabled=False), min_retrieval_confidence=0.5)
|
||||
assert is_confident(make_results(1, vector_score=0.6), config) is True
|
||||
assert is_confident(make_results(1, vector_score=0.4), config) is False
|
||||
|
||||
|
||||
def patch_app_runtime(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"python.ebook_search.api.main.get_postgres_engine",
|
||||
lambda **_kwargs: create_engine("sqlite+pysqlite:///:memory:", future=True),
|
||||
)
|
||||
monkeypatch.setattr("python.ebook_search.api.main.ensure_bm25_corpus", lambda _session, _config: None)
|
||||
|
||||
|
||||
def test_low_confidence_skips_answer_generation(monkeypatch) -> None:
|
||||
called = False
|
||||
|
||||
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
|
||||
del rerank
|
||||
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(1, vector_score=0.05))
|
||||
|
||||
def fake_answer_query(_query, _results, _config):
|
||||
nonlocal called
|
||||
called = True
|
||||
return "answer"
|
||||
|
||||
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
|
||||
monkeypatch.setattr("python.ebook_search.api.routes.search.answer_query", fake_answer_query)
|
||||
patch_app_runtime(monkeypatch)
|
||||
app = create_app()
|
||||
app.state.config = EbookSearchConfig(
|
||||
rerank=RerankConfig(enabled=False),
|
||||
answer_enabled=True,
|
||||
min_retrieval_confidence=0.5,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/search", data={"query": "q"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert called is False
|
||||
assert "Low retrieval confidence" in response.text
|
||||
|
||||
|
||||
def test_invalid_citation_is_flagged(monkeypatch) -> None:
|
||||
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
|
||||
del rerank
|
||||
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
|
||||
|
||||
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
|
||||
monkeypatch.setattr(
|
||||
"python.ebook_search.api.routes.search.answer_query",
|
||||
lambda _query, _results, _config: "Per the text [9].",
|
||||
)
|
||||
patch_app_runtime(monkeypatch)
|
||||
app = create_app()
|
||||
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/search", data={"query": "q"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "Invalid citations" in response.text
|
||||
assert "9" in response.text
|
||||
|
||||
|
||||
def test_grounded_answer_has_no_warning_badge(monkeypatch) -> None:
|
||||
def fake_search_ebooks(_engine, query, _config, *, rerank=False):
|
||||
del rerank
|
||||
return SearchResponse(query=query, rank_label="Hybrid", results=make_results(2, vector_score=0.9))
|
||||
|
||||
monkeypatch.setattr("python.ebook_search.api.routes.search.search_ebooks", fake_search_ebooks)
|
||||
monkeypatch.setattr(
|
||||
"python.ebook_search.api.routes.search.answer_query",
|
||||
lambda _query, _results, _config: "Grounded in [1] and [2].",
|
||||
)
|
||||
patch_app_runtime(monkeypatch)
|
||||
app = create_app()
|
||||
app.state.config = EbookSearchConfig(rerank=RerankConfig(enabled=False), answer_enabled=True)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/search", data={"query": "q"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "Unverified" not in response.text
|
||||
assert "Invalid citations" not in response.text
|
||||
Reference in New Issue
Block a user