Files
dotfiles/python/ebook_search/guardrails.py
T

58 lines
2.0 KiB
Python

"""Serve-time output guardrails for retrieval confidence and answer citations."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from python.ebook_search.config import EbookSearchConfig
from python.ebook_search.search import SearchResult
CITATION_RE = re.compile(r"\[(\d+)\]")
def retrieval_confidence(results: list[SearchResult]) -> float:
"""Return the strongest interpretable relevance signal of the top result.
Reciprocal-rank-fusion scores are rank-based and not comparable across queries,
so the rerank relevance score is preferred, then vector cosine similarity, then
the final score.
"""
if not results:
return 0.0
top = results[0]
if top.rerank_score is not None:
return top.rerank_score
if top.vector_score is not None:
return top.vector_score
return top.score
def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool:
"""Return whether top-result confidence meets the configured threshold."""
return retrieval_confidence(results) >= config.min_retrieval_confidence
@dataclass(frozen=True)
class CitationReport:
"""Validation summary for bracketed citation markers in a generated answer."""
cited: tuple[int, ...]
invalid: tuple[int, ...]
grounded: bool
def validate_citations(answer: str, result_count: int) -> CitationReport:
"""Validate bracketed citation markers against the number of shown sources.
A marker is valid when it points to a returned source (``1..result_count``).
``grounded`` is true when the answer cites at least one valid source.
"""
markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)})
valid = range(1, result_count + 1)
cited = tuple(marker for marker in markers if marker in valid)
invalid = tuple(marker for marker in markers if marker not in valid)
return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited))