58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
"""Serve-time output guardrails for retrieval confidence and answer citations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from python.ebook_search.config import EbookSearchConfig
|
|
from python.ebook_search.search import SearchResult
|
|
|
|
CITATION_RE = re.compile(r"\[(\d+)\]")
|
|
|
|
|
|
def retrieval_confidence(results: list[SearchResult]) -> float:
|
|
"""Return the strongest interpretable relevance signal of the top result.
|
|
|
|
Reciprocal-rank-fusion scores are rank-based and not comparable across queries,
|
|
so the rerank relevance score is preferred, then vector cosine similarity, then
|
|
the final score.
|
|
"""
|
|
if not results:
|
|
return 0.0
|
|
top = results[0]
|
|
if top.rerank_score is not None:
|
|
return top.rerank_score
|
|
if top.vector_score is not None:
|
|
return top.vector_score
|
|
return top.score
|
|
|
|
|
|
def is_confident(results: list[SearchResult], config: EbookSearchConfig) -> bool:
|
|
"""Return whether top-result confidence meets the configured threshold."""
|
|
return retrieval_confidence(results) >= config.min_retrieval_confidence
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CitationReport:
|
|
"""Validation summary for bracketed citation markers in a generated answer."""
|
|
|
|
cited: tuple[int, ...]
|
|
invalid: tuple[int, ...]
|
|
grounded: bool
|
|
|
|
|
|
def validate_citations(answer: str, result_count: int) -> CitationReport:
|
|
"""Validate bracketed citation markers against the number of shown sources.
|
|
|
|
A marker is valid when it points to a returned source (``1..result_count``).
|
|
``grounded`` is true when the answer cites at least one valid source.
|
|
"""
|
|
markers = sorted({int(match.group(1)) for match in CITATION_RE.finditer(answer)})
|
|
valid = range(1, result_count + 1)
|
|
cited = tuple(marker for marker in markers if marker in valid)
|
|
invalid = tuple(marker for marker in markers if marker not in valid)
|
|
return CitationReport(cited=cited, invalid=invalid, grounded=bool(cited))
|