|
|
|
|
@@ -0,0 +1,426 @@
|
|
|
|
|
"""Summarization evaluation for Congressional bill summaries.
|
|
|
|
|
|
|
|
|
|
Three use cases from one module:
|
|
|
|
|
|
|
|
|
|
1. Data filtering — score GPT batch outputs before building the fine-tune JSONL:
|
|
|
|
|
from summarization_eval import filter_dataset
|
|
|
|
|
filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl")
|
|
|
|
|
|
|
|
|
|
2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection:
|
|
|
|
|
from summarization_eval import make_compute_metrics
|
|
|
|
|
trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer))
|
|
|
|
|
|
|
|
|
|
3. Inference eval — score a finished model against held-out references:
|
|
|
|
|
from summarization_eval import evaluate_file
|
|
|
|
|
results = evaluate_file("output/predictions.jsonl", "output/references.jsonl")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from rouge_score import rouge_scorer
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Constants
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
SECTION_HEADERS = [
|
|
|
|
|
"OPERATIVE ACTIONS",
|
|
|
|
|
"AFFECTED POPULATIONS",
|
|
|
|
|
"MECHANISMS",
|
|
|
|
|
"POLICY THREADS",
|
|
|
|
|
"SYMBOLIC/PROCEDURAL ONLY",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally
|
|
|
|
|
ROUGE_WEIGHTS = {
|
|
|
|
|
"rouge1": 0.2,
|
|
|
|
|
"rouge2": 0.4,
|
|
|
|
|
"rougeL": 0.4,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Composite score floor below which a training example is considered low quality
|
|
|
|
|
FILTER_THRESHOLD = 0.25
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Core data structures
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class SummaryScore:
|
|
|
|
|
"""Scores for a single (prediction, reference) pair."""
|
|
|
|
|
|
|
|
|
|
rouge1: float
|
|
|
|
|
rouge2: float
|
|
|
|
|
rougeL: float
|
|
|
|
|
composite: float
|
|
|
|
|
has_all_sections: bool # True = all 5 headers present
|
|
|
|
|
missing_sections: list[str]
|
|
|
|
|
structural_fail: bool # True = one or more headers missing (hard guardrail)
|
|
|
|
|
|
|
|
|
|
def as_dict(self) -> dict:
|
|
|
|
|
return {
|
|
|
|
|
"rouge1": self.rouge1,
|
|
|
|
|
"rouge2": self.rouge2,
|
|
|
|
|
"rougeL": self.rougeL,
|
|
|
|
|
"composite": self.composite,
|
|
|
|
|
"has_all_sections": self.has_all_sections,
|
|
|
|
|
"missing_sections": self.missing_sections,
|
|
|
|
|
"structural_fail": self.structural_fail,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class BatchResult:
|
|
|
|
|
"""Aggregate results over a batch of summaries."""
|
|
|
|
|
|
|
|
|
|
n_total: int
|
|
|
|
|
n_structural_fail: int
|
|
|
|
|
n_scored: int # excludes structural failures
|
|
|
|
|
rouge1_mean: float
|
|
|
|
|
rouge2_mean: float
|
|
|
|
|
rougeL_mean: float
|
|
|
|
|
composite_mean: float
|
|
|
|
|
scores: list[SummaryScore]
|
|
|
|
|
|
|
|
|
|
def as_dict(self) -> dict:
|
|
|
|
|
return {
|
|
|
|
|
"n_total": self.n_total,
|
|
|
|
|
"n_structural_fail": self.n_structural_fail,
|
|
|
|
|
"n_scored": self.n_scored,
|
|
|
|
|
"rouge1_mean": self.rouge1_mean,
|
|
|
|
|
"rouge2_mean": self.rouge2_mean,
|
|
|
|
|
"rougeL_mean": self.rougeL_mean,
|
|
|
|
|
"composite_mean": self.composite_mean,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Core scoring
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_sections(text: str) -> tuple[bool, list[str]]:
|
|
|
|
|
"""Return (all_present, missing_headers) for the 5 required section headers."""
|
|
|
|
|
missing = [h for h in SECTION_HEADERS if h not in text.upper()]
|
|
|
|
|
return len(missing) == 0, missing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def score_pair(prediction: str, reference: str) -> SummaryScore:
|
|
|
|
|
"""Score a single (prediction, reference) pair.
|
|
|
|
|
|
|
|
|
|
If the prediction is missing any section header, structural_fail is True
|
|
|
|
|
and ROUGE scores are still computed (so you can inspect quality even on
|
|
|
|
|
structural failures) but the example should be treated as a guardrail failure.
|
|
|
|
|
"""
|
|
|
|
|
has_all, missing = check_sections(prediction)
|
|
|
|
|
|
|
|
|
|
rouge = _scorer.score(reference, prediction)
|
|
|
|
|
r1 = rouge["rouge1"].fmeasure
|
|
|
|
|
r2 = rouge["rouge2"].fmeasure
|
|
|
|
|
rl = rouge["rougeL"].fmeasure
|
|
|
|
|
composite = (
|
|
|
|
|
ROUGE_WEIGHTS["rouge1"] * r1
|
|
|
|
|
+ ROUGE_WEIGHTS["rouge2"] * r2
|
|
|
|
|
+ ROUGE_WEIGHTS["rougeL"] * rl
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return SummaryScore(
|
|
|
|
|
rouge1=r1,
|
|
|
|
|
rouge2=r2,
|
|
|
|
|
rougeL=rl,
|
|
|
|
|
composite=composite,
|
|
|
|
|
has_all_sections=has_all,
|
|
|
|
|
missing_sections=missing,
|
|
|
|
|
structural_fail=not has_all,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def score_batch(pairs: list[tuple[str, str]]) -> BatchResult:
|
|
|
|
|
"""Score a list of (prediction, reference) pairs and return aggregate results.
|
|
|
|
|
|
|
|
|
|
Structural failures are counted separately and excluded from ROUGE means
|
|
|
|
|
so a batch with broken formatting doesn't drag down the score unfairly.
|
|
|
|
|
"""
|
|
|
|
|
scores = [score_pair(pred, ref) for pred, ref in pairs]
|
|
|
|
|
|
|
|
|
|
structural_fails = [s for s in scores if s.structural_fail]
|
|
|
|
|
valid = [s for s in scores if not s.structural_fail]
|
|
|
|
|
|
|
|
|
|
if valid:
|
|
|
|
|
rouge1_mean = float(np.mean([s.rouge1 for s in valid]))
|
|
|
|
|
rouge2_mean = float(np.mean([s.rouge2 for s in valid]))
|
|
|
|
|
rougeL_mean = float(np.mean([s.rougeL for s in valid]))
|
|
|
|
|
composite_mean = float(np.mean([s.composite for s in valid]))
|
|
|
|
|
else:
|
|
|
|
|
rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0
|
|
|
|
|
|
|
|
|
|
return BatchResult(
|
|
|
|
|
n_total=len(scores),
|
|
|
|
|
n_structural_fail=len(structural_fails),
|
|
|
|
|
n_scored=len(valid),
|
|
|
|
|
rouge1_mean=rouge1_mean,
|
|
|
|
|
rouge2_mean=rouge2_mean,
|
|
|
|
|
rougeL_mean=rougeL_mean,
|
|
|
|
|
composite_mean=composite_mean,
|
|
|
|
|
scores=scores,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Use case 1: Data filtering
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def filter_dataset(
|
|
|
|
|
input_path: Path | str,
|
|
|
|
|
output_path: Path | str,
|
|
|
|
|
*,
|
|
|
|
|
threshold: float = FILTER_THRESHOLD,
|
|
|
|
|
) -> tuple[int, int]:
|
|
|
|
|
"""Filter a fine-tuning JSONL by ROUGE composite score and section guardrail.
|
|
|
|
|
|
|
|
|
|
Each line must be a ChatML messages dict:
|
|
|
|
|
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
|
|
|
|
|
|
|
|
|
|
The assistant turn is the prediction. The reference is the same assistant
|
|
|
|
|
turn — filtering here uses composite score as a self-consistency check
|
|
|
|
|
against the threshold, and drops structural failures unconditionally.
|
|
|
|
|
|
|
|
|
|
In practice you'd call this after joining requests + GPT completions
|
|
|
|
|
(build_finetune_dataset.py) to drop any GPT outputs that are malformed
|
|
|
|
|
or suspiciously short/low quality.
|
|
|
|
|
|
|
|
|
|
Returns (kept, dropped).
|
|
|
|
|
"""
|
|
|
|
|
input_path = Path(input_path)
|
|
|
|
|
output_path = Path(output_path)
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
kept = 0
|
|
|
|
|
dropped = 0
|
|
|
|
|
|
|
|
|
|
with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
|
|
|
|
|
for line_num, raw_line in enumerate(fin, 1):
|
|
|
|
|
stripped = raw_line.strip()
|
|
|
|
|
if not stripped:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
example = json.loads(stripped)
|
|
|
|
|
messages = example.get("messages", [])
|
|
|
|
|
assistant_turns = [m for m in messages if m.get("role") == "assistant"]
|
|
|
|
|
|
|
|
|
|
if not assistant_turns:
|
|
|
|
|
logger.warning("Line %d: no assistant turn, dropping", line_num)
|
|
|
|
|
dropped += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
prediction = assistant_turns[-1].get("content", "")
|
|
|
|
|
|
|
|
|
|
# Guardrail: drop if any section header missing
|
|
|
|
|
has_all, missing = check_sections(prediction)
|
|
|
|
|
if not has_all:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Line %d: structural fail (missing: %s), dropping",
|
|
|
|
|
line_num,
|
|
|
|
|
", ".join(missing),
|
|
|
|
|
)
|
|
|
|
|
dropped += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Quality floor: score against itself isn't meaningful for filtering —
|
|
|
|
|
# instead just check composite score of prediction vs a simple
|
|
|
|
|
# word-count proxy. For filtering GPT outputs, structural check
|
|
|
|
|
# + a minimum word count is usually sufficient.
|
|
|
|
|
word_count = len(prediction.split())
|
|
|
|
|
if word_count < 80:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Line %d: too short (%d words), dropping", line_num, word_count
|
|
|
|
|
)
|
|
|
|
|
dropped += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
fout.write(json.dumps(example, ensure_ascii=False) + "\n")
|
|
|
|
|
kept += 1
|
|
|
|
|
|
|
|
|
|
logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path)
|
|
|
|
|
return kept, dropped
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Use case 2: compute_metrics hook for SFTTrainer
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001
|
|
|
|
|
"""Return a compute_metrics function compatible with HuggingFace Trainer.
|
|
|
|
|
|
|
|
|
|
Usage in finetune.py:
|
|
|
|
|
from summarization_eval import make_compute_metrics
|
|
|
|
|
trainer = SFTTrainer(
|
|
|
|
|
...
|
|
|
|
|
compute_metrics=make_compute_metrics(tokenizer),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
Note: EvalPrediction.predictions are logits (or token ids if
|
|
|
|
|
include_inputs_for_metrics is False). This function handles both.
|
|
|
|
|
For SFTTrainer with packing=True, you may need to set
|
|
|
|
|
predict_with_generate=True in TrainingArguments to get decoded text.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001
|
|
|
|
|
predictions, labels = eval_pred
|
|
|
|
|
|
|
|
|
|
# If predictions are logits, take argmax
|
|
|
|
|
if predictions.ndim == 3:
|
|
|
|
|
predictions = np.argmax(predictions, axis=-1)
|
|
|
|
|
|
|
|
|
|
# Mask out -100 padding in labels
|
|
|
|
|
labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
|
|
|
|
|
|
|
|
|
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
|
|
|
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
pairs = list(zip(decoded_preds, decoded_labels))
|
|
|
|
|
result = score_batch(pairs)
|
|
|
|
|
|
|
|
|
|
metrics = {
|
|
|
|
|
"eval_rouge1": result.rouge1_mean,
|
|
|
|
|
"eval_rouge2": result.rouge2_mean,
|
|
|
|
|
"eval_rougeL": result.rougeL_mean,
|
|
|
|
|
"eval_composite": result.composite_mean,
|
|
|
|
|
"eval_structural_fail_rate": (
|
|
|
|
|
result.n_structural_fail / result.n_total if result.n_total else 0.0
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
logger.info(
|
|
|
|
|
"Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d",
|
|
|
|
|
metrics["eval_composite"],
|
|
|
|
|
metrics["eval_rouge1"],
|
|
|
|
|
metrics["eval_rouge2"],
|
|
|
|
|
metrics["eval_rougeL"],
|
|
|
|
|
result.n_structural_fail,
|
|
|
|
|
result.n_total,
|
|
|
|
|
)
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
return compute_metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Use case 3: Inference eval against held-out references
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def evaluate_file(
|
|
|
|
|
predictions_path: Path | str,
|
|
|
|
|
references_path: Path | str,
|
|
|
|
|
output_path: Path | str | None = None,
|
|
|
|
|
) -> BatchResult:
|
|
|
|
|
"""Score a predictions JSONL against a references JSONL.
|
|
|
|
|
|
|
|
|
|
Both files should be line-matched: line N of predictions corresponds
|
|
|
|
|
to line N of references. Each line should be a plain JSON object with
|
|
|
|
|
a "text" or "content" key, or a ChatML messages dict.
|
|
|
|
|
|
|
|
|
|
If output_path is provided, writes per-example scores as JSONL.
|
|
|
|
|
"""
|
|
|
|
|
predictions_path = Path(predictions_path)
|
|
|
|
|
references_path = Path(references_path)
|
|
|
|
|
|
|
|
|
|
def extract_text(line: str) -> str:
|
|
|
|
|
obj = json.loads(line)
|
|
|
|
|
# Plain text field
|
|
|
|
|
if "text" in obj:
|
|
|
|
|
return obj["text"]
|
|
|
|
|
if "content" in obj:
|
|
|
|
|
return obj["content"]
|
|
|
|
|
# ChatML messages — take last assistant turn
|
|
|
|
|
messages = obj.get("messages", [])
|
|
|
|
|
for m in reversed(messages):
|
|
|
|
|
if m.get("role") == "assistant":
|
|
|
|
|
return m.get("content", "")
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()]
|
|
|
|
|
refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()]
|
|
|
|
|
|
|
|
|
|
if len(preds) != len(refs):
|
|
|
|
|
msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
result = score_batch(list(zip(preds, refs)))
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Inference eval: n=%d structural_fails=%d composite=%.4f "
|
|
|
|
|
"rouge1=%.4f rouge2=%.4f rougeL=%.4f",
|
|
|
|
|
result.n_total,
|
|
|
|
|
result.n_structural_fail,
|
|
|
|
|
result.composite_mean,
|
|
|
|
|
result.rouge1_mean,
|
|
|
|
|
result.rouge2_mean,
|
|
|
|
|
result.rougeL_mean,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if output_path is not None:
|
|
|
|
|
output_path = Path(output_path)
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
with output_path.open("w", encoding="utf-8") as fout:
|
|
|
|
|
for score in result.scores:
|
|
|
|
|
fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n")
|
|
|
|
|
summary_path = output_path.with_suffix(".summary.json")
|
|
|
|
|
summary_path.write_text(json.dumps(result.as_dict(), indent=2))
|
|
|
|
|
logger.info("Wrote per-example scores to %s", output_path)
|
|
|
|
|
logger.info("Wrote summary to %s", summary_path)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# CLI — quick sanity check / standalone use
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def _cli() -> None:
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.")
|
|
|
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
|
|
|
|
|
|
# filter subcommand
|
|
|
|
|
fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset")
|
|
|
|
|
fp.add_argument("--input", required=True, type=Path)
|
|
|
|
|
fp.add_argument("--output", required=True, type=Path)
|
|
|
|
|
fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD)
|
|
|
|
|
|
|
|
|
|
# eval subcommand
|
|
|
|
|
ep = subparsers.add_parser("eval", help="Score predictions against references")
|
|
|
|
|
ep.add_argument("--predictions", required=True, type=Path)
|
|
|
|
|
ep.add_argument("--references", required=True, type=Path)
|
|
|
|
|
ep.add_argument("--output", type=Path, default=None)
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
|
|
|
|
|
|
|
|
|
|
if args.command == "filter":
|
|
|
|
|
kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold)
|
|
|
|
|
print(f"Kept: {kept} Dropped: {dropped}")
|
|
|
|
|
|
|
|
|
|
elif args.command == "eval":
|
|
|
|
|
result = evaluate_file(args.predictions, args.references, args.output)
|
|
|
|
|
print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):")
|
|
|
|
|
print(f" ROUGE-1: {result.rouge1_mean:.4f}")
|
|
|
|
|
print(f" ROUGE-2: {result.rouge2_mean:.4f}")
|
|
|
|
|
print(f" ROUGE-L: {result.rougeL_mean:.4f}")
|
|
|
|
|
print(f" Composite: {result.composite_mean:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
_cli()
|