Compare commits
3 Commits
db3583e7f2
...
matt_ds
| Author | SHA1 | Date | |
|---|---|---|---|
| e0f88c126e | |||
| 716bed5300 | |||
| 1e9c2a6caa |
@@ -1,7 +1,7 @@
|
|||||||
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
||||||
#
|
#
|
||||||
# Build:
|
# Build:
|
||||||
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
|
# docker build -f pipelines/pipelines/tools/Dockerfile.finetune -t bill-finetune .
|
||||||
#
|
#
|
||||||
# Run:
|
# Run:
|
||||||
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
||||||
@@ -14,12 +14,13 @@
|
|||||||
|
|
||||||
FROM ghcr.io/unslothai/unsloth:latest
|
FROM ghcr.io/unslothai/unsloth:latest
|
||||||
|
|
||||||
RUN pip install --no-cache-dir typer
|
RUN pip install --no-cache-dir typer rouge-score
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
|
COPY pipelines/tools/__init__.py pipelines/tools/__init__.py
|
||||||
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
|
COPY pipelines/tools/finetune.py pipelines/tools/finetune.py
|
||||||
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
|
COPY pipelines/tools/summarization_eval.py pipelines/tools/summarization_eval.py
|
||||||
COPY python/__init__.py python/__init__.py
|
COPY summarization_prompts.toml config/prompts/summarization_prompts.toml
|
||||||
|
COPY config.toml pipelines/tools/config.toml
|
||||||
|
|
||||||
ENTRYPOINT ["python", "-m", "pipelines.prompt_bench.finetune"]
|
ENTRYPOINT ["python", "-m", "pipelines.tools.finetune"]
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def build_image() -> None:
|
|||||||
"docker",
|
"docker",
|
||||||
"build",
|
"build",
|
||||||
"-f",
|
"-f",
|
||||||
str(REPO_DIR / "python/prompt_bench/Dockerfile.finetune"),
|
str(REPO_DIR / "pipelines/pipelines/tools/Dockerfile.finetune"),
|
||||||
"-t",
|
"-t",
|
||||||
FINETUNE_IMAGE,
|
FINETUNE_IMAGE,
|
||||||
".",
|
".",
|
||||||
@@ -133,7 +133,7 @@ def build() -> None:
|
|||||||
@app.command()
|
@app.command()
|
||||||
def run(
|
def run(
|
||||||
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
|
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = REPO_DIR
|
||||||
/ "data/finetune_dataset.jsonl",
|
/ "/zfs/storage/data_science/data/finetune_dataset.jsonl",
|
||||||
output_dir: Annotated[
|
output_dir: Annotated[
|
||||||
Path, typer.Option(help="Where to save the trained model")
|
Path, typer.Option(help="Where to save the trained model")
|
||||||
] = REPO_DIR / "data/output/qwen-bill-summarizer",
|
] = REPO_DIR / "data/output/qwen-bill-summarizer",
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from datasets import Dataset
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
from .summarization_eval import make_compute_metrics
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -187,6 +189,9 @@ def main(
|
|||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
seed=42,
|
seed=42,
|
||||||
report_to="none",
|
report_to="none",
|
||||||
|
metric_for_best_model="eval_composite",
|
||||||
|
greater_is_better=True,
|
||||||
|
predict_with_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
@@ -197,6 +202,7 @@ def main(
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
max_seq_length=config.training.max_seq_length,
|
max_seq_length=config.training.max_seq_length,
|
||||||
packing=True,
|
packing=True,
|
||||||
|
compute_metrics=make_compute_metrics(tokenizer),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
426
pipelines/tools/summarization_eval.py
Normal file
426
pipelines/tools/summarization_eval.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
"""Summarization evaluation for Congressional bill summaries.
|
||||||
|
|
||||||
|
Three use cases from one module:
|
||||||
|
|
||||||
|
1. Data filtering — score GPT batch outputs before building the fine-tune JSONL:
|
||||||
|
from summarization_eval import filter_dataset
|
||||||
|
filter_dataset("output/finetune_dataset.jsonl", "output/filtered_dataset.jsonl")
|
||||||
|
|
||||||
|
2. Training compute_metrics hook — plug into SFTTrainer for ROUGE-based checkpoint selection:
|
||||||
|
from summarization_eval import make_compute_metrics
|
||||||
|
trainer = SFTTrainer(..., compute_metrics=make_compute_metrics(tokenizer))
|
||||||
|
|
||||||
|
3. Inference eval — score a finished model against held-out references:
|
||||||
|
from summarization_eval import evaluate_file
|
||||||
|
results = evaluate_file("output/predictions.jsonl", "output/references.jsonl")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from rouge_score import rouge_scorer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SECTION_HEADERS = [
|
||||||
|
"OPERATIVE ACTIONS",
|
||||||
|
"AFFECTED POPULATIONS",
|
||||||
|
"MECHANISMS",
|
||||||
|
"POLICY THREADS",
|
||||||
|
"SYMBOLIC/PROCEDURAL ONLY",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Weighted composite: de-emphasise unigram overlap, weight phrase + structure equally
|
||||||
|
ROUGE_WEIGHTS = {
|
||||||
|
"rouge1": 0.2,
|
||||||
|
"rouge2": 0.4,
|
||||||
|
"rougeL": 0.4,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Composite score floor below which a training example is considered low quality
|
||||||
|
FILTER_THRESHOLD = 0.25
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core data structures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SummaryScore:
|
||||||
|
"""Scores for a single (prediction, reference) pair."""
|
||||||
|
|
||||||
|
rouge1: float
|
||||||
|
rouge2: float
|
||||||
|
rougeL: float
|
||||||
|
composite: float
|
||||||
|
has_all_sections: bool # True = all 5 headers present
|
||||||
|
missing_sections: list[str]
|
||||||
|
structural_fail: bool # True = one or more headers missing (hard guardrail)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"rouge1": self.rouge1,
|
||||||
|
"rouge2": self.rouge2,
|
||||||
|
"rougeL": self.rougeL,
|
||||||
|
"composite": self.composite,
|
||||||
|
"has_all_sections": self.has_all_sections,
|
||||||
|
"missing_sections": self.missing_sections,
|
||||||
|
"structural_fail": self.structural_fail,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchResult:
|
||||||
|
"""Aggregate results over a batch of summaries."""
|
||||||
|
|
||||||
|
n_total: int
|
||||||
|
n_structural_fail: int
|
||||||
|
n_scored: int # excludes structural failures
|
||||||
|
rouge1_mean: float
|
||||||
|
rouge2_mean: float
|
||||||
|
rougeL_mean: float
|
||||||
|
composite_mean: float
|
||||||
|
scores: list[SummaryScore]
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"n_total": self.n_total,
|
||||||
|
"n_structural_fail": self.n_structural_fail,
|
||||||
|
"n_scored": self.n_scored,
|
||||||
|
"rouge1_mean": self.rouge1_mean,
|
||||||
|
"rouge2_mean": self.rouge2_mean,
|
||||||
|
"rougeL_mean": self.rougeL_mean,
|
||||||
|
"composite_mean": self.composite_mean,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core scoring
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||||
|
|
||||||
|
|
||||||
|
def check_sections(text: str) -> tuple[bool, list[str]]:
|
||||||
|
"""Return (all_present, missing_headers) for the 5 required section headers."""
|
||||||
|
missing = [h for h in SECTION_HEADERS if h not in text.upper()]
|
||||||
|
return len(missing) == 0, missing
|
||||||
|
|
||||||
|
|
||||||
|
def score_pair(prediction: str, reference: str) -> SummaryScore:
|
||||||
|
"""Score a single (prediction, reference) pair.
|
||||||
|
|
||||||
|
If the prediction is missing any section header, structural_fail is True
|
||||||
|
and ROUGE scores are still computed (so you can inspect quality even on
|
||||||
|
structural failures) but the example should be treated as a guardrail failure.
|
||||||
|
"""
|
||||||
|
has_all, missing = check_sections(prediction)
|
||||||
|
|
||||||
|
rouge = _scorer.score(reference, prediction)
|
||||||
|
r1 = rouge["rouge1"].fmeasure
|
||||||
|
r2 = rouge["rouge2"].fmeasure
|
||||||
|
rl = rouge["rougeL"].fmeasure
|
||||||
|
composite = (
|
||||||
|
ROUGE_WEIGHTS["rouge1"] * r1
|
||||||
|
+ ROUGE_WEIGHTS["rouge2"] * r2
|
||||||
|
+ ROUGE_WEIGHTS["rougeL"] * rl
|
||||||
|
)
|
||||||
|
|
||||||
|
return SummaryScore(
|
||||||
|
rouge1=r1,
|
||||||
|
rouge2=r2,
|
||||||
|
rougeL=rl,
|
||||||
|
composite=composite,
|
||||||
|
has_all_sections=has_all,
|
||||||
|
missing_sections=missing,
|
||||||
|
structural_fail=not has_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def score_batch(pairs: list[tuple[str, str]]) -> BatchResult:
|
||||||
|
"""Score a list of (prediction, reference) pairs and return aggregate results.
|
||||||
|
|
||||||
|
Structural failures are counted separately and excluded from ROUGE means
|
||||||
|
so a batch with broken formatting doesn't drag down the score unfairly.
|
||||||
|
"""
|
||||||
|
scores = [score_pair(pred, ref) for pred, ref in pairs]
|
||||||
|
|
||||||
|
structural_fails = [s for s in scores if s.structural_fail]
|
||||||
|
valid = [s for s in scores if not s.structural_fail]
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
rouge1_mean = float(np.mean([s.rouge1 for s in valid]))
|
||||||
|
rouge2_mean = float(np.mean([s.rouge2 for s in valid]))
|
||||||
|
rougeL_mean = float(np.mean([s.rougeL for s in valid]))
|
||||||
|
composite_mean = float(np.mean([s.composite for s in valid]))
|
||||||
|
else:
|
||||||
|
rouge1_mean = rouge2_mean = rougeL_mean = composite_mean = 0.0
|
||||||
|
|
||||||
|
return BatchResult(
|
||||||
|
n_total=len(scores),
|
||||||
|
n_structural_fail=len(structural_fails),
|
||||||
|
n_scored=len(valid),
|
||||||
|
rouge1_mean=rouge1_mean,
|
||||||
|
rouge2_mean=rouge2_mean,
|
||||||
|
rougeL_mean=rougeL_mean,
|
||||||
|
composite_mean=composite_mean,
|
||||||
|
scores=scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Use case 1: Data filtering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def filter_dataset(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
*,
|
||||||
|
threshold: float = FILTER_THRESHOLD,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""Filter a fine-tuning JSONL by ROUGE composite score and section guardrail.
|
||||||
|
|
||||||
|
Each line must be a ChatML messages dict:
|
||||||
|
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
|
||||||
|
|
||||||
|
The assistant turn is the prediction. The reference is the same assistant
|
||||||
|
turn — filtering here uses composite score as a self-consistency check
|
||||||
|
against the threshold, and drops structural failures unconditionally.
|
||||||
|
|
||||||
|
In practice you'd call this after joining requests + GPT completions
|
||||||
|
(build_finetune_dataset.py) to drop any GPT outputs that are malformed
|
||||||
|
or suspiciously short/low quality.
|
||||||
|
|
||||||
|
Returns (kept, dropped).
|
||||||
|
"""
|
||||||
|
input_path = Path(input_path)
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
dropped = 0
|
||||||
|
|
||||||
|
with input_path.open(encoding="utf-8") as fin, output_path.open("w", encoding="utf-8") as fout:
|
||||||
|
for line_num, raw_line in enumerate(fin, 1):
|
||||||
|
stripped = raw_line.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
|
||||||
|
example = json.loads(stripped)
|
||||||
|
messages = example.get("messages", [])
|
||||||
|
assistant_turns = [m for m in messages if m.get("role") == "assistant"]
|
||||||
|
|
||||||
|
if not assistant_turns:
|
||||||
|
logger.warning("Line %d: no assistant turn, dropping", line_num)
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
prediction = assistant_turns[-1].get("content", "")
|
||||||
|
|
||||||
|
# Guardrail: drop if any section header missing
|
||||||
|
has_all, missing = check_sections(prediction)
|
||||||
|
if not has_all:
|
||||||
|
logger.warning(
|
||||||
|
"Line %d: structural fail (missing: %s), dropping",
|
||||||
|
line_num,
|
||||||
|
", ".join(missing),
|
||||||
|
)
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Quality floor: score against itself isn't meaningful for filtering —
|
||||||
|
# instead just check composite score of prediction vs a simple
|
||||||
|
# word-count proxy. For filtering GPT outputs, structural check
|
||||||
|
# + a minimum word count is usually sufficient.
|
||||||
|
word_count = len(prediction.split())
|
||||||
|
if word_count < 80:
|
||||||
|
logger.warning(
|
||||||
|
"Line %d: too short (%d words), dropping", line_num, word_count
|
||||||
|
)
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
fout.write(json.dumps(example, ensure_ascii=False) + "\n")
|
||||||
|
kept += 1
|
||||||
|
|
||||||
|
logger.info("Filtered dataset: kept=%d dropped=%d -> %s", kept, dropped, output_path)
|
||||||
|
return kept, dropped
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Use case 2: compute_metrics hook for SFTTrainer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def make_compute_metrics(tokenizer) -> Callable: # noqa: ANN001
|
||||||
|
"""Return a compute_metrics function compatible with HuggingFace Trainer.
|
||||||
|
|
||||||
|
Usage in finetune.py:
|
||||||
|
from summarization_eval import make_compute_metrics
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
...
|
||||||
|
compute_metrics=make_compute_metrics(tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
Note: EvalPrediction.predictions are logits (or token ids if
|
||||||
|
include_inputs_for_metrics is False). This function handles both.
|
||||||
|
For SFTTrainer with packing=True, you may need to set
|
||||||
|
predict_with_generate=True in TrainingArguments to get decoded text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred) -> dict[str, float]: # noqa: ANN001
|
||||||
|
predictions, labels = eval_pred
|
||||||
|
|
||||||
|
# If predictions are logits, take argmax
|
||||||
|
if predictions.ndim == 3:
|
||||||
|
predictions = np.argmax(predictions, axis=-1)
|
||||||
|
|
||||||
|
# Mask out -100 padding in labels
|
||||||
|
labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
|
||||||
|
|
||||||
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
||||||
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
|
pairs = list(zip(decoded_preds, decoded_labels))
|
||||||
|
result = score_batch(pairs)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"eval_rouge1": result.rouge1_mean,
|
||||||
|
"eval_rouge2": result.rouge2_mean,
|
||||||
|
"eval_rougeL": result.rougeL_mean,
|
||||||
|
"eval_composite": result.composite_mean,
|
||||||
|
"eval_structural_fail_rate": (
|
||||||
|
result.n_structural_fail / result.n_total if result.n_total else 0.0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Eval: composite=%.4f rouge1=%.4f rouge2=%.4f rougeL=%.4f structural_fail=%d/%d",
|
||||||
|
metrics["eval_composite"],
|
||||||
|
metrics["eval_rouge1"],
|
||||||
|
metrics["eval_rouge2"],
|
||||||
|
metrics["eval_rougeL"],
|
||||||
|
result.n_structural_fail,
|
||||||
|
result.n_total,
|
||||||
|
)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
return compute_metrics
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Use case 3: Inference eval against held-out references
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def evaluate_file(
|
||||||
|
predictions_path: Path | str,
|
||||||
|
references_path: Path | str,
|
||||||
|
output_path: Path | str | None = None,
|
||||||
|
) -> BatchResult:
|
||||||
|
"""Score a predictions JSONL against a references JSONL.
|
||||||
|
|
||||||
|
Both files should be line-matched: line N of predictions corresponds
|
||||||
|
to line N of references. Each line should be a plain JSON object with
|
||||||
|
a "text" or "content" key, or a ChatML messages dict.
|
||||||
|
|
||||||
|
If output_path is provided, writes per-example scores as JSONL.
|
||||||
|
"""
|
||||||
|
predictions_path = Path(predictions_path)
|
||||||
|
references_path = Path(references_path)
|
||||||
|
|
||||||
|
def extract_text(line: str) -> str:
|
||||||
|
obj = json.loads(line)
|
||||||
|
# Plain text field
|
||||||
|
if "text" in obj:
|
||||||
|
return obj["text"]
|
||||||
|
if "content" in obj:
|
||||||
|
return obj["content"]
|
||||||
|
# ChatML messages — take last assistant turn
|
||||||
|
messages = obj.get("messages", [])
|
||||||
|
for m in reversed(messages):
|
||||||
|
if m.get("role") == "assistant":
|
||||||
|
return m.get("content", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
preds = [extract_text(l) for l in predictions_path.read_text().splitlines() if l.strip()]
|
||||||
|
refs = [extract_text(l) for l in references_path.read_text().splitlines() if l.strip()]
|
||||||
|
|
||||||
|
if len(preds) != len(refs):
|
||||||
|
msg = f"Prediction count ({len(preds)}) != reference count ({len(refs)})"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
result = score_batch(list(zip(preds, refs)))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Inference eval: n=%d structural_fails=%d composite=%.4f "
|
||||||
|
"rouge1=%.4f rouge2=%.4f rougeL=%.4f",
|
||||||
|
result.n_total,
|
||||||
|
result.n_structural_fail,
|
||||||
|
result.composite_mean,
|
||||||
|
result.rouge1_mean,
|
||||||
|
result.rouge2_mean,
|
||||||
|
result.rougeL_mean,
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with output_path.open("w", encoding="utf-8") as fout:
|
||||||
|
for score in result.scores:
|
||||||
|
fout.write(json.dumps(score.as_dict(), ensure_ascii=False) + "\n")
|
||||||
|
summary_path = output_path.with_suffix(".summary.json")
|
||||||
|
summary_path.write_text(json.dumps(result.as_dict(), indent=2))
|
||||||
|
logger.info("Wrote per-example scores to %s", output_path)
|
||||||
|
logger.info("Wrote summary to %s", summary_path)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI — quick sanity check / standalone use
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _cli() -> None:
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate bill summarization quality.")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# filter subcommand
|
||||||
|
fp = subparsers.add_parser("filter", help="Filter a fine-tuning JSONL dataset")
|
||||||
|
fp.add_argument("--input", required=True, type=Path)
|
||||||
|
fp.add_argument("--output", required=True, type=Path)
|
||||||
|
fp.add_argument("--threshold", type=float, default=FILTER_THRESHOLD)
|
||||||
|
|
||||||
|
# eval subcommand
|
||||||
|
ep = subparsers.add_parser("eval", help="Score predictions against references")
|
||||||
|
ep.add_argument("--predictions", required=True, type=Path)
|
||||||
|
ep.add_argument("--references", required=True, type=Path)
|
||||||
|
ep.add_argument("--output", type=Path, default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
if args.command == "filter":
|
||||||
|
kept, dropped = filter_dataset(args.input, args.output, threshold=args.threshold)
|
||||||
|
print(f"Kept: {kept} Dropped: {dropped}")
|
||||||
|
|
||||||
|
elif args.command == "eval":
|
||||||
|
result = evaluate_file(args.predictions, args.references, args.output)
|
||||||
|
print(f"\nResults ({result.n_scored} scored, {result.n_structural_fail} structural fails):")
|
||||||
|
print(f" ROUGE-1: {result.rouge1_mean:.4f}")
|
||||||
|
print(f" ROUGE-2: {result.rouge2_mean:.4f}")
|
||||||
|
print(f" ROUGE-L: {result.rougeL_mean:.4f}")
|
||||||
|
print(f" Composite: {result.composite_mean:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_cli()
|
||||||
Reference in New Issue
Block a user