239 lines
8.6 KiB
Python
239 lines
8.6 KiB
Python
"""Submit an OpenAI Batch API bill-summarization job over compressed text.
|
|
|
|
Reads the first N bills from a CSV with a `text_content` column, compresses
|
|
each via `bill_token_compression.compress_bill_text`, builds a JSONL file of
|
|
summarization requests, and submits it as an asynchronous Batch API job
|
|
against `/v1/chat/completions`. Also writes a CSV of per-bill pre/post-
|
|
compression token counts.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import csv
|
|
import json
|
|
import logging
|
|
import re
|
|
import sys
|
|
import tomllib
|
|
from os import getenv
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
import httpx
|
|
import typer
|
|
from tiktoken import Encoding, get_encoding
|
|
|
|
from python.prompt_bench.bill_token_compression import compress_bill_text
|
|
|
|
_PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml"
|
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OPENAI_API_BASE = "https://api.openai.com/v1"
|
|
|
|
|
|
def load_bills(csv_path: Path, count: int = 0) -> list[tuple[str, str]]:
|
|
"""Return (bill_id, text_content) tuples with non-empty text.
|
|
|
|
If `count` is 0 or negative, all rows are returned.
|
|
"""
|
|
csv.field_size_limit(sys.maxsize)
|
|
bills: list[tuple[str, str]] = []
|
|
with csv_path.open(newline="", encoding="utf-8") as handle:
|
|
reader = csv.DictReader(handle)
|
|
for row in reader:
|
|
text_content = (row.get("text_content") or "").strip()
|
|
if not text_content:
|
|
continue
|
|
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
|
|
version_code = row.get("version_code") or ""
|
|
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
|
|
bills.append((unique_id, text_content))
|
|
if count > 0 and len(bills) >= count:
|
|
break
|
|
return bills
|
|
|
|
|
|
def safe_filename(value: str) -> str:
|
|
"""Make a string safe for use as a filename or batch custom_id."""
|
|
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
|
|
|
|
|
|
def build_request(custom_id: str, model: str, bill_text: str) -> dict:
|
|
"""Build one OpenAI batch request line."""
|
|
return {
|
|
"custom_id": custom_id,
|
|
"method": "POST",
|
|
"url": "/v1/chat/completions",
|
|
"body": {
|
|
"model": model,
|
|
"messages": [
|
|
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
|
|
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
|
|
],
|
|
},
|
|
}
|
|
|
|
|
|
def write_jsonl(path: Path, lines: list[dict]) -> None:
|
|
"""Write a list of dicts as JSONL."""
|
|
with path.open("w", encoding="utf-8") as handle:
|
|
for line in lines:
|
|
handle.write(json.dumps(line, ensure_ascii=False))
|
|
handle.write("\n")
|
|
|
|
|
|
def upload_file(client: httpx.Client, path: Path) -> str:
|
|
"""Upload a JSONL file to the OpenAI Files API and return its file id."""
|
|
with path.open("rb") as handle:
|
|
response = client.post(
|
|
f"{OPENAI_API_BASE}/files",
|
|
files={"file": (path.name, handle, "application/jsonl")},
|
|
data={"purpose": "batch"},
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["id"]
|
|
|
|
|
|
def prepare_requests(
|
|
bills: list[tuple[str, str]],
|
|
*,
|
|
model: str,
|
|
encoder: Encoding,
|
|
) -> tuple[list[dict], list[dict]]:
|
|
"""Build (request_lines, token_rows) from bills.
|
|
|
|
Each bill is compressed before being turned into a request line.
|
|
Each `token_rows` entry has chars + token counts for one bill so the caller
|
|
can write a per-bill CSV.
|
|
"""
|
|
request_lines: list[dict] = []
|
|
token_rows: list[dict] = []
|
|
for bill_id, text_content in bills:
|
|
raw_token_count = len(encoder.encode(text_content))
|
|
compressed_text = compress_bill_text(text_content)
|
|
compressed_token_count = len(encoder.encode(compressed_text))
|
|
token_rows.append(
|
|
{
|
|
"bill_id": bill_id,
|
|
"raw_chars": len(text_content),
|
|
"compressed_chars": len(compressed_text),
|
|
"raw_tokens": raw_token_count,
|
|
"compressed_tokens": compressed_token_count,
|
|
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None,
|
|
},
|
|
)
|
|
safe_id = safe_filename(bill_id)
|
|
request_lines.append(build_request(safe_id, model, compressed_text))
|
|
return request_lines, token_rows
|
|
|
|
|
|
def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
|
|
"""Write per-bill token counts to CSV. Returns (raw_total, compressed_total)."""
|
|
with path.open("w", newline="", encoding="utf-8") as handle:
|
|
writer = csv.DictWriter(
|
|
handle,
|
|
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"],
|
|
)
|
|
writer.writeheader()
|
|
writer.writerows(token_rows)
|
|
raw_total = sum(row["raw_tokens"] for row in token_rows)
|
|
compressed_total = sum(row["compressed_tokens"] for row in token_rows)
|
|
return raw_total, compressed_total
|
|
|
|
|
|
def create_batch(client: httpx.Client, input_file_id: str, description: str) -> dict:
|
|
"""Create a batch job and return its full response payload."""
|
|
response = client.post(
|
|
f"{OPENAI_API_BASE}/batches",
|
|
json={
|
|
"input_file_id": input_file_id,
|
|
"endpoint": "/v1/chat/completions",
|
|
"completion_window": "24h",
|
|
"metadata": {"description": description},
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
def main(
|
|
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
|
|
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path(
|
|
"output/openai_batch",
|
|
),
|
|
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
|
|
count: Annotated[int, typer.Option(help="Max bills to process, 0 = all")] = 0,
|
|
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
|
) -> None:
|
|
"""Submit an OpenAI Batch job of compressed bill summaries."""
|
|
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
|
|
|
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
|
|
raise typer.BadParameter(message)
|
|
if not csv_path.is_file():
|
|
message = f"CSV not found: {csv_path}"
|
|
raise typer.BadParameter(message)
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info("Loading %d bills from %s", count, csv_path)
|
|
bills = load_bills(csv_path, count)
|
|
if len(bills) < count:
|
|
logger.warning("Only %d bills available (requested %d)", len(bills), count)
|
|
|
|
encoder = get_encoding("o200k_base")
|
|
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
|
|
|
|
token_csv_path = output_dir / "token_counts.csv"
|
|
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows)
|
|
logger.info(
|
|
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
|
|
raw_tokens_total,
|
|
compressed_tokens_total,
|
|
(compressed_tokens_total / raw_tokens_total) if raw_tokens_total else 0.0,
|
|
token_csv_path,
|
|
)
|
|
|
|
jsonl_path = output_dir / "requests.jsonl"
|
|
write_jsonl(jsonl_path, request_lines)
|
|
logger.info("Wrote %s (%d bills)", jsonl_path, len(request_lines))
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
with httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client:
|
|
logger.info("Uploading JSONL")
|
|
file_id = upload_file(client, jsonl_path)
|
|
logger.info("Uploaded: %s", file_id)
|
|
|
|
logger.info("Creating batch")
|
|
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})")
|
|
logger.info("Batch created: %s", batch["id"])
|
|
|
|
metadata = {
|
|
"model": model,
|
|
"count": len(bills),
|
|
"jsonl": str(jsonl_path),
|
|
"input_file_id": file_id,
|
|
"batch_id": batch["id"],
|
|
"raw_tokens_total": raw_tokens_total,
|
|
"compressed_tokens_total": compressed_tokens_total,
|
|
"batch": batch,
|
|
}
|
|
metadata_path = output_dir / "batch.json"
|
|
metadata_path.write_text(json.dumps(metadata, indent=2))
|
|
logger.info("Wrote metadata to %s", metadata_path)
|
|
|
|
|
|
def cli() -> None:
|
|
"""Typer entry point."""
|
|
typer.run(main)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|