"""Summarization evaluation for Congressional bill summaries. Three use cases from one module: 1. Data filtering — score GPT batch outputs before building the fine-tune JSONL: from summarization_eval import filter_dataset filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl") 2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection: from summarization_eval import make_compute_metrics trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer)) 3. Inference eval — score a finished model against held-out references: from summarization_eval import evaluate_file results = evaluate_file("output/predictions.jsonl", "output/references.jsonl") """ from __future__ import annotations import json import logging from dataclasses import dataclass from pathlib import Path from typing import Callable import numpy as np from rouge_score import rouge_scorer logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- SECTION_HEADERS = [ "OPERATIVE ACTIONS", "AFFECTED POPULATIONS", "MECHANISMS", "POLICY THREADS", "SYMBOLIC/PROCEDURAL ONLY", ] # Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally ROUGE_WEIGHTS = { "rouge1": 0.2, "rouge2": 0.4, "rougeL": 0.4, } # Composite score floor below which a training example is considered low quality FILTER_THRESHOLD = 0.25 # --------------------------------------------------------------------------- # Core data structures # --------------------------------------------------------------------------- @dataclass class SummaryScore: """Scores for a single (prediction, reference) pair.""" rouge1: float rouge2: float rougeL: float composite: float has_all_sections: bool # True = all 5 headers present missing_sections: list[str] structural_fail: bool # True = one or more headers missing (hard guardrail) def as_dict(self) -> dict: return { "rouge1": self.rouge1, "rouge2": self.rouge2, "rougeL": self.rougeL, "composite": self.composite, "has_all_sections": self.has_all_sections, "missing_sections": self.missing_sections, "structural_fail": self.structural_fail, } @dataclass class BatchResult: """Aggregate results over a batch of summaries.""" n_total: int n_structural_fail: int n_scored: int # excludes structural failures rouge1_mean: float rouge2_mean: float rougeL_mean: float composite_mean: float scores: list[SummaryScore] def as_dict(self) -> dict: return { "n_total": self.n_total, "n_structural_fail": self.n_structural_fail, "n_scored": self.n_scored, "rouge1_mean": self.rouge1_mean, "rouge2_mean": self.rouge2_mean, "rougeL_mean": self.rougeL_mean, "composite_mean": self.composite_mean, } # --------------------------------------------------------------------------- # Core scoring # --------------------------------------------------------------------------- _scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) def check_sections(text: str) -> tuple[bool, list[str]]: """Return (all_present, missing_headers) for the 5 required section headers.""" missing = [h for h in SECTION_HEADERS if h not in text.upper()] return len(missing) == 0, missing def score_pair(prediction: str, reference: str) -> SummaryScore: """Score a single (prediction, reference) pair. If the prediction is missing any section header, structural_fail is True and ROUGE scores are still computed (so you can inspect quality even on structural failures) but the example should be treated as a guardrail failure. """ has_all, missing = check_sections(prediction) rouge = _scorer.score(reference, prediction) r1 = rouge["rouge1"].fmeasure r2 = rouge["rouge2"].fmeasure rl = rouge["rougeL"].fmeasure composite = ( ROUGE_WEIGHTS["rouge1"] * r1 + ROUGE_WEIGHTS["rouge2"] * r2 + ROUGE_WEIGHTS["rougeL"] * rl ) return SummaryScore( rouge1=r1, rouge2=r2, rougeL=rl, composite=composite, has_all_sections=has_all, missing_sections=missing, structural_fail=not has_all, ) def score_batch(pairs: list[tuple[str, str]]) -> BatchResult: """Score a list of (prediction, reference) pairs and return aggregate results. Structural failures are counted separately and excluded from ROUGE means so a batch with broken formatting doesn't drag down the score unfairly. """ scores = [score_pair(pred, ref) for pred, ref in pairs] structural_fails = [s for s in scores if s.structural_fail] valid = [s for s in scores if not s.structural_fail] if valid: rouge1_mean = float(np.mean([s.rouge1 for s in valid])) rouge2_mean = float(np.mean([s.rouge2 for s in valid])) rougeL_mean = float(np.mean([s.rougeL for s in valid])) composite_mean = float(np.mean([s.composite for s in valid])) else: rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0 return BatchResult( n_total=len(scores), n_structural_fail=len(structural_fails), n_scored=len(valid), rouge1_mean=rouge1_mean, rouge2_mean=rouge2_mean, rougeL_mean=rougeL_mean, composite_mean=composite_mean, scores=scores, ) # --------------------------------------------------------------------------- # Use case 1: Data filtering # --------------------------------------------------------------------------- def filter_dataset( input_path: Path | str, output_path: Path | str, *, threshold: float = FILTER_THRESHOLD, ) -> tuple[int, int]: """Filter a fine-tuning JSONL by ROUGE composite score and section guardrail. Each line must be a ChatML messages dict: {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]} The assistant turn is the prediction. The reference is the same assistant turn — filtering here uses composite score as a self-consistency check against the threshold, and drops structural failures unconditionally. In practice you'd call this after joining requests + GPT completions (build_finetune_dataset.py) to drop any GPT outputs that are malformed or suspiciously short/low quality. Returns (kept, dropped). """ input_path = Path(input_path) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) kept = 0 dropped = 0 with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout: for line_num, raw_line in enumerate(fin, 1): stripped = raw_line.strip() if not stripped: continue example = json.loads(stripped) messages = example.get("messages", []) assistant_turns = [m for m in messages if m.get("role") == "assistant"] if not assistant_turns: logger.warning("Line %d: no assistant turn, dropping", line_num) dropped += 1 continue prediction = assistant_turns[-1].get("content", "") # Guardrail: drop if any section header missing has_all, missing = check_sections(prediction) if not has_all: logger.warning( "Line %d: structural fail (missing: %s), dropping", line_num, ", ".join(missing), ) dropped += 1 continue # Quality floor: score against itself isn't meaningful for filtering — # instead just check composite score of prediction vs a simple # word-count proxy. For filtering GPT outputs, structural check # + a minimum word count is usually sufficient. word_count = len(prediction.split()) if word_count < 80: logger.warning( "Line %d: too short (%d words), dropping", line_num, word_count ) dropped += 1 continue fout.write(json.dumps(example, ensure_ascii=False) + "\n") kept += 1 logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path) return kept, dropped # --------------------------------------------------------------------------- # Use case 2: compute_metrics hook for SFTTrainer # --------------------------------------------------------------------------- def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001 """Return a compute_metrics function compatible with HuggingFace Trainer. Usage in finetune.py: from summarization_eval import make_compute_metrics trainer = SFTTrainer( ... compute_metrics=make_compute_metrics(tokenizer), ) Note: EvalPrediction.predictions are logits (or token ids if include_inputs_for_metrics is False). This function handles both. For SFTTrainer with packing=True, you may need to set predict_with_generate=True in TrainingArguments to get decoded text. """ def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001 predictions, labels = eval_pred # If predictions are logits, take argmax if predictions.ndim == 3: predictions = np.argmax(predictions, axis=-1) # Mask out -100 padding in labels labels = np.where(labels == -100, tokenizer.pad_token_id, labels) decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) pairs = list(zip(decoded_preds, decoded_labels)) result = score_batch(pairs) metrics = { "eval_rouge1": result.rouge1_mean, "eval_rouge2": result.rouge2_mean, "eval_rougeL": result.rougeL_mean, "eval_composite": result.composite_mean, "eval_structural_fail_rate": ( result.n_structural_fail / result.n_total if result.n_total else 0.0 ), } logger.info( "Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d", metrics["eval_composite"], metrics["eval_rouge1"], metrics["eval_rouge2"], metrics["eval_rougeL"], result.n_structural_fail, result.n_total, ) return metrics return compute_metrics # --------------------------------------------------------------------------- # Use case 3: Inference eval against held-out references # --------------------------------------------------------------------------- def evaluate_file( predictions_path: Path | str, references_path: Path | str, output_path: Path | str | None = None, ) -> BatchResult: """Score a predictions JSONL against a references JSONL. Both files should be line-matched: line N of predictions corresponds to line N of references. Each line should be a plain JSON object with a "text" or "content" key, or a ChatML messages dict. If output_path is provided, writes per-example scores as JSONL. """ predictions_path = Path(predictions_path) references_path = Path(references_path) def extract_text(line: str) -> str: obj = json.loads(line) # Plain text field if "text" in obj: return obj["text"] if "content" in obj: return obj["content"] # ChatML messages — take last assistant turn messages = obj.get("messages", []) for m in reversed(messages): if m.get("role") == "assistant": return m.get("content", "") return "" preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()] refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()] if len(preds) != len(refs): msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})" raise ValueError(msg) result = score_batch(list(zip(preds, refs))) logger.info( "Inference eval: n=%d structural_fails=%d composite=%.4f " "rouge1=%.4f rouge2=%.4f rougeL=%.4f", result.n_total, result.n_structural_fail, result.composite_mean, result.rouge1_mean, result.rouge2_mean, result.rougeL_mean, ) if output_path is not None: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as fout: for score in result.scores: fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n") summary_path = output_path.with_suffix(".summary.json") summary_path.write_text(json.dumps(result.as_dict(), indent=2)) logger.info("Wrote per-example scores to %s", output_path) logger.info("Wrote summary to %s", summary_path) return result # --------------------------------------------------------------------------- # CLI — quick sanity check / standalone use # --------------------------------------------------------------------------- def _cli() -> None: import argparse parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.") subparsers = parser.add_subparsers(dest="command", required=True) # filter subcommand fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset") fp.add_argument("--input", required=True, type=Path) fp.add_argument("--output", required=True, type=Path) fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD) # eval subcommand ep = subparsers.add_parser("eval", help="Score predictions against references") ep.add_argument("--predictions", required=True, type=Path) ep.add_argument("--references", required=True, type=Path) ep.add_argument("--output", type=Path, default=None) args = parser.parse_args() logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s") if args.command == "filter": kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold) print(f"Kept: {kept} Dropped: {dropped}") elif args.command == "eval": result = evaluate_file(args.predictions, args.references, args.output) print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):") print(f" ROUGE-1: {result.rouge1_mean:.4f}") print(f" ROUGE-2: {result.rouge2_mean:.4f}") print(f" ROUGE-L: {result.rougeL_mean:.4f}") print(f" Composite: {result.composite_mean:.4f}") if __name__ == "__main__": _cli()