"""Run two interactive OpenAI chat-completion sweeps over bill text. Reads the first N bills from a CSV with a `text_content` column and sends two sweeps through `/v1/chat/completions` concurrently — one with the raw bill text, one with the compressed bill text. Each request's prompt is saved to disk alongside the OpenAI response id so the prompts and responses can be correlated later. """ from __future__ import annotations import csv import json import logging import re import sys import time import tomllib from concurrent.futures import ThreadPoolExecutor, as_completed from os import getenv from pathlib import Path from typing import Annotated import httpx import typer from python.prompt_bench.bill_token_compression import compress_bill_text _PROMPTS_PATH = Path(__file__).resolve().parents[2] / "config" / "prompts" / "summarization_prompts.toml" _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] logger = logging.getLogger(__name__) OPENAI_API_BASE = "https://api.openai.com/v1" DEFAULT_MODEL = "gpt-5.4-mini" DEFAULT_COUNT = 100 SEED = 42 def load_bills(csv_path: Path, count: int) -> list[tuple[str, str]]: """Return up to `count` (bill_id, text_content) tuples with non-empty text.""" csv.field_size_limit(sys.maxsize) bills: list[tuple[str, str]] = [] with csv_path.open(newline="", encoding="utf-8") as handle: reader = csv.DictReader(handle) for row in reader: text_content = (row.get("text_content") or "").strip() if not text_content: continue bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}" version_code = row.get("version_code") or "" unique_id = f"{bill_id}-{version_code}" if version_code else bill_id bills.append((unique_id, text_content)) if len(bills) >= count: break return bills def build_messages(bill_text: str) -> list[dict]: """Return the system + user message pair for a bill.""" return [ {"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT}, {"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)}, ] def safe_filename(value: str) -> str: """Make a string safe for use as a filename.""" return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed" def run_one_request( client: httpx.Client, *, bill_id: str, label: str, bill_text: str, model: str, output_path: Path, ) -> tuple[bool, float, str | None]: """Send one chat-completion request and persist prompt + response. Returns (success, elapsed_seconds, response_id). """ messages = build_messages(bill_text) payload = { "model": model, "messages": messages, "seed": SEED, } start = time.monotonic() record: dict = { "bill_id": bill_id, "label": label, "model": model, "seed": SEED, "input_chars": len(bill_text), "messages": messages, } try: response = client.post(f"{OPENAI_API_BASE}/chat/completions", json=payload) response.raise_for_status() body = response.json() except httpx.HTTPStatusError as error: elapsed = time.monotonic() - start record["error"] = { "status_code": error.response.status_code, "body": error.response.text, "elapsed_seconds": elapsed, } output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2)) logger.exception("HTTP error for %s/%s after %.2fs", label, bill_id, elapsed) return False, elapsed, None except Exception as error: elapsed = time.monotonic() - start record["error"] = {"message": str(error), "elapsed_seconds": elapsed} output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2)) logger.exception("Failed: %s/%s after %.2fs", label, bill_id, elapsed) return False, elapsed, None elapsed = time.monotonic() - start response_id = body.get("id") record["response_id"] = response_id record["elapsed_seconds"] = elapsed record["usage"] = body.get("usage") record["response"] = body output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2)) logger.info("Done: %s/%s id=%s in %.2fs", label, bill_id, response_id, elapsed) return True, elapsed, response_id def main( csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"), output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path( "output/openai_runs", ), model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL, count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT, concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16, log_level: Annotated[str, typer.Option(help="Log level")] = "INFO", ) -> None: """Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text.""" logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s") api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY") if not api_key: message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set" raise typer.BadParameter(message) if not csv_path.is_file(): message = f"CSV not found: {csv_path}" raise typer.BadParameter(message) compressed_dir = output_dir / "compressed" uncompressed_dir = output_dir / "uncompressed" compressed_dir.mkdir(parents=True, exist_ok=True) uncompressed_dir.mkdir(parents=True, exist_ok=True) logger.info("Loading %d bills from %s", count, csv_path) bills = load_bills(csv_path, count) if len(bills) < count: logger.warning("Only %d bills available (requested %d)", len(bills), count) tasks: list[tuple[str, str, str, Path]] = [] for bill_id, text_content in bills: filename = f"{safe_filename(bill_id)}.json" tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename)) tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename)) logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency) headers = {"Authorization": f"Bearer {api_key}"} completed = 0 failed = 0 index: list[dict] = [] wall_start = time.monotonic() with ( httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client, ThreadPoolExecutor( max_workers=concurrency, ) as executor, ): future_to_task = { executor.submit( run_one_request, client, bill_id=bill_id, label=label, bill_text=bill_text, model=model, output_path=output_path, ): (bill_id, label, output_path) for bill_id, label, bill_text, output_path in tasks } for future in as_completed(future_to_task): bill_id, label, output_path = future_to_task[future] success, elapsed, response_id = future.result() if success: completed += 1 else: failed += 1 index.append( { "bill_id": bill_id, "label": label, "response_id": response_id, "elapsed_seconds": elapsed, "success": success, "path": str(output_path), }, ) wall_elapsed = time.monotonic() - wall_start summary = { "model": model, "count": len(bills), "completed": completed, "failed": failed, "wall_seconds": wall_elapsed, "concurrency": concurrency, "results": index, } summary_path = output_dir / "summary.json" summary_path.write_text(json.dumps(summary, indent=2)) logger.info( "Done: completed=%d failed=%d wall=%.1fs summary=%s", completed, failed, wall_elapsed, summary_path, ) def cli() -> None: """Typer entry point.""" typer.run(main) if __name__ == "__main__": cli()