diff --git a/pipelines/tools/Dockerfile.finetune b/pipelines/tools/Dockerfile.finetune index c9db8a2..7be9a78 100644 --- a/pipelines/tools/Dockerfile.finetune +++ b/pipelines/tools/Dockerfile.finetune @@ -14,12 +14,13 @@ FROM ghcr.io/unslothai/unsloth:latest -RUN pip install --no-cache-dir typer +RUN pip install --no-cache-dir typer rouge-score WORKDIR /workspace COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py +COPY python/prompt_bench/summarization_eval.py python/prompt_bench/summarization_eval.py COPY python/__init__.py python/__init__.py ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"] diff --git a/pipelines/tools/containers/finetune.py b/pipelines/tools/containers/finetune.py index f28fe02..82455cd 100644 --- a/pipelines/tools/containers/finetune.py +++ b/pipelines/tools/containers/finetune.py @@ -133,7 +133,7 @@ def build() -> None: @app.command() def run( dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR - / "data/finetune_dataset.jsonl", + / "/zfs/storage/data_science/data/finetune_dataset.jsonl", output_dir: Annotated[ Path, typer.Option(help="Where to save the trained model") ] = REPO_DIR / "data/output/qwen-bill-summarizer", diff --git a/pipelines/tools/finetune.py b/pipelines/tools/finetune.py index 3fccf57..880c69c 100644 --- a/pipelines/tools/finetune.py +++ b/pipelines/tools/finetune.py @@ -25,6 +25,8 @@ from datasets import Dataset from transformers import TrainingArguments from trl import SFTTrainer +from .summarization_eval import make_compute_metrics + logger = logging.getLogger(__name__) @@ -187,6 +189,9 @@ def main( optim="adamw_8bit", seed=42, report_to="none", + metric_for_best_model="eval_composite", + greater_is_better=True, + predict_with_generate=True, ) trainer = SFTTrainer( @@ -197,6 +202,7 @@ def main( args=training_args, max_seq_length=config.training.max_seq_length, packing=True, + compute_metrics=make_compute_metrics(tokenizer), ) logger.info( diff --git a/pipelines/tools/summarization_eval.py b/pipelines/tools/summarization_eval.py new file mode 100644 index 0000000..78a4904 --- /dev/null +++ b/pipelines/tools/summarization_eval.py @@ -0,0 +1,426 @@ +"""Summarization evaluation for Congressional bill summaries. + +Three use cases from one module: + +1. Data filtering — score GPT batch outputs before building the fine-tune JSONL: + from summarization_eval import filter_dataset + filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl") + +2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection: + from summarization_eval import make_compute_metrics + trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer)) + +3. Inference eval — score a finished model against held-out references: + from summarization_eval import evaluate_file + results = evaluate_file("output/predictions.jsonl", "output/references.jsonl") +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import numpy as np +from rouge_score import rouge_scorer + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +SECTION_HEADERS = [ + "OPERATIVE ACTIONS", + "AFFECTED POPULATIONS", + "MECHANISMS", + "POLICY THREADS", + "SYMBOLIC/PROCEDURAL ONLY", +] + +# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally +ROUGE_WEIGHTS = { + "rouge1": 0.2, + "rouge2": 0.4, + "rougeL": 0.4, +} + +# Composite score floor below which a training example is considered low quality +FILTER_THRESHOLD = 0.25 + + +# --------------------------------------------------------------------------- +# Core data structures +# --------------------------------------------------------------------------- + +@dataclass +class SummaryScore: + """Scores for a single (prediction, reference) pair.""" + + rouge1: float + rouge2: float + rougeL: float + composite: float + has_all_sections: bool # True = all 5 headers present + missing_sections: list[str] + structural_fail: bool # True = one or more headers missing (hard guardrail) + + def as_dict(self) -> dict: + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "composite": self.composite, + "has_all_sections": self.has_all_sections, + "missing_sections": self.missing_sections, + "structural_fail": self.structural_fail, + } + + +@dataclass +class BatchResult: + """Aggregate results over a batch of summaries.""" + + n_total: int + n_structural_fail: int + n_scored: int # excludes structural failures + rouge1_mean: float + rouge2_mean: float + rougeL_mean: float + composite_mean: float + scores: list[SummaryScore] + + def as_dict(self) -> dict: + return { + "n_total": self.n_total, + "n_structural_fail": self.n_structural_fail, + "n_scored": self.n_scored, + "rouge1_mean": self.rouge1_mean, + "rouge2_mean": self.rouge2_mean, + "rougeL_mean": self.rougeL_mean, + "composite_mean": self.composite_mean, + } + + +# --------------------------------------------------------------------------- +# Core scoring +# --------------------------------------------------------------------------- + +_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) + + +def check_sections(text: str) -> tuple[bool, list[str]]: + """Return (all_present, missing_headers) for the 5 required section headers.""" + missing = [h for h in SECTION_HEADERS if h not in text.upper()] + return len(missing) == 0, missing + + +def score_pair(prediction: str, reference: str) -> SummaryScore: + """Score a single (prediction, reference) pair. + + If the prediction is missing any section header, structural_fail is True + and ROUGE scores are still computed (so you can inspect quality even on + structural failures) but the example should be treated as a guardrail failure. + """ + has_all, missing = check_sections(prediction) + + rouge = _scorer.score(reference, prediction) + r1 = rouge["rouge1"].fmeasure + r2 = rouge["rouge2"].fmeasure + rl = rouge["rougeL"].fmeasure + composite = ( + ROUGE_WEIGHTS["rouge1"] * r1 + + ROUGE_WEIGHTS["rouge2"] * r2 + + ROUGE_WEIGHTS["rougeL"] * rl + ) + + return SummaryScore( + rouge1=r1, + rouge2=r2, + rougeL=rl, + composite=composite, + has_all_sections=has_all, + missing_sections=missing, + structural_fail=not has_all, + ) + + +def score_batch(pairs: list[tuple[str, str]]) -> BatchResult: + """Score a list of (prediction, reference) pairs and return aggregate results. + + Structural failures are counted separately and excluded from ROUGE means + so a batch with broken formatting doesn't drag down the score unfairly. + """ + scores = [score_pair(pred, ref) for pred, ref in pairs] + + structural_fails = [s for s in scores if s.structural_fail] + valid = [s for s in scores if not s.structural_fail] + + if valid: + rouge1_mean = float(np.mean([s.rouge1 for s in valid])) + rouge2_mean = float(np.mean([s.rouge2 for s in valid])) + rougeL_mean = float(np.mean([s.rougeL for s in valid])) + composite_mean = float(np.mean([s.composite for s in valid])) + else: + rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0 + + return BatchResult( + n_total=len(scores), + n_structural_fail=len(structural_fails), + n_scored=len(valid), + rouge1_mean=rouge1_mean, + rouge2_mean=rouge2_mean, + rougeL_mean=rougeL_mean, + composite_mean=composite_mean, + scores=scores, + ) + + +# --------------------------------------------------------------------------- +# Use case 1: Data filtering +# --------------------------------------------------------------------------- + +def filter_dataset( + input_path: Path | str, + output_path: Path | str, + *, + threshold: float = FILTER_THRESHOLD, +) -> tuple[int, int]: + """Filter a fine-tuning JSONL by ROUGE composite score and section guardrail. + + Each line must be a ChatML messages dict: + {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]} + + The assistant turn is the prediction. The reference is the same assistant + turn — filtering here uses composite score as a self-consistency check + against the threshold, and drops structural failures unconditionally. + + In practice you'd call this after joining requests + GPT completions + (build_finetune_dataset.py) to drop any GPT outputs that are malformed + or suspiciously short/low quality. + + Returns (kept, dropped). + """ + input_path = Path(input_path) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + kept = 0 + dropped = 0 + + with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout: + for line_num, raw_line in enumerate(fin, 1): + stripped = raw_line.strip() + if not stripped: + continue + + example = json.loads(stripped) + messages = example.get("messages", []) + assistant_turns = [m for m in messages if m.get("role") == "assistant"] + + if not assistant_turns: + logger.warning("Line %d: no assistant turn, dropping", line_num) + dropped += 1 + continue + + prediction = assistant_turns[-1].get("content", "") + + # Guardrail: drop if any section header missing + has_all, missing = check_sections(prediction) + if not has_all: + logger.warning( + "Line %d: structural fail (missing: %s), dropping", + line_num, + ", ".join(missing), + ) + dropped += 1 + continue + + # Quality floor: score against itself isn't meaningful for filtering — + # instead just check composite score of prediction vs a simple + # word-count proxy. For filtering GPT outputs, structural check + # + a minimum word count is usually sufficient. + word_count = len(prediction.split()) + if word_count < 80: + logger.warning( + "Line %d: too short (%d words), dropping", line_num, word_count + ) + dropped += 1 + continue + + fout.write(json.dumps(example, ensure_ascii=False) + "\n") + kept += 1 + + logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path) + return kept, dropped + + +# --------------------------------------------------------------------------- +# Use case 2: compute_metrics hook for SFTTrainer +# --------------------------------------------------------------------------- + +def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001 + """Return a compute_metrics function compatible with HuggingFace Trainer. + + Usage in finetune.py: + from summarization_eval import make_compute_metrics + trainer = SFTTrainer( + ... + compute_metrics=make_compute_metrics(tokenizer), + ) + + Note: EvalPrediction.predictions are logits (or token ids if + include_inputs_for_metrics is False). This function handles both. + For SFTTrainer with packing=True, you may need to set + predict_with_generate=True in TrainingArguments to get decoded text. + """ + + def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001 + predictions, labels = eval_pred + + # If predictions are logits, take argmax + if predictions.ndim == 3: + predictions = np.argmax(predictions, axis=-1) + + # Mask out -100 padding in labels + labels = np.where(labels == -100, tokenizer.pad_token_id, labels) + + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + pairs = list(zip(decoded_preds, decoded_labels)) + result = score_batch(pairs) + + metrics = { + "eval_rouge1": result.rouge1_mean, + "eval_rouge2": result.rouge2_mean, + "eval_rougeL": result.rougeL_mean, + "eval_composite": result.composite_mean, + "eval_structural_fail_rate": ( + result.n_structural_fail / result.n_total if result.n_total else 0.0 + ), + } + logger.info( + "Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d", + metrics["eval_composite"], + metrics["eval_rouge1"], + metrics["eval_rouge2"], + metrics["eval_rougeL"], + result.n_structural_fail, + result.n_total, + ) + return metrics + + return compute_metrics + + +# --------------------------------------------------------------------------- +# Use case 3: Inference eval against held-out references +# --------------------------------------------------------------------------- + +def evaluate_file( + predictions_path: Path | str, + references_path: Path | str, + output_path: Path | str | None = None, +) -> BatchResult: + """Score a predictions JSONL against a references JSONL. + + Both files should be line-matched: line N of predictions corresponds + to line N of references. Each line should be a plain JSON object with + a "text" or "content" key, or a ChatML messages dict. + + If output_path is provided, writes per-example scores as JSONL. + """ + predictions_path = Path(predictions_path) + references_path = Path(references_path) + + def extract_text(line: str) -> str: + obj = json.loads(line) + # Plain text field + if "text" in obj: + return obj["text"] + if "content" in obj: + return obj["content"] + # ChatML messages — take last assistant turn + messages = obj.get("messages", []) + for m in reversed(messages): + if m.get("role") == "assistant": + return m.get("content", "") + return "" + + preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()] + refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()] + + if len(preds) != len(refs): + msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})" + raise ValueError(msg) + + result = score_batch(list(zip(preds, refs))) + + logger.info( + "Inference eval: n=%d structural_fails=%d composite=%.4f " + "rouge1=%.4f rouge2=%.4f rougeL=%.4f", + result.n_total, + result.n_structural_fail, + result.composite_mean, + result.rouge1_mean, + result.rouge2_mean, + result.rougeL_mean, + ) + + if output_path is not None: + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as fout: + for score in result.scores: + fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n") + summary_path = output_path.with_suffix(".summary.json") + summary_path.write_text(json.dumps(result.as_dict(), indent=2)) + logger.info("Wrote per-example scores to %s", output_path) + logger.info("Wrote summary to %s", summary_path) + + return result + + +# --------------------------------------------------------------------------- +# CLI — quick sanity check / standalone use +# --------------------------------------------------------------------------- + +def _cli() -> None: + import argparse + + parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.") + subparsers = parser.add_subparsers(dest="command", required=True) + + # filter subcommand + fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset") + fp.add_argument("--input", required=True, type=Path) + fp.add_argument("--output", required=True, type=Path) + fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD) + + # eval subcommand + ep = subparsers.add_parser("eval", help="Score predictions against references") + ep.add_argument("--predictions", required=True, type=Path) + ep.add_argument("--references", required=True, type=Path) + ep.add_argument("--output", type=Path, default=None) + + args = parser.parse_args() + logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s") + + if args.command == "filter": + kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold) + print(f"Kept: {kept} Dropped: {dropped}") + + elif args.command == "eval": + result = evaluate_file(args.predictions, args.references, args.output) + print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):") + print(f" ROUGE-1: {result.rouge1_mean:.4f}") + print(f" ROUGE-2: {result.rouge2_mean:.4f}") + print(f" ROUGE-L: {result.rougeL_mean:.4f}") + print(f" Composite: {result.composite_mean:.4f}") + + +if __name__ == "__main__": + _cli() \ No newline at end of file