start
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
"""init."""
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig:
|
||||||
|
"""LoRA adapter hyperparameters."""
|
||||||
|
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
dropout: float
|
||||||
|
targets: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
"""Training loop hyperparameters."""
|
||||||
|
|
||||||
|
learning_rate: float
|
||||||
|
epochs: int
|
||||||
|
batch_size: int
|
||||||
|
gradient_accumulation: int
|
||||||
|
max_seq_length: int
|
||||||
|
warmup_ratio: float
|
||||||
|
weight_decay: float
|
||||||
|
logging_steps: int
|
||||||
|
save_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuneConfig:
|
||||||
|
"""Top-level finetune configuration."""
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
lora: LoraConfig
|
||||||
|
training: TrainingConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> FinetuneConfig:
|
||||||
|
"""Load finetune config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["finetune"]
|
||||||
|
return cls(
|
||||||
|
base_model=raw["base_model"],
|
||||||
|
lora=LoraConfig(**raw["lora"]),
|
||||||
|
training=TrainingConfig(**raw["training"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig:
|
||||||
|
"""Top-level benchmark configuration loaded from TOML."""
|
||||||
|
|
||||||
|
models: list[str]
|
||||||
|
model_dir: str
|
||||||
|
port: int
|
||||||
|
gpu_memory_utilization: float
|
||||||
|
temperature: float
|
||||||
|
timeout: int
|
||||||
|
concurrency: int
|
||||||
|
vllm_startup_timeout: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> BenchmarkConfig:
|
||||||
|
"""Load benchmark config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["bench"]
|
||||||
|
return cls(**raw)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dir() -> Path:
|
||||||
|
"""Get the path to the config file."""
|
||||||
|
return Path(__file__).resolve().parent.parent.parent / "config"
|
||||||
|
|
||||||
|
def default_config_path() -> Path:
|
||||||
|
"""Get the path to the config file."""
|
||||||
|
return get_config_dir() / "config.toml"
|
||||||
|
|
||||||
|
|
||||||
|
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
||||||
|
if config_path is None:
|
||||||
|
config_path = default_config_path()
|
||||||
|
return FinetuneConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
|
||||||
|
if config_path is None:
|
||||||
|
config_path = default_config_path()
|
||||||
|
return BenchmarkConfig.from_toml(config_path)
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
|
||||||
|
#
|
||||||
|
# Build:
|
||||||
|
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
|
||||||
|
#
|
||||||
|
# Run:
|
||||||
|
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
|
||||||
|
# -v $(pwd)/output:/workspace/output \
|
||||||
|
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
|
||||||
|
# -v /zfs/models/hf:/models \
|
||||||
|
# bill-finetune \
|
||||||
|
# --dataset /workspace/dataset.jsonl \
|
||||||
|
# --output-dir /workspace/output/qwen-bill-summarizer
|
||||||
|
|
||||||
|
FROM ghcr.io/unslothai/unsloth:latest
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir typer
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
|
||||||
|
COPY config/prompts/summarization_prompts.toml config/prompts/summarization_prompts.toml
|
||||||
|
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
|
||||||
|
COPY python/__init__.py python/__init__.py
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "python.prompt_bench.finetune"]
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,238 @@
|
|||||||
|
"""Submit an OpenAI Batch API bill-summarization job over compressed text.
|
||||||
|
|
||||||
|
Reads the first N bills from a CSV with a `text_content` column, compresses
|
||||||
|
each via `bill_token_compression.compress_bill_text`, builds a JSONL file of
|
||||||
|
summarization requests, and submits it as an asynchronous Batch API job
|
||||||
|
against `/v1/chat/completions`. Also writes a CSV of per-bill pre/post-
|
||||||
|
compression token counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import tomllib
|
||||||
|
from os import getenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from tiktoken import Encoding, get_encoding
|
||||||
|
|
||||||
|
from python.prompt_bench.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml"
|
||||||
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def load_bills(csv_path: Path, count: int = 0) -> list[tuple[str, str]]:
|
||||||
|
"""Return (bill_id, text_content) tuples with non-empty text.
|
||||||
|
|
||||||
|
If `count` is 0 or negative, all rows are returned.
|
||||||
|
"""
|
||||||
|
csv.field_size_limit(sys.maxsize)
|
||||||
|
bills: list[tuple[str, str]] = []
|
||||||
|
with csv_path.open(newline="", encoding="utf-8") as handle:
|
||||||
|
reader = csv.DictReader(handle)
|
||||||
|
for row in reader:
|
||||||
|
text_content = (row.get("text_content") or "").strip()
|
||||||
|
if not text_content:
|
||||||
|
continue
|
||||||
|
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
|
||||||
|
version_code = row.get("version_code") or ""
|
||||||
|
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
|
||||||
|
bills.append((unique_id, text_content))
|
||||||
|
if count > 0 and len(bills) >= count:
|
||||||
|
break
|
||||||
|
return bills
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(value: str) -> str:
|
||||||
|
"""Make a string safe for use as a filename or batch custom_id."""
|
||||||
|
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
|
||||||
|
|
||||||
|
|
||||||
|
def build_request(custom_id: str, model: str, bill_text: str) -> dict:
|
||||||
|
"""Build one OpenAI batch request line."""
|
||||||
|
return {
|
||||||
|
"custom_id": custom_id,
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_jsonl(path: Path, lines: list[dict]) -> None:
|
||||||
|
"""Write a list of dicts as JSONL."""
|
||||||
|
with path.open("w", encoding="utf-8") as handle:
|
||||||
|
for line in lines:
|
||||||
|
handle.write(json.dumps(line, ensure_ascii=False))
|
||||||
|
handle.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file(client: httpx.Client, path: Path) -> str:
|
||||||
|
"""Upload a JSONL file to the OpenAI Files API and return its file id."""
|
||||||
|
with path.open("rb") as handle:
|
||||||
|
response = client.post(
|
||||||
|
f"{OPENAI_API_BASE}/files",
|
||||||
|
files={"file": (path.name, handle, "application/jsonl")},
|
||||||
|
data={"purpose": "batch"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_requests(
|
||||||
|
bills: list[tuple[str, str]],
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
encoder: Encoding,
|
||||||
|
) -> tuple[list[dict], list[dict]]:
|
||||||
|
"""Build (request_lines, token_rows) from bills.
|
||||||
|
|
||||||
|
Each bill is compressed before being turned into a request line.
|
||||||
|
Each `token_rows` entry has chars + token counts for one bill so the caller
|
||||||
|
can write a per-bill CSV.
|
||||||
|
"""
|
||||||
|
request_lines: list[dict] = []
|
||||||
|
token_rows: list[dict] = []
|
||||||
|
for bill_id, text_content in bills:
|
||||||
|
raw_token_count = len(encoder.encode(text_content))
|
||||||
|
compressed_text = compress_bill_text(text_content)
|
||||||
|
compressed_token_count = len(encoder.encode(compressed_text))
|
||||||
|
token_rows.append(
|
||||||
|
{
|
||||||
|
"bill_id": bill_id,
|
||||||
|
"raw_chars": len(text_content),
|
||||||
|
"compressed_chars": len(compressed_text),
|
||||||
|
"raw_tokens": raw_token_count,
|
||||||
|
"compressed_tokens": compressed_token_count,
|
||||||
|
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
safe_id = safe_filename(bill_id)
|
||||||
|
request_lines.append(build_request(safe_id, model, compressed_text))
|
||||||
|
return request_lines, token_rows
|
||||||
|
|
||||||
|
|
||||||
|
def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
|
||||||
|
"""Write per-bill token counts to CSV. Returns (raw_total, compressed_total)."""
|
||||||
|
with path.open("w", newline="", encoding="utf-8") as handle:
|
||||||
|
writer = csv.DictWriter(
|
||||||
|
handle,
|
||||||
|
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"],
|
||||||
|
)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(token_rows)
|
||||||
|
raw_total = sum(row["raw_tokens"] for row in token_rows)
|
||||||
|
compressed_total = sum(row["compressed_tokens"] for row in token_rows)
|
||||||
|
return raw_total, compressed_total
|
||||||
|
|
||||||
|
|
||||||
|
def create_batch(client: httpx.Client, input_file_id: str, description: str) -> dict:
|
||||||
|
"""Create a batch job and return its full response payload."""
|
||||||
|
response = client.post(
|
||||||
|
f"{OPENAI_API_BASE}/batches",
|
||||||
|
json={
|
||||||
|
"input_file_id": input_file_id,
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"completion_window": "24h",
|
||||||
|
"metadata": {"description": description},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
|
||||||
|
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path(
|
||||||
|
"output/openai_batch",
|
||||||
|
),
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
|
||||||
|
count: Annotated[int, typer.Option(help="Max bills to process, 0 = all")] = 0,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Submit an OpenAI Batch job of compressed bill summaries."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
if not csv_path.is_file():
|
||||||
|
message = f"CSV not found: {csv_path}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("Loading %d bills from %s", count, csv_path)
|
||||||
|
bills = load_bills(csv_path, count)
|
||||||
|
if len(bills) < count:
|
||||||
|
logger.warning("Only %d bills available (requested %d)", len(bills), count)
|
||||||
|
|
||||||
|
encoder = get_encoding("o200k_base")
|
||||||
|
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
|
||||||
|
|
||||||
|
token_csv_path = output_dir / "token_counts.csv"
|
||||||
|
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows)
|
||||||
|
logger.info(
|
||||||
|
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
|
||||||
|
raw_tokens_total,
|
||||||
|
compressed_tokens_total,
|
||||||
|
(compressed_tokens_total / raw_tokens_total) if raw_tokens_total else 0.0,
|
||||||
|
token_csv_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = output_dir / "requests.jsonl"
|
||||||
|
write_jsonl(jsonl_path, request_lines)
|
||||||
|
logger.info("Wrote %s (%d bills)", jsonl_path, len(request_lines))
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
with httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client:
|
||||||
|
logger.info("Uploading JSONL")
|
||||||
|
file_id = upload_file(client, jsonl_path)
|
||||||
|
logger.info("Uploaded: %s", file_id)
|
||||||
|
|
||||||
|
logger.info("Creating batch")
|
||||||
|
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})")
|
||||||
|
logger.info("Batch created: %s", batch["id"])
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"model": model,
|
||||||
|
"count": len(bills),
|
||||||
|
"jsonl": str(jsonl_path),
|
||||||
|
"input_file_id": file_id,
|
||||||
|
"batch_id": batch["id"],
|
||||||
|
"raw_tokens_total": raw_tokens_total,
|
||||||
|
"compressed_tokens_total": compressed_tokens_total,
|
||||||
|
"batch": batch,
|
||||||
|
}
|
||||||
|
metadata_path = output_dir / "batch.json"
|
||||||
|
metadata_path.write_text(json.dumps(metadata, indent=2))
|
||||||
|
logger.info("Wrote metadata to %s", metadata_path)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""Lossless-ish text compression for Congressional bill text."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
STATES = (
|
||||||
|
"Alabama",
|
||||||
|
"Alaska",
|
||||||
|
"Arizona",
|
||||||
|
"Arkansas",
|
||||||
|
"California",
|
||||||
|
"Colorado",
|
||||||
|
"Connecticut",
|
||||||
|
"Delaware",
|
||||||
|
"Florida",
|
||||||
|
"Georgia",
|
||||||
|
"Hawaii",
|
||||||
|
"Idaho",
|
||||||
|
"Illinois",
|
||||||
|
"Indiana",
|
||||||
|
"Iowa",
|
||||||
|
"Kansas",
|
||||||
|
"Kentucky",
|
||||||
|
"Louisiana",
|
||||||
|
"Maine",
|
||||||
|
"Maryland",
|
||||||
|
"Massachusetts",
|
||||||
|
"Michigan",
|
||||||
|
"Minnesota",
|
||||||
|
"Mississippi",
|
||||||
|
"Missouri",
|
||||||
|
"Montana",
|
||||||
|
"Nebraska",
|
||||||
|
"Nevada",
|
||||||
|
"New Hampshire",
|
||||||
|
"New Jersey",
|
||||||
|
"New Mexico",
|
||||||
|
"New York",
|
||||||
|
"North Carolina",
|
||||||
|
"North Dakota",
|
||||||
|
"Ohio",
|
||||||
|
"Oklahoma",
|
||||||
|
"Oregon",
|
||||||
|
"Pennsylvania",
|
||||||
|
"Rhode Island",
|
||||||
|
"South Carolina",
|
||||||
|
"South Dakota",
|
||||||
|
"Tennessee",
|
||||||
|
"Texas",
|
||||||
|
"Utah",
|
||||||
|
"Vermont",
|
||||||
|
"Virginia",
|
||||||
|
"Washington",
|
||||||
|
"West Virginia",
|
||||||
|
"Wisconsin",
|
||||||
|
"Wyoming",
|
||||||
|
"Puerto Rico",
|
||||||
|
"Guam",
|
||||||
|
"American Samoa",
|
||||||
|
"District of Columbia",
|
||||||
|
"US Virgin Islands",
|
||||||
|
)
|
||||||
|
STATE_PATTERNS = [(re.compile(re.escape(state), re.IGNORECASE), state) for state in STATES]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_state_names(text: str) -> str:
|
||||||
|
"""Replace any casing of state names with title case."""
|
||||||
|
for pattern, replacement in STATE_PATTERNS:
|
||||||
|
text = pattern.sub(replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def strip_number_commas(text: str) -> str:
|
||||||
|
"""Remove commas from numeric thousands separators."""
|
||||||
|
return re.sub(r"(\d{1,3}(?:,\d{3})+)", lambda match: match.group().replace(",", ""), text)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_horizontal_rules(text: str) -> str:
|
||||||
|
"""Remove ASCII horizontal-rule lines built from underscores, dashes, equals, or asterisks."""
|
||||||
|
return re.sub(r"^\s*[_\-=\*]{3,}\s*$", "", text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_double_dashes(text: str) -> str:
|
||||||
|
"""Replace ``--`` em-dash stand-ins with a single space so they don't tokenize oddly."""
|
||||||
|
return text.replace("--", " ")
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_inline_whitespace(text: str) -> str:
|
||||||
|
"""Collapse runs of horizontal whitespace (spaces, tabs) into a single space, leaving newlines intact."""
|
||||||
|
return re.sub(r"[^\S\n]+", " ", text)
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_blank_lines(text: str) -> str:
|
||||||
|
"""Collapse three-or-more consecutive newlines down to a blank-line separator."""
|
||||||
|
return re.sub(r"\n{3,}", "\n\n", text)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_line_edges(text: str) -> str:
|
||||||
|
"""Strip spaces immediately before and after newline characters on every line."""
|
||||||
|
text = re.sub(r" +\n", "\n", text)
|
||||||
|
return re.sub(r"\n +", "\n", text)
|
||||||
|
|
||||||
|
|
||||||
|
def shorten_section_markers(text: str) -> str:
|
||||||
|
"""Rewrite ``Sec. 12.`` style section headings as the more compact ``SEC 12``."""
|
||||||
|
return re.sub(r"(?i)sec\.\s*(\d+[a-zA-Z]?)\.", r"SEC \1", text)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_parens(text: str) -> str:
|
||||||
|
"""Strip parentheses around short alphanumeric labels like ``(a)`` or ``(12)``."""
|
||||||
|
return re.sub(r"\(([a-zA-Z0-9]+)\)", r"\1", text)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_typeset_quotes(text: str) -> str:
|
||||||
|
"""Remove the `` and '' typeset quote markers used in the GPO bill format."""
|
||||||
|
return text.replace("``", "").replace("''", "")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_usc_acronym(text: str) -> str:
|
||||||
|
"""Collapse ``U.S.C.`` to ``USC`` to save tokens on the common citation."""
|
||||||
|
return text.replace("U.S.C.", "USC")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_us_acronym(text: str) -> str:
|
||||||
|
"""Normalize the various ``U.S.``/``U. S.`` spellings to the bare ``US`` form."""
|
||||||
|
for acronym in ("U. S.", "u. s.", "U.S. ", "u.s. "):
|
||||||
|
text = text.replace(acronym, "US ")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_ellipses(text: str) -> str:
|
||||||
|
"""Collapse runs of two-or-more periods (``...``, ``....``) down to a single period."""
|
||||||
|
return re.sub(r"\.{2,}", ".", text)
|
||||||
|
|
||||||
|
|
||||||
|
COMPRESSION_STEPS = (
|
||||||
|
strip_horizontal_rules,
|
||||||
|
collapse_double_dashes,
|
||||||
|
collapse_inline_whitespace,
|
||||||
|
collapse_blank_lines,
|
||||||
|
trim_line_edges,
|
||||||
|
shorten_section_markers,
|
||||||
|
unwrap_parens,
|
||||||
|
strip_typeset_quotes,
|
||||||
|
normalize_usc_acronym,
|
||||||
|
normalize_us_acronym,
|
||||||
|
strip_number_commas,
|
||||||
|
collapse_ellipses,
|
||||||
|
normalize_state_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compress_bill_text(text: str) -> str:
|
||||||
|
"""Apply lossless-ish whitespace and boilerplate compression to bill text.
|
||||||
|
|
||||||
|
Runs every transform in :data:`COMPRESSION_STEPS` in order, then strips
|
||||||
|
leading/trailing whitespace from the final result.
|
||||||
|
"""
|
||||||
|
for step in COMPRESSION_STEPS:
|
||||||
|
text = step(text)
|
||||||
|
return text.strip()
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
"""Run two interactive OpenAI chat-completion sweeps over bill text.
|
||||||
|
|
||||||
|
Reads the first N bills from a CSV with a `text_content` column and sends two
|
||||||
|
sweeps through `/v1/chat/completions` concurrently — one with the raw bill
|
||||||
|
text, one with the compressed bill text. Each request's prompt is saved to
|
||||||
|
disk alongside the OpenAI response id so the prompts and responses can be
|
||||||
|
correlated later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import tomllib
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from os import getenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from python.prompt_bench.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml"
|
||||||
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||||
|
DEFAULT_MODEL = "gpt-5.4-mini"
|
||||||
|
DEFAULT_COUNT = 100
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
def load_bills(csv_path: Path, count: int) -> list[tuple[str, str]]:
|
||||||
|
"""Return up to `count` (bill_id, text_content) tuples with non-empty text."""
|
||||||
|
csv.field_size_limit(sys.maxsize)
|
||||||
|
bills: list[tuple[str, str]] = []
|
||||||
|
with csv_path.open(newline="", encoding="utf-8") as handle:
|
||||||
|
reader = csv.DictReader(handle)
|
||||||
|
for row in reader:
|
||||||
|
text_content = (row.get("text_content") or "").strip()
|
||||||
|
if not text_content:
|
||||||
|
continue
|
||||||
|
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
|
||||||
|
version_code = row.get("version_code") or ""
|
||||||
|
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
|
||||||
|
bills.append((unique_id, text_content))
|
||||||
|
if len(bills) >= count:
|
||||||
|
break
|
||||||
|
return bills
|
||||||
|
|
||||||
|
|
||||||
|
def build_messages(bill_text: str) -> list[dict]:
|
||||||
|
"""Return the system + user message pair for a bill."""
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(value: str) -> str:
|
||||||
|
"""Make a string safe for use as a filename."""
|
||||||
|
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
|
||||||
|
|
||||||
|
|
||||||
|
def run_one_request(
|
||||||
|
client: httpx.Client,
|
||||||
|
*,
|
||||||
|
bill_id: str,
|
||||||
|
label: str,
|
||||||
|
bill_text: str,
|
||||||
|
model: str,
|
||||||
|
output_path: Path,
|
||||||
|
) -> tuple[bool, float, str | None]:
|
||||||
|
"""Send one chat-completion request and persist prompt + response.
|
||||||
|
|
||||||
|
Returns (success, elapsed_seconds, response_id).
|
||||||
|
"""
|
||||||
|
messages = build_messages(bill_text)
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"seed": SEED,
|
||||||
|
}
|
||||||
|
start = time.monotonic()
|
||||||
|
record: dict = {
|
||||||
|
"bill_id": bill_id,
|
||||||
|
"label": label,
|
||||||
|
"model": model,
|
||||||
|
"seed": SEED,
|
||||||
|
"input_chars": len(bill_text),
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = client.post(f"{OPENAI_API_BASE}/chat/completions", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
body = response.json()
|
||||||
|
except httpx.HTTPStatusError as error:
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
record["error"] = {
|
||||||
|
"status_code": error.response.status_code,
|
||||||
|
"body": error.response.text,
|
||||||
|
"elapsed_seconds": elapsed,
|
||||||
|
}
|
||||||
|
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
|
||||||
|
logger.exception("HTTP error for %s/%s after %.2fs", label, bill_id, elapsed)
|
||||||
|
return False, elapsed, None
|
||||||
|
except Exception as error:
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
record["error"] = {"message": str(error), "elapsed_seconds": elapsed}
|
||||||
|
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
|
||||||
|
logger.exception("Failed: %s/%s after %.2fs", label, bill_id, elapsed)
|
||||||
|
return False, elapsed, None
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
response_id = body.get("id")
|
||||||
|
record["response_id"] = response_id
|
||||||
|
record["elapsed_seconds"] = elapsed
|
||||||
|
record["usage"] = body.get("usage")
|
||||||
|
record["response"] = body
|
||||||
|
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
|
||||||
|
logger.info("Done: %s/%s id=%s in %.2fs", label, bill_id, response_id, elapsed)
|
||||||
|
return True, elapsed, response_id
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
|
||||||
|
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path(
|
||||||
|
"output/openai_runs",
|
||||||
|
),
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
|
||||||
|
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
|
||||||
|
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
if not csv_path.is_file():
|
||||||
|
message = f"CSV not found: {csv_path}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
compressed_dir = output_dir / "compressed"
|
||||||
|
uncompressed_dir = output_dir / "uncompressed"
|
||||||
|
compressed_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
uncompressed_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("Loading %d bills from %s", count, csv_path)
|
||||||
|
bills = load_bills(csv_path, count)
|
||||||
|
if len(bills) < count:
|
||||||
|
logger.warning("Only %d bills available (requested %d)", len(bills), count)
|
||||||
|
|
||||||
|
tasks: list[tuple[str, str, str, Path]] = []
|
||||||
|
for bill_id, text_content in bills:
|
||||||
|
filename = f"{safe_filename(bill_id)}.json"
|
||||||
|
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename))
|
||||||
|
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename))
|
||||||
|
|
||||||
|
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
completed = 0
|
||||||
|
failed = 0
|
||||||
|
index: list[dict] = []
|
||||||
|
wall_start = time.monotonic()
|
||||||
|
with (
|
||||||
|
httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client,
|
||||||
|
ThreadPoolExecutor(
|
||||||
|
max_workers=concurrency,
|
||||||
|
) as executor,
|
||||||
|
):
|
||||||
|
future_to_task = {
|
||||||
|
executor.submit(
|
||||||
|
run_one_request,
|
||||||
|
client,
|
||||||
|
bill_id=bill_id,
|
||||||
|
label=label,
|
||||||
|
bill_text=bill_text,
|
||||||
|
model=model,
|
||||||
|
output_path=output_path,
|
||||||
|
): (bill_id, label, output_path)
|
||||||
|
for bill_id, label, bill_text, output_path in tasks
|
||||||
|
}
|
||||||
|
for future in as_completed(future_to_task):
|
||||||
|
bill_id, label, output_path = future_to_task[future]
|
||||||
|
success, elapsed, response_id = future.result()
|
||||||
|
if success:
|
||||||
|
completed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
index.append(
|
||||||
|
{
|
||||||
|
"bill_id": bill_id,
|
||||||
|
"label": label,
|
||||||
|
"response_id": response_id,
|
||||||
|
"elapsed_seconds": elapsed,
|
||||||
|
"success": success,
|
||||||
|
"path": str(output_path),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
wall_elapsed = time.monotonic() - wall_start
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"model": model,
|
||||||
|
"count": len(bills),
|
||||||
|
"completed": completed,
|
||||||
|
"failed": failed,
|
||||||
|
"wall_seconds": wall_elapsed,
|
||||||
|
"concurrency": concurrency,
|
||||||
|
"results": index,
|
||||||
|
}
|
||||||
|
summary_path = output_dir / "summary.json"
|
||||||
|
summary_path.write_text(json.dumps(summary, indent=2))
|
||||||
|
logger.info(
|
||||||
|
"Done: completed=%d failed=%d wall=%.1fs summary=%s",
|
||||||
|
completed,
|
||||||
|
failed,
|
||||||
|
wall_elapsed,
|
||||||
|
summary_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,165 @@
|
|||||||
|
"""Docker container lifecycle management for Unsloth fine-tuning."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from python.prompt_bench.containers.lib import check_gpu_free
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CONTAINER_NAME = "bill-finetune"
|
||||||
|
FINETUNE_IMAGE = "bill-finetune:latest"
|
||||||
|
DOCKERFILE_PATH = "/home/richie/dotfiles/python/prompt_bench/Dockerfile.finetune"
|
||||||
|
DEFAULT_HF_CACHE = Path("/zfs/models/hf")
|
||||||
|
|
||||||
|
|
||||||
|
def build_image() -> None:
|
||||||
|
"""Build the fine-tuning Docker image."""
|
||||||
|
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
|
||||||
|
result = subprocess.run(
|
||||||
|
["docker", "build", "-f", DOCKERFILE_PATH, "-t", FINETUNE_IMAGE, "."],
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
message = "Failed to build fine-tuning image"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
logger.info("Image built: %s", FINETUNE_IMAGE)
|
||||||
|
|
||||||
|
|
||||||
|
def start_finetune(
|
||||||
|
*,
|
||||||
|
dataset_path: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
hf_cache: Path = DEFAULT_HF_CACHE,
|
||||||
|
) -> None:
|
||||||
|
"""Run the fine-tuning container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path: Host path to the fine-tuning JSONL dataset.
|
||||||
|
output_dir: Host path where the trained model will be saved.
|
||||||
|
hf_cache: Host path to HuggingFace model cache (bind-mounted to avoid re-downloading).
|
||||||
|
validation_split: Fraction of data held out for validation.
|
||||||
|
"""
|
||||||
|
dataset_path = dataset_path.resolve()
|
||||||
|
output_dir = output_dir.resolve()
|
||||||
|
|
||||||
|
if not dataset_path.is_file():
|
||||||
|
message = f"Dataset not found: {dataset_path}"
|
||||||
|
raise FileNotFoundError(message)
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
stop_finetune()
|
||||||
|
|
||||||
|
hf_cache = hf_cache.resolve()
|
||||||
|
hf_cache.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
command = [
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
|
"--name",
|
||||||
|
CONTAINER_NAME,
|
||||||
|
"--device=nvidia.com/gpu=all",
|
||||||
|
"--ipc=host",
|
||||||
|
"-v",
|
||||||
|
f"{hf_cache}:/root/.cache/huggingface",
|
||||||
|
"-v",
|
||||||
|
f"{output_dir}:/workspace/output/qwen-bill-summarizer",
|
||||||
|
"-v",
|
||||||
|
f"{dataset_path}:/workspace/dataset.jsonl:ro",
|
||||||
|
FINETUNE_IMAGE,
|
||||||
|
"--dataset",
|
||||||
|
"/workspace/dataset.jsonl",
|
||||||
|
"--output-dir",
|
||||||
|
"/workspace/output/qwen-bill-summarizer",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Starting fine-tuning container")
|
||||||
|
logger.info(" Dataset: %s", dataset_path)
|
||||||
|
logger.info(" Output: %s", output_dir)
|
||||||
|
|
||||||
|
result = subprocess.run(command, text=True, check=False)
|
||||||
|
if result.returncode != 0:
|
||||||
|
message = f"Fine-tuning container exited with code {result.returncode}"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
logger.info("Fine-tuning complete. Model saved to %s", output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_finetune() -> None:
|
||||||
|
"""Stop and remove the fine-tuning container."""
|
||||||
|
logger.info("Stopping fine-tuning container")
|
||||||
|
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
|
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
|
|
||||||
|
|
||||||
|
def logs_finetune() -> str | None:
|
||||||
|
"""Return recent logs from the fine-tuning container, or None if not running."""
|
||||||
|
result = subprocess.run(
|
||||||
|
["docker", "logs", "--tail", "50", CONTAINER_NAME],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return result.stdout + result.stderr
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer(help="Fine-tuning container management.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def build() -> None:
|
||||||
|
"""Build the fine-tuning Docker image."""
|
||||||
|
build_image()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def run(
|
||||||
|
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = Path(
|
||||||
|
"/home/richie/dotfiles/data/finetune_dataset.jsonl"
|
||||||
|
),
|
||||||
|
output_dir: Annotated[Path, typer.Option(help="Where to save the trained model")] = Path(
|
||||||
|
"/home/richie/dotfiles/data/output/qwen-bill-summarizer",
|
||||||
|
),
|
||||||
|
hf_cache: Annotated[Path, typer.Option(help="Host path to HuggingFace model cache")] = DEFAULT_HF_CACHE,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Run fine-tuning inside a Docker container."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
check_gpu_free()
|
||||||
|
start_finetune(
|
||||||
|
dataset_path=dataset,
|
||||||
|
output_dir=output_dir,
|
||||||
|
hf_cache=hf_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def stop() -> None:
|
||||||
|
"""Stop and remove the fine-tuning container."""
|
||||||
|
stop_finetune()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def logs() -> None:
|
||||||
|
"""Show recent logs from the fine-tuning container."""
|
||||||
|
output = logs_finetune()
|
||||||
|
if output is None:
|
||||||
|
typer.echo("No running fine-tuning container found.")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
typer.echo(output)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_gpu_free() -> None:
|
||||||
|
"""Warn if GPU-heavy processes (e.g. Ollama) are running."""
|
||||||
|
result = subprocess.run(
|
||||||
|
["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.warning("Could not query GPU processes: %s", result.stderr.strip())
|
||||||
|
return
|
||||||
|
processes = result.stdout.strip()
|
||||||
|
if processes:
|
||||||
|
logger.warning("GPU processes detected:\n%s", processes)
|
||||||
|
logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking")
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
"""Docker container lifecycle management for vLLM."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CONTAINER_NAME = "vllm-bench"
|
||||||
|
VLLM_IMAGE = "vllm/vllm-openai:v0.19.0"
|
||||||
|
|
||||||
|
|
||||||
|
def start_vllm(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
port: int,
|
||||||
|
model_dir: str,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
) -> None:
|
||||||
|
"""Start a vLLM container serving the given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: HuggingFace model directory name (relative to model_dir).
|
||||||
|
port: Host port to bind.
|
||||||
|
model_dir: Host path containing HuggingFace model directories.
|
||||||
|
gpu_memory_utilization: Fraction of GPU memory to use (0-1).
|
||||||
|
"""
|
||||||
|
command = [
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
|
"-d",
|
||||||
|
"--name",
|
||||||
|
CONTAINER_NAME,
|
||||||
|
"--device=nvidia.com/gpu=all",
|
||||||
|
"--ipc=host",
|
||||||
|
"-v",
|
||||||
|
f"{model_dir}:/models",
|
||||||
|
"-p",
|
||||||
|
f"{port}:8000",
|
||||||
|
VLLM_IMAGE,
|
||||||
|
"--model",
|
||||||
|
f"/models/{model}",
|
||||||
|
"--served-model-name",
|
||||||
|
model,
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
str(gpu_memory_utilization),
|
||||||
|
"--max-model-len",
|
||||||
|
"4096",
|
||||||
|
]
|
||||||
|
logger.info("Starting vLLM container with model: %s", model)
|
||||||
|
stop_vllm()
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True, check=False)
|
||||||
|
if result.returncode != 0:
|
||||||
|
msg = f"Failed to start vLLM container: {result.stderr.strip()}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
logger.info("vLLM container started: %s", result.stdout.strip()[:12])
|
||||||
|
|
||||||
|
|
||||||
|
def stop_vllm() -> None:
|
||||||
|
"""Stop and remove the vLLM benchmark container."""
|
||||||
|
logger.info("Stopping vLLM container")
|
||||||
|
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
|
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
|
||||||
|
subprocess.run(
|
||||||
|
["docker", "network", "disconnect", "-f", "bridge", CONTAINER_NAME],
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
logger.info("vLLM container stopped and removed")
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
"""HuggingFace model downloader."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from python.prompt_bench.models import BenchmarkConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def local_model_path(repo: str, model_dir: str) -> Path:
|
||||||
|
"""Return the local directory path for a HuggingFace repo."""
|
||||||
|
return Path(model_dir) / repo
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_present(repo: str, model_dir: str) -> bool:
|
||||||
|
"""Check if a model has already been downloaded."""
|
||||||
|
path = local_model_path(repo, model_dir)
|
||||||
|
return path.exists() and any(path.iterdir())
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(repo: str, model_dir: str) -> Path:
|
||||||
|
"""Download a HuggingFace model to the local model directory.
|
||||||
|
|
||||||
|
Skips the download if the model directory already exists and contains files.
|
||||||
|
"""
|
||||||
|
local_path = local_model_path(repo, model_dir)
|
||||||
|
|
||||||
|
if is_model_present(repo, model_dir):
|
||||||
|
logger.info("Model already exists: %s", local_path)
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
logger.info("Downloading model: %s -> %s", repo, local_path)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo,
|
||||||
|
local_dir=str(local_path),
|
||||||
|
)
|
||||||
|
logger.info("Download complete: %s", repo)
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
|
||||||
|
def download_all(config: BenchmarkConfig) -> None:
|
||||||
|
"""Download every model listed in the config, top to bottom."""
|
||||||
|
for repo in config.models:
|
||||||
|
download_model(repo, config.model_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Download all models listed in the benchmark config."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
if not config.is_file():
|
||||||
|
message = f"Config file does not exist: {config}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
benchmark_config = BenchmarkConfig.from_toml(config)
|
||||||
|
download_all(benchmark_config)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
"""Fine-tune Qwen 3.5 4B on bill summarization data using Unsloth.
|
||||||
|
|
||||||
|
Loads a ChatML-style JSONL dataset (system/user/assistant messages),
|
||||||
|
applies QLoRA with 4-bit quantization, and saves the merged model
|
||||||
|
in HuggingFace format. Designed for a single RTX 3090 (24GB).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m python.prompt_bench.finetune \
|
||||||
|
--dataset output/finetune_dataset.jsonl \
|
||||||
|
--output-dir output/qwen-bill-summarizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
import typer
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig:
|
||||||
|
"""LoRA adapter hyperparameters."""
|
||||||
|
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
dropout: float
|
||||||
|
targets: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
"""Training loop hyperparameters."""
|
||||||
|
|
||||||
|
learning_rate: float
|
||||||
|
epochs: int
|
||||||
|
batch_size: int
|
||||||
|
gradient_accumulation: int
|
||||||
|
max_seq_length: int
|
||||||
|
warmup_ratio: float
|
||||||
|
weight_decay: float
|
||||||
|
logging_steps: int
|
||||||
|
save_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuneConfig:
|
||||||
|
"""Top-level finetune configuration."""
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
lora: LoraConfig
|
||||||
|
training: TrainingConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> FinetuneConfig:
|
||||||
|
"""Load finetune config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["finetune"]
|
||||||
|
return cls(
|
||||||
|
base_model=raw["base_model"],
|
||||||
|
lora=LoraConfig(**raw["lora"]),
|
||||||
|
training=TrainingConfig(**raw["training"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_chatml(messages: list[dict]) -> str:
|
||||||
|
r"""Convert a message list to Qwen ChatML format.
|
||||||
|
|
||||||
|
Produces:
|
||||||
|
<|im_start|>system\n...\n<|im_end|>
|
||||||
|
<|im_start|>user\n...\n<|im_end|>
|
||||||
|
<|im_start|>assistant\n...\n<|im_end|>
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
content = message["content"]
|
||||||
|
parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_jsonl(path: Path) -> Dataset:
|
||||||
|
"""Load a ChatML JSONL file into a HuggingFace Dataset.
|
||||||
|
|
||||||
|
Each line must have {"messages": [{"role": ..., "content": ...}, ...]}.
|
||||||
|
Pre-formats into a `text` column with the Qwen ChatML template applied,
|
||||||
|
which SFTTrainer consumes directly.
|
||||||
|
"""
|
||||||
|
records = []
|
||||||
|
with path.open(encoding="utf-8") as handle:
|
||||||
|
for raw_line in handle:
|
||||||
|
stripped = raw_line.strip()
|
||||||
|
if stripped:
|
||||||
|
entry = json.loads(stripped)
|
||||||
|
records.append({"text": _messages_to_chatml(entry["messages"])})
|
||||||
|
logger.info("Loaded %d examples from %s", len(records), path)
|
||||||
|
return Dataset.from_list(records)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
dataset_path: Annotated[Path, typer.Option("--dataset", help="Fine-tuning JSONL")] = Path(
|
||||||
|
"output/finetune_dataset.jsonl",
|
||||||
|
),
|
||||||
|
validation_split: Annotated[float, typer.Option("--val-split", help="Fraction held out for validation")] = 0.1,
|
||||||
|
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
|
||||||
|
"output/qwen-bill-summarizer",
|
||||||
|
),
|
||||||
|
config_path: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Option("--config", help="TOML config file"),
|
||||||
|
] = Path(__file__).parent / "config.toml",
|
||||||
|
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
|
||||||
|
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
if not dataset_path.is_file():
|
||||||
|
message = f"Dataset not found: {dataset_path}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
config = FinetuneConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
logger.info("Loading base model: %s", config.base_model)
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=config.base_model,
|
||||||
|
max_seq_length=config.training.max_seq_length,
|
||||||
|
load_in_4bit=True,
|
||||||
|
dtype=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=config.lora.rank,
|
||||||
|
lora_alpha=config.lora.alpha,
|
||||||
|
lora_dropout=config.lora.dropout,
|
||||||
|
target_modules=config.lora.targets,
|
||||||
|
bias="none",
|
||||||
|
use_gradient_checkpointing="unsloth",
|
||||||
|
random_state=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
full_dataset = load_dataset_from_jsonl(dataset_path)
|
||||||
|
split = full_dataset.train_test_split(test_size=validation_split, seed=42)
|
||||||
|
train_dataset = split["train"]
|
||||||
|
validation_dataset = split["test"]
|
||||||
|
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=str(output_dir / "checkpoints"),
|
||||||
|
num_train_epochs=config.training.epochs,
|
||||||
|
per_device_train_batch_size=config.training.batch_size,
|
||||||
|
gradient_accumulation_steps=config.training.gradient_accumulation,
|
||||||
|
learning_rate=config.training.learning_rate,
|
||||||
|
warmup_ratio=config.training.warmup_ratio,
|
||||||
|
weight_decay=config.training.weight_decay,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
logging_steps=config.training.logging_steps,
|
||||||
|
save_steps=config.training.save_steps,
|
||||||
|
save_total_limit=3,
|
||||||
|
eval_strategy="steps",
|
||||||
|
eval_steps=config.training.save_steps,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
bf16=True,
|
||||||
|
optim="adamw_8bit",
|
||||||
|
seed=42,
|
||||||
|
report_to="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=validation_dataset,
|
||||||
|
args=training_args,
|
||||||
|
max_seq_length=config.training.max_seq_length,
|
||||||
|
packing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting training: %d train, %d val, %d epochs",
|
||||||
|
len(train_dataset),
|
||||||
|
len(validation_dataset),
|
||||||
|
config.training.epochs,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
merged_path = str(output_dir / "merged")
|
||||||
|
logger.info("Saving merged model to %s", merged_path)
|
||||||
|
model.save_pretrained_merged(merged_path, tokenizer, save_method="merged_16bit")
|
||||||
|
|
||||||
|
if save_gguf:
|
||||||
|
gguf_path = str(output_dir / "gguf")
|
||||||
|
logger.info("Saving GGUF to %s", gguf_path)
|
||||||
|
model.save_pretrained_gguf(gguf_path, tokenizer, quantization_method="q4_k_m")
|
||||||
|
|
||||||
|
logger.info("Done! Model saved to %s", output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
how many oceans are there in the world
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
whos the president of the united states
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
whats the greatest country in the world
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
was/is the usa the greatest country in the world
|
||||||
@@ -0,0 +1,215 @@
|
|||||||
|
"""CLI entry point for the prompt benchmarking system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from python.prompt_bench.containers.lib import check_gpu_free
|
||||||
|
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm
|
||||||
|
from python.prompt_bench.downloader import is_model_present
|
||||||
|
from python.prompt_bench.models import BenchmarkConfig
|
||||||
|
from python.prompt_bench.vllm_client import VLLMClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_prompts(input_dir: Path) -> list[Path]:
|
||||||
|
"""Find all .txt files in the input directory."""
|
||||||
|
prompts = list(input_dir.glob("*.txt"))
|
||||||
|
if not prompts:
|
||||||
|
message = f"No .txt files found in {input_dir}"
|
||||||
|
raise FileNotFoundError(message)
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def _run_prompt(
|
||||||
|
client: VLLMClient,
|
||||||
|
prompt_path: Path,
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
model_dir_name: str,
|
||||||
|
model_output: Path,
|
||||||
|
temperature: float,
|
||||||
|
) -> tuple[bool, float]:
|
||||||
|
"""Run a single prompt. Returns (success, elapsed_seconds)."""
|
||||||
|
filename = prompt_path.name
|
||||||
|
output_path = model_output / filename
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
prompt_text = prompt_path.read_text()
|
||||||
|
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
|
||||||
|
output_path.write_text(response)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
|
||||||
|
except Exception:
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
error_path = model_output / f"{filename}.error"
|
||||||
|
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
|
||||||
|
error_path.write_text(f"Error processing {filename}")
|
||||||
|
return False, elapsed
|
||||||
|
return True, elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_model(
|
||||||
|
client: VLLMClient,
|
||||||
|
prompts: list[Path],
|
||||||
|
*,
|
||||||
|
repo: str,
|
||||||
|
model_dir_name: str,
|
||||||
|
model_output: Path,
|
||||||
|
temperature: float,
|
||||||
|
concurrency: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""Run all prompts against a single model in parallel.
|
||||||
|
|
||||||
|
vLLM batches concurrent requests internally, so submitting many at once is
|
||||||
|
significantly faster than running them serially.
|
||||||
|
"""
|
||||||
|
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
|
||||||
|
skipped = len(prompts) - len(pending)
|
||||||
|
if skipped:
|
||||||
|
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
logger.info("Nothing to do for %s", repo)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
completed = 0
|
||||||
|
failed = 0
|
||||||
|
latencies: list[float] = []
|
||||||
|
|
||||||
|
wall_start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
_run_prompt,
|
||||||
|
client,
|
||||||
|
prompt_path,
|
||||||
|
repo=repo,
|
||||||
|
model_dir_name=model_dir_name,
|
||||||
|
model_output=model_output,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
for prompt_path in pending
|
||||||
|
]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
success, elapsed = future.result()
|
||||||
|
latencies.append(elapsed)
|
||||||
|
if success:
|
||||||
|
completed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
wall_elapsed = time.monotonic() - wall_start
|
||||||
|
|
||||||
|
attempted = completed + failed
|
||||||
|
avg_latency = sum(latencies) / attempted
|
||||||
|
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
|
||||||
|
timing = {
|
||||||
|
"repo": repo,
|
||||||
|
"wall_seconds": wall_elapsed,
|
||||||
|
"attempted": attempted,
|
||||||
|
"completed": completed,
|
||||||
|
"failed": failed,
|
||||||
|
"avg_latency_seconds": avg_latency,
|
||||||
|
"throughput_prompts_per_second": throughput,
|
||||||
|
"concurrency": concurrency,
|
||||||
|
}
|
||||||
|
timing_path = model_output / "_timing.json"
|
||||||
|
timing_path.write_text(json.dumps(timing, indent=2))
|
||||||
|
|
||||||
|
return completed, failed
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
input_dir: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Execute the benchmark across all models and prompts."""
|
||||||
|
prompts = discover_prompts(input_dir)
|
||||||
|
logger.info("Found %d prompts in %s", len(prompts), input_dir)
|
||||||
|
|
||||||
|
check_gpu_free()
|
||||||
|
|
||||||
|
total_completed = 0
|
||||||
|
total_failed = 0
|
||||||
|
|
||||||
|
for repo in config.models:
|
||||||
|
if not is_model_present(repo, config.model_dir):
|
||||||
|
logger.warning("Skipping (not downloaded): %s", repo)
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_output = output_dir / repo
|
||||||
|
model_output.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("=== Benchmarking model: %s ===", repo)
|
||||||
|
|
||||||
|
stop_vllm()
|
||||||
|
try:
|
||||||
|
start_vllm(
|
||||||
|
model=repo,
|
||||||
|
port=config.port,
|
||||||
|
model_dir=config.model_dir,
|
||||||
|
gpu_memory_utilization=config.gpu_memory_utilization,
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.exception("Failed to start vLLM for %s, skipping", repo)
|
||||||
|
continue
|
||||||
|
logger.info("vLLM started for %s", repo)
|
||||||
|
try:
|
||||||
|
with VLLMClient(port=config.port, timeout=config.timeout) as client:
|
||||||
|
client.wait_ready(max_wait=config.vllm_startup_timeout)
|
||||||
|
completed, failed = benchmark_model(
|
||||||
|
client,
|
||||||
|
prompts,
|
||||||
|
repo=repo,
|
||||||
|
model_dir_name=repo,
|
||||||
|
model_output=model_output,
|
||||||
|
temperature=config.temperature,
|
||||||
|
concurrency=config.concurrency,
|
||||||
|
)
|
||||||
|
total_completed += completed
|
||||||
|
total_failed += failed
|
||||||
|
finally:
|
||||||
|
stop_vllm()
|
||||||
|
|
||||||
|
logger.info("=== Benchmark complete ===")
|
||||||
|
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
|
||||||
|
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
|
||||||
|
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Run prompts through multiple LLMs via vLLM and save results."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
if not input_dir.is_dir():
|
||||||
|
message = f"Input directory does not exist: {input_dir}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
if not config.is_file():
|
||||||
|
message = f"Config file does not exist: {config}"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
benchmark_config = BenchmarkConfig.from_toml(config)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
run_benchmark(benchmark_config, input_dir, output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""Pydantic models for benchmark configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(BaseModel):
|
||||||
|
"""Top-level benchmark configuration loaded from TOML."""
|
||||||
|
|
||||||
|
models: list[str]
|
||||||
|
model_dir: str = "/zfs/models/hf"
|
||||||
|
port: int = 8000
|
||||||
|
gpu_memory_utilization: float = 0.90
|
||||||
|
temperature: float = 0.0
|
||||||
|
timeout: int = 300
|
||||||
|
concurrency: int = 4
|
||||||
|
vllm_startup_timeout: int = 900
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> BenchmarkConfig:
|
||||||
|
"""Load benchmark config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["bench"]
|
||||||
|
return cls(**raw)
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
SUMMARIZATION_SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
|
||||||
|
|
||||||
|
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action — including secondary effects that might be buried in subsections.
|
||||||
|
|
||||||
|
EXTRACTION RULES:
|
||||||
|
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
|
||||||
|
- FOCUS ON: operative verbs — what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
|
||||||
|
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
|
||||||
|
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
|
||||||
|
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
|
||||||
|
|
||||||
|
OUTPUT FORMAT — plain structured text, not JSON:
|
||||||
|
|
||||||
|
OPERATIVE ACTIONS:
|
||||||
|
[Numbered list of what the bill actually does, one action per line, max 20 words each]
|
||||||
|
|
||||||
|
AFFECTED POPULATIONS:
|
||||||
|
[Who gains something, who loses something, or whose behavior is regulated]
|
||||||
|
|
||||||
|
MECHANISMS:
|
||||||
|
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
|
||||||
|
|
||||||
|
POLICY THREADS:
|
||||||
|
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
|
||||||
|
|
||||||
|
SYMBOLIC/PROCEDURAL ONLY:
|
||||||
|
[Yes or No — is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
|
||||||
|
|
||||||
|
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
|
||||||
|
|
||||||
|
SUMMARIZATION_USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
|
||||||
|
|
||||||
|
BILL TEXT:
|
||||||
|
{text_content}"""
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
"""Build a fine-tuning JSONL dataset from batch request + output files.
|
||||||
|
|
||||||
|
Joins the original request JSONL (system + user messages) with the batch
|
||||||
|
output JSONL (assistant completions) by custom_id to produce a ChatML-style
|
||||||
|
messages JSONL suitable for fine-tuning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HTTP_OK = 200
|
||||||
|
|
||||||
|
|
||||||
|
def load_requests(path: Path) -> dict[str, list[dict]]:
|
||||||
|
"""Parse request JSONL into {custom_id: messages}."""
|
||||||
|
results: dict[str, list[dict]] = {}
|
||||||
|
with path.open(encoding="utf-8") as handle:
|
||||||
|
for raw_line in handle:
|
||||||
|
stripped = raw_line.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
record = json.loads(stripped)
|
||||||
|
custom_id = record["custom_id"]
|
||||||
|
messages = record["body"]["messages"]
|
||||||
|
results[custom_id] = messages
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def load_completions(path: Path) -> dict[str, str]:
|
||||||
|
"""Parse batch output JSONL into {custom_id: assistant_content}."""
|
||||||
|
results: dict[str, str] = {}
|
||||||
|
with path.open(encoding="utf-8") as handle:
|
||||||
|
for line_number, raw_line in enumerate(handle, 1):
|
||||||
|
stripped = raw_line.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
record = json.loads(stripped)
|
||||||
|
custom_id = record["custom_id"]
|
||||||
|
response = record.get("response", {})
|
||||||
|
if response.get("status_code") != HTTP_OK:
|
||||||
|
logger.warning("Skipping %s (line %d): status %s", custom_id, line_number, response.get("status_code"))
|
||||||
|
continue
|
||||||
|
body = response.get("body", {})
|
||||||
|
choices = body.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
logger.warning("Skipping %s (line %d): no choices", custom_id, line_number)
|
||||||
|
continue
|
||||||
|
content = choices[0].get("message", {}).get("content", "")
|
||||||
|
if not content:
|
||||||
|
logger.warning("Skipping %s (line %d): empty content", custom_id, line_number)
|
||||||
|
continue
|
||||||
|
results[custom_id] = content
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
requests_path: Annotated[Path, typer.Option("--requests", help="Batch request JSONL")] = Path(
|
||||||
|
"output/openai_batch/requests.jsonl",
|
||||||
|
),
|
||||||
|
batch_output: Annotated[Path, typer.Option("--batch-output", help="Batch output JSONL")] = Path(
|
||||||
|
"batch_69d84558d91c819091d53f08d78f9fd6_output.jsonl",
|
||||||
|
),
|
||||||
|
output_path: Annotated[Path, typer.Option("--output", help="Fine-tuning JSONL output")] = Path(
|
||||||
|
"output/finetune_dataset.jsonl",
|
||||||
|
),
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Build fine-tuning dataset by joining request and output JSONL files."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
logger.info("Loading requests from %s", requests_path)
|
||||||
|
requests = load_requests(requests_path)
|
||||||
|
logger.info("Loaded %d requests", len(requests))
|
||||||
|
|
||||||
|
logger.info("Loading completions from %s", batch_output)
|
||||||
|
completions = load_completions(batch_output)
|
||||||
|
logger.info("Loaded %d completions", len(completions))
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
matched = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
with output_path.open("w", encoding="utf-8") as handle:
|
||||||
|
for custom_id, messages in requests.items():
|
||||||
|
assistant_content = completions.get(custom_id)
|
||||||
|
if assistant_content is None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
example = {
|
||||||
|
"messages": [*messages, {"role": "assistant", "content": assistant_content}],
|
||||||
|
}
|
||||||
|
handle.write(json.dumps(example, ensure_ascii=False))
|
||||||
|
handle.write("\n")
|
||||||
|
matched += 1
|
||||||
|
|
||||||
|
logger.info("Wrote %d examples to %s (skipped %d unmatched)", matched, output_path, skipped)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
"""Sum token usage across compressed and uncompressed run directories."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageTotals:
|
||||||
|
"""Aggregate usage counters for a directory of run records."""
|
||||||
|
|
||||||
|
files: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
cached_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
reasoning_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
per_file: list[tuple[str, int, int, int]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def tally_directory(directory: Path) -> UsageTotals:
|
||||||
|
"""Return aggregated usage stats for every JSON record in a directory."""
|
||||||
|
totals = UsageTotals()
|
||||||
|
decoder = json.JSONDecoder()
|
||||||
|
for path in sorted(directory.glob("*.json")):
|
||||||
|
text = path.read_text().lstrip()
|
||||||
|
record, _ = decoder.raw_decode(text)
|
||||||
|
totals.files += 1
|
||||||
|
usage = record.get("usage")
|
||||||
|
if not usage:
|
||||||
|
totals.errors += 1
|
||||||
|
continue
|
||||||
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = usage.get("completion_tokens", 0)
|
||||||
|
total_tokens = usage.get("total_tokens", 0)
|
||||||
|
cached_tokens = (usage.get("prompt_tokens_details") or {}).get("cached_tokens", 0)
|
||||||
|
reasoning_tokens = (usage.get("completion_tokens_details") or {}).get("reasoning_tokens", 0)
|
||||||
|
totals.prompt_tokens += prompt_tokens
|
||||||
|
totals.completion_tokens += completion_tokens
|
||||||
|
totals.total_tokens += total_tokens
|
||||||
|
totals.cached_tokens += cached_tokens
|
||||||
|
totals.reasoning_tokens += reasoning_tokens
|
||||||
|
totals.per_file.append((path.name, prompt_tokens, completion_tokens, total_tokens))
|
||||||
|
return totals
|
||||||
|
|
||||||
|
|
||||||
|
def log_totals(label: str, totals: UsageTotals) -> None:
|
||||||
|
"""Log a one-block summary for a directory."""
|
||||||
|
counted = totals.files - totals.errors
|
||||||
|
average_total = totals.total_tokens / counted if counted else 0
|
||||||
|
logger.info("[%s]", label)
|
||||||
|
logger.info(" files : %d (with usage: %d, errors: %d)", totals.files, counted, totals.errors)
|
||||||
|
logger.info(" prompt tokens : %d", totals.prompt_tokens)
|
||||||
|
logger.info(" cached tokens : %d", totals.cached_tokens)
|
||||||
|
logger.info(" completion tok : %d", totals.completion_tokens)
|
||||||
|
logger.info(" reasoning tok : %d", totals.reasoning_tokens)
|
||||||
|
logger.info(" total tokens : %d", totals.total_tokens)
|
||||||
|
logger.info(" avg total/file : %.1f", average_total)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
runs_dir: Annotated[Path, typer.Option("--runs-dir")] = Path("output/openai_runs_temp_1"),
|
||||||
|
log_level: Annotated[str, typer.Option("--log-level")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""Print token usage totals for the compressed and uncompressed run directories."""
|
||||||
|
logging.basicConfig(level=log_level, format="%(message)s")
|
||||||
|
|
||||||
|
grand = UsageTotals()
|
||||||
|
for label in ("compressed", "uncompressed"):
|
||||||
|
directory = runs_dir / label
|
||||||
|
if not directory.is_dir():
|
||||||
|
logger.warning("%s: directory not found at %s", label, directory)
|
||||||
|
continue
|
||||||
|
totals = tally_directory(directory)
|
||||||
|
log_totals(label, totals)
|
||||||
|
grand.files += totals.files
|
||||||
|
grand.errors += totals.errors
|
||||||
|
grand.prompt_tokens += totals.prompt_tokens
|
||||||
|
grand.cached_tokens += totals.cached_tokens
|
||||||
|
grand.completion_tokens += totals.completion_tokens
|
||||||
|
grand.reasoning_tokens += totals.reasoning_tokens
|
||||||
|
grand.total_tokens += totals.total_tokens
|
||||||
|
|
||||||
|
log_totals("grand total", grand)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
"""OpenAI-compatible client for vLLM's API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
READY_POLL_INTERVAL = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMClient:
|
||||||
|
"""Talk to a vLLM server via its OpenAI-compatible API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: vLLM host.
|
||||||
|
port: vLLM port.
|
||||||
|
timeout: Per-request timeout in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None:
|
||||||
|
"""Create a client connected to a vLLM server."""
|
||||||
|
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
|
||||||
|
|
||||||
|
def wait_ready(self, max_wait: int) -> None:
|
||||||
|
"""Poll /v1/models until the server is ready or timeout."""
|
||||||
|
deadline = time.monotonic() + max_wait
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
response = self._client.get("/v1/models")
|
||||||
|
if response.is_success:
|
||||||
|
logger.info("vLLM server is ready")
|
||||||
|
return
|
||||||
|
except httpx.TransportError:
|
||||||
|
pass
|
||||||
|
time.sleep(READY_POLL_INTERVAL)
|
||||||
|
msg = f"vLLM server not ready after {max_wait}s"
|
||||||
|
raise TimeoutError(msg)
|
||||||
|
|
||||||
|
def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str:
|
||||||
|
"""Send a prompt to /v1/completions and return the response text."""
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
logger.info("Sending prompt to %s (%d chars)", model, len(prompt))
|
||||||
|
response = self._client.post("/v1/completions", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data["choices"][0]["text"]
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""Enter the context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: object) -> None:
|
||||||
|
"""Close the HTTP client on exit."""
|
||||||
|
self.close()
|
||||||
Reference in New Issue
Block a user