added guardrails.py to constrain responses and added validation to config.py
This commit is contained in:
@@ -11,8 +11,16 @@ from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from python.ebook_search.answer import answer_query
|
||||
from python.ebook_search.api.dependencies import AppConfig, AppEngine
|
||||
from python.ebook_search.api.web import templates
|
||||
from python.ebook_search.search import search_ebooks
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
from python.ebook_search.guardrails import (
|
||||
CitationReport,
|
||||
is_confident,
|
||||
retrieval_confidence,
|
||||
validate_citations,
|
||||
)
|
||||
from python.ebook_search.search import SearchResponse, search_ebooks
|
||||
from python.ebook_search.timing import runtime_step_from_start
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,30 +28,64 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def build_answer(
|
||||
query: str,
|
||||
response: SearchResponse,
|
||||
config: EbookSearchConfig,
|
||||
) -> tuple[str, bool, CitationReport | None]:
|
||||
"""Generate the answer for a search, returning ``(answer, low_confidence, citation_report)``."""
|
||||
if not config.answer_enabled:
|
||||
logger.info("ebook_answer_skipped_disabled")
|
||||
return "Answer generation is disabled. Source chunks are shown below.", False, None
|
||||
|
||||
if not is_confident(response.results, config):
|
||||
logger.info(
|
||||
"ebook_answer_low_confidence confidence=%.4f threshold=%.4f",
|
||||
retrieval_confidence(response.results),
|
||||
config.min_retrieval_confidence,
|
||||
)
|
||||
answer = (
|
||||
"Retrieval confidence is low for this query, so answer generation was skipped. "
|
||||
"Source chunks are shown below."
|
||||
)
|
||||
return answer, True, None
|
||||
|
||||
try:
|
||||
answer = answer_query(query, response.results, config)
|
||||
except RuntimeError as error:
|
||||
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
|
||||
return "Answer generation failed. Source chunks are still shown below.", False, None
|
||||
|
||||
citation_report = None
|
||||
if config.validate_citations_enabled and response.results:
|
||||
citation_report = validate_citations(answer, len(response.results))
|
||||
if citation_report.invalid or not citation_report.grounded:
|
||||
logger.warning(
|
||||
"ebook_answer_citation_issue invalid=%s grounded=%s",
|
||||
citation_report.invalid,
|
||||
citation_report.grounded,
|
||||
)
|
||||
return answer, False, citation_report
|
||||
|
||||
|
||||
@router.post("/search", response_class=HTMLResponse)
|
||||
def search(
|
||||
request: Request,
|
||||
config: AppConfig,
|
||||
engine: AppEngine,
|
||||
query: Annotated[str, Form()],
|
||||
rerank: Annotated[str | None, Form()] = None,
|
||||
) -> HTMLResponse:
|
||||
"""Run a search and render HTMX results."""
|
||||
try:
|
||||
response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true")
|
||||
response = search_ebooks(engine, query, config, rerank=rerank == "true")
|
||||
except Exception as error:
|
||||
logger.exception("ebook_search_request_failed")
|
||||
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
|
||||
|
||||
answer_start = perf_counter()
|
||||
if request.app.state.config.answer_enabled:
|
||||
try:
|
||||
answer = answer_query(query, response.results, request.app.state.config)
|
||||
except RuntimeError as error:
|
||||
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
|
||||
answer = "Answer generation failed. Source chunks are still shown below."
|
||||
else:
|
||||
logger.info("ebook_answer_skipped_disabled")
|
||||
answer = "Answer generation is disabled. Source chunks are shown below."
|
||||
answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped"
|
||||
answer, low_confidence, citation_report = build_answer(query, response, config)
|
||||
answer_step_name = "Answer generation" if config.answer_enabled else "Answer skipped"
|
||||
response = replace(
|
||||
response,
|
||||
timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)),
|
||||
@@ -55,4 +97,13 @@ def search(
|
||||
response.rank_label,
|
||||
response.total_runtime_ms,
|
||||
)
|
||||
return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response})
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"partials/results.html",
|
||||
{
|
||||
"answer": answer,
|
||||
"response": response,
|
||||
"low_confidence": low_confidence,
|
||||
"citation_report": citation_report,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from os import getenv
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Self
|
||||
|
||||
from pydantic import AliasChoices, Field, field_validator
|
||||
from pydantic import AliasChoices, Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
||||
|
||||
|
||||
@@ -82,6 +82,8 @@ class EbookSearchConfig(BaseSettings):
|
||||
vector_candidate_multiplier: int = 4
|
||||
bm25_candidate_limit: int = 120
|
||||
rrf_rank_constant: int = 60
|
||||
min_retrieval_confidence: float = 0.0
|
||||
validate_citations_enabled: bool = True
|
||||
bm25_index_dir: str = ".ebook_search_bm25"
|
||||
bm25_refresh_delay_seconds: int = 60
|
||||
|
||||
@@ -99,6 +101,20 @@ class EbookSearchConfig(BaseSettings):
|
||||
"""Normalize the configured embedding alias to its provider model name."""
|
||||
return normalize_embedding_alias(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runtime_consistency(self) -> Self:
|
||||
"""Reject configurations that cannot serve the features they enable."""
|
||||
if not self.embedding_base_url.strip():
|
||||
msg = "embedding_base_url must be set"
|
||||
raise ValueError(msg)
|
||||
if self.answer_enabled and (not self.vllm_base_url.strip() or not self.chat_model.strip()):
|
||||
msg = "answer_enabled requires vllm_base_url and chat_model to be set"
|
||||
raise ValueError(msg)
|
||||
if self.rerank.enabled and not self.rerank.base_url.strip():
|
||||
msg = "rerank.enabled requires rerank.base_url to be set"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
|
||||
def load_rerank_config() -> RerankConfig:
|
||||
"""Load reranker config from environment variables."""
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Serve-time output guardrails for retrieval confidence and answer citations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.ebook_search.config import EbookSearchConfig
|
||||
from python.ebook_search.search import SearchResult
|
||||
|
||||
CITATION_RE = re.compile(r"\[(\d+)\]")
|
||||
|
||||
|
||||
def retrieval_confidence(results: list[SearchResult]) -> float:
|
||||
"""Return the strongest interpretable relevance signal of the top result.
|
||||
|
||||
Reciprocal-rank-fusion scores are rank-based and not comparable across queries,
|
||||
so the rerank relevance score is preferred, then vector cosine similarity, then
|
||||
the final score.
|
||||
"""
|
||||
if not results:
|
||||
return 0.0
|
||||
top = results[0]
|
||||
if top.rerank_score is not None:
|
||||
return top.rerank_score
|
||||
if top.vector_score is not None:
|
||||
return top.vector_score
|
||||
return top.score
|
||||
|
||||
|
||||
def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool:
|
||||
"""Return whether top-result confidence meets the configured threshold."""
|
||||
return retrieval_confidence(results) >= config.min_retrieval_confidence
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CitationReport:
|
||||
"""Validation summary for bracketed citation markers in a generated answer."""
|
||||
|
||||
cited: tuple[int, ...]
|
||||
invalid: tuple[int, ...]
|
||||
grounded: bool
|
||||
|
||||
|
||||
def validate_citations(answer: str, result_count: int) -> CitationReport:
|
||||
"""Validate bracketed citation markers against the number of shown sources.
|
||||
|
||||
A marker is valid when it points to a returned source (``1..result_count``).
|
||||
``grounded`` is true when the answer cites at least one valid source.
|
||||
"""
|
||||
markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)})
|
||||
valid = range(1, result_count + 1)
|
||||
cited = tuple(marker for marker in markers if marker in valid)
|
||||
invalid = tuple(marker for marker in markers if marker not in valid)
|
||||
return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited))
|
||||
Reference in New Issue
Block a user