Add eval rouge metrics

This commit is contained in:
2026-04-19 23:40:50 -04:00
parent db3583e7f2
commit 1e9c2a6caa
4 changed files with 435 additions and 2 deletions

View File

@@ -14,12 +14,13 @@
FROM ghcr.io/unslothai/unsloth:latest FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer RUN pip install --no-cache-dir typer rouge-score
WORKDIR /workspace WORKDIR /workspace
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/prompt_bench/summarization_eval.py python/prompt_bench/summarization_eval.py
COPY python/__init__.py python/__init__.py COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"] ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]

View File

@@ -133,7 +133,7 @@ def build() -> None:
@app.command() @app.command()
def run( def run(
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
/ "data/finetune_dataset.jsonl", / "/zfs/storage/data_science/data/finetune_dataset.jsonl",
output_dir: Annotated[ output_dir: Annotated[
Path, typer.Option(help="Where to save the trained model") Path, typer.Option(help="Where to save the trained model")
] = REPO_DIR / "data/output/qwen-bill-summarizer", ] = REPO_DIR / "data/output/qwen-bill-summarizer",

View File

@@ -25,6 +25,8 @@ from datasets import Dataset
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import SFTTrainer from trl import SFTTrainer
from .summarization_eval import make_compute_metrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -187,6 +189,9 @@ def main(
optim="adamw_8bit", optim="adamw_8bit",
seed=42, seed=42,
report_to="none", report_to="none",
metric_for_best_model="eval_composite",
greater_is_better=True,
predict_with_generate=True,
) )
trainer = SFTTrainer( trainer = SFTTrainer(
@@ -197,6 +202,7 @@ def main(
args=training_args, args=training_args,
max_seq_length=config.training.max_seq_length, max_seq_length=config.training.max_seq_length,
packing=True, packing=True,
compute_metrics=make_compute_metrics(tokenizer),
) )
logger.info( logger.info(

View File

@@ -0,0 +1,426 @@
"""Summarization evaluation for Congressional bill summaries.
Three use cases from one module:
1. Data filtering — score GPT batch outputs before building the fine-tune JSONL:
from summarization_eval import filter_dataset
filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl")
2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection:
from summarization_eval import make_compute_metrics
trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer))
3. Inference eval — score a finished model against held-out references:
from summarization_eval import evaluate_file
results = evaluate_file("output/predictions.jsonl", "output/references.jsonl")
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import numpy as np
from rouge_score import rouge_scorer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SECTION_HEADERS = [
"OPERATIVE ACTIONS",
"AFFECTED POPULATIONS",
"MECHANISMS",
"POLICY THREADS",
"SYMBOLIC/PROCEDURAL ONLY",
]
# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally
ROUGE_WEIGHTS = {
"rouge1": 0.2,
"rouge2": 0.4,
"rougeL": 0.4,
}
# Composite score floor below which a training example is considered low quality
FILTER_THRESHOLD = 0.25
# ---------------------------------------------------------------------------
# Core data structures
# ---------------------------------------------------------------------------
@dataclass
class SummaryScore:
"""Scores for a single (prediction, reference) pair."""
rouge1: float
rouge2: float
rougeL: float
composite: float
has_all_sections: bool # True = all 5 headers present
missing_sections: list[str]
structural_fail: bool # True = one or more headers missing (hard guardrail)
def as_dict(self) -> dict:
return {
"rouge1": self.rouge1,
"rouge2": self.rouge2,
"rougeL": self.rougeL,
"composite": self.composite,
"has_all_sections": self.has_all_sections,
"missing_sections": self.missing_sections,
"structural_fail": self.structural_fail,
}
@dataclass
class BatchResult:
"""Aggregate results over a batch of summaries."""
n_total: int
n_structural_fail: int
n_scored: int # excludes structural failures
rouge1_mean: float
rouge2_mean: float
rougeL_mean: float
composite_mean: float
scores: list[SummaryScore]
def as_dict(self) -> dict:
return {
"n_total": self.n_total,
"n_structural_fail": self.n_structural_fail,
"n_scored": self.n_scored,
"rouge1_mean": self.rouge1_mean,
"rouge2_mean": self.rouge2_mean,
"rougeL_mean": self.rougeL_mean,
"composite_mean": self.composite_mean,
}
# ---------------------------------------------------------------------------
# Core scoring
# ---------------------------------------------------------------------------
_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
def check_sections(text: str) -> tuple[bool, list[str]]:
"""Return (all_present, missing_headers) for the 5 required section headers."""
missing = [h for h in SECTION_HEADERS if h not in text.upper()]
return len(missing) == 0, missing
def score_pair(prediction: str, reference: str) -> SummaryScore:
"""Score a single (prediction, reference) pair.
If the prediction is missing any section header, structural_fail is True
and ROUGE scores are still computed (so you can inspect quality even on
structural failures) but the example should be treated as a guardrail failure.
"""
has_all, missing = check_sections(prediction)
rouge = _scorer.score(reference, prediction)
r1 = rouge["rouge1"].fmeasure
r2 = rouge["rouge2"].fmeasure
rl = rouge["rougeL"].fmeasure
composite = (
ROUGE_WEIGHTS["rouge1"] * r1
+ ROUGE_WEIGHTS["rouge2"] * r2
+ ROUGE_WEIGHTS["rougeL"] * rl
)
return SummaryScore(
rouge1=r1,
rouge2=r2,
rougeL=rl,
composite=composite,
has_all_sections=has_all,
missing_sections=missing,
structural_fail=not has_all,
)
def score_batch(pairs: list[tuple[str, str]]) -> BatchResult:
"""Score a list of (prediction, reference) pairs and return aggregate results.
Structural failures are counted separately and excluded from ROUGE means
so a batch with broken formatting doesn't drag down the score unfairly.
"""
scores = [score_pair(pred, ref) for pred, ref in pairs]
structural_fails = [s for s in scores if s.structural_fail]
valid = [s for s in scores if not s.structural_fail]
if valid:
rouge1_mean = float(np.mean([s.rouge1 for s in valid]))
rouge2_mean = float(np.mean([s.rouge2 for s in valid]))
rougeL_mean = float(np.mean([s.rougeL for s in valid]))
composite_mean = float(np.mean([s.composite for s in valid]))
else:
rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0
return BatchResult(
n_total=len(scores),
n_structural_fail=len(structural_fails),
n_scored=len(valid),
rouge1_mean=rouge1_mean,
rouge2_mean=rouge2_mean,
rougeL_mean=rougeL_mean,
composite_mean=composite_mean,
scores=scores,
)
# ---------------------------------------------------------------------------
# Use case 1: Data filtering
# ---------------------------------------------------------------------------
def filter_dataset(
input_path: Path | str,
output_path: Path | str,
*,
threshold: float = FILTER_THRESHOLD,
) -> tuple[int, int]:
"""Filter a fine-tuning JSONL by ROUGE composite score and section guardrail.
Each line must be a ChatML messages dict:
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
The assistant turn is the prediction. The reference is the same assistant
turn — filtering here uses composite score as a self-consistency check
against the threshold, and drops structural failures unconditionally.
In practice you'd call this after joining requests + GPT completions
(build_finetune_dataset.py) to drop any GPT outputs that are malformed
or suspiciously short/low quality.
Returns (kept, dropped).
"""
input_path = Path(input_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
kept = 0
dropped = 0
with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
for line_num, raw_line in enumerate(fin, 1):
stripped = raw_line.strip()
if not stripped:
continue
example = json.loads(stripped)
messages = example.get("messages", [])
assistant_turns = [m for m in messages if m.get("role") == "assistant"]
if not assistant_turns:
logger.warning("Line %d: no assistant turn, dropping", line_num)
dropped += 1
continue
prediction = assistant_turns[-1].get("content", "")
# Guardrail: drop if any section header missing
has_all, missing = check_sections(prediction)
if not has_all:
logger.warning(
"Line %d: structural fail (missing: %s), dropping",
line_num,
", ".join(missing),
)
dropped += 1
continue
# Quality floor: score against itself isn't meaningful for filtering —
# instead just check composite score of prediction vs a simple
# word-count proxy. For filtering GPT outputs, structural check
# + a minimum word count is usually sufficient.
word_count = len(prediction.split())
if word_count < 80:
logger.warning(
"Line %d: too short (%d words), dropping", line_num, word_count
)
dropped += 1
continue
fout.write(json.dumps(example, ensure_ascii=False) + "\n")
kept += 1
logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path)
return kept, dropped
# ---------------------------------------------------------------------------
# Use case 2: compute_metrics hook for SFTTrainer
# ---------------------------------------------------------------------------
def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001
"""Return a compute_metrics function compatible with HuggingFace Trainer.
Usage in finetune.py:
from summarization_eval import make_compute_metrics
trainer = SFTTrainer(
...
compute_metrics=make_compute_metrics(tokenizer),
)
Note: EvalPrediction.predictions are logits (or token ids if
include_inputs_for_metrics is False). This function handles both.
For SFTTrainer with packing=True, you may need to set
predict_with_generate=True in TrainingArguments to get decoded text.
"""
def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001
predictions, labels = eval_pred
# If predictions are logits, take argmax
if predictions.ndim == 3:
predictions = np.argmax(predictions, axis=-1)
# Mask out -100 padding in labels
labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
pairs = list(zip(decoded_preds, decoded_labels))
result = score_batch(pairs)
metrics = {
"eval_rouge1": result.rouge1_mean,
"eval_rouge2": result.rouge2_mean,
"eval_rougeL": result.rougeL_mean,
"eval_composite": result.composite_mean,
"eval_structural_fail_rate": (
result.n_structural_fail / result.n_total if result.n_total else 0.0
),
}
logger.info(
"Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d",
metrics["eval_composite"],
metrics["eval_rouge1"],
metrics["eval_rouge2"],
metrics["eval_rougeL"],
result.n_structural_fail,
result.n_total,
)
return metrics
return compute_metrics
# ---------------------------------------------------------------------------
# Use case 3: Inference eval against held-out references
# ---------------------------------------------------------------------------
def evaluate_file(
predictions_path: Path | str,
references_path: Path | str,
output_path: Path | str | None = None,
) -> BatchResult:
"""Score a predictions JSONL against a references JSONL.
Both files should be line-matched: line N of predictions corresponds
to line N of references. Each line should be a plain JSON object with
a "text" or "content" key, or a ChatML messages dict.
If output_path is provided, writes per-example scores as JSONL.
"""
predictions_path = Path(predictions_path)
references_path = Path(references_path)
def extract_text(line: str) -> str:
obj = json.loads(line)
# Plain text field
if "text" in obj:
return obj["text"]
if "content" in obj:
return obj["content"]
# ChatML messages — take last assistant turn
messages = obj.get("messages", [])
for m in reversed(messages):
if m.get("role") == "assistant":
return m.get("content", "")
return ""
preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()]
refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()]
if len(preds) != len(refs):
msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})"
raise ValueError(msg)
result = score_batch(list(zip(preds, refs)))
logger.info(
"Inference eval: n=%d structural_fails=%d composite=%.4f "
"rouge1=%.4f rouge2=%.4f rougeL=%.4f",
result.n_total,
result.n_structural_fail,
result.composite_mean,
result.rouge1_mean,
result.rouge2_mean,
result.rougeL_mean,
)
if output_path is not None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as fout:
for score in result.scores:
fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n")
summary_path = output_path.with_suffix(".summary.json")
summary_path.write_text(json.dumps(result.as_dict(), indent=2))
logger.info("Wrote per-example scores to %s", output_path)
logger.info("Wrote summary to %s", summary_path)
return result
# ---------------------------------------------------------------------------
# CLI — quick sanity check / standalone use
# ---------------------------------------------------------------------------
def _cli() -> None:
import argparse
parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.")
subparsers = parser.add_subparsers(dest="command", required=True)
# filter subcommand
fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset")
fp.add_argument("--input", required=True, type=Path)
fp.add_argument("--output", required=True, type=Path)
fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD)
# eval subcommand
ep = subparsers.add_parser("eval", help="Score predictions against references")
ep.add_argument("--predictions", required=True, type=Path)
ep.add_argument("--references", required=True, type=Path)
ep.add_argument("--output", type=Path, default=None)
args = parser.parse_args()
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
if args.command == "filter":
kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold)
print(f"Kept: {kept} Dropped: {dropped}")
elif args.command == "eval":
result = evaluate_file(args.predictions, args.references, args.output)
print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):")
print(f" ROUGE-1: {result.rouge1_mean:.4f}")
print(f" ROUGE-2: {result.rouge2_mean:.4f}")
print(f" ROUGE-L: {result.rougeL_mean:.4f}")
print(f" Composite: {result.composite_mean:.4f}")
if __name__ == "__main__":
_cli()