mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 21:18:18 -04:00
added bill_token_compression.py
tested on sample size of 100 bills matching the distribution of our data Compression saves ~11.5% on prompt tokens; completion/reasoning are roughly equal across the two sets. prompt completion reasoning total compressed 349,460 157,110 112,128 506,570 uncompressed 394,948 154,710 110,080 549,658 delta −45,488 +2,400 +2,048 −43,088
This commit is contained in:
270
python/prompt_bench/compresion_test.py
Normal file
270
python/prompt_bench/compresion_test.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Run two interactive OpenAI chat-completion sweeps over bill text.
|
||||
|
||||
Reads the first N bills from a CSV with a `text_content` column and sends two
|
||||
sweeps through `/v1/chat/completions` concurrently — one with the raw bill
|
||||
text, one with the compressed bill text. Each request's prompt is saved to
|
||||
disk alongside the OpenAI response id so the prompts and responses can be
|
||||
correlated later.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
|
||||
from python.prompt_bench.bill_token_compression import compress_bill_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
DEFAULT_MODEL = "gpt-5.4-mini"
|
||||
DEFAULT_COUNT = 100
|
||||
SEED = 42
|
||||
|
||||
SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
|
||||
|
||||
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action — including secondary effects that might be buried in subsections.
|
||||
|
||||
EXTRACTION RULES:
|
||||
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
|
||||
- FOCUS ON: operative verbs — what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
|
||||
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
|
||||
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
|
||||
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
|
||||
|
||||
OUTPUT FORMAT — plain structured text, not JSON:
|
||||
|
||||
OPERATIVE ACTIONS:
|
||||
[Numbered list of what the bill actually does, one action per line, max 20 words each]
|
||||
|
||||
AFFECTED POPULATIONS:
|
||||
[Who gains something, who loses something, or whose behavior is regulated]
|
||||
|
||||
MECHANISMS:
|
||||
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
|
||||
|
||||
POLICY THREADS:
|
||||
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
|
||||
|
||||
SYMBOLIC/PROCEDURAL ONLY:
|
||||
[Yes or No — is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
|
||||
|
||||
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
|
||||
|
||||
USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
|
||||
|
||||
BILL TEXT:
|
||||
{text_content}"""
|
||||
|
||||
|
||||
def load_bills(csv_path: Path, count: int) -> list[tuple[str, str]]:
|
||||
"""Return up to `count` (bill_id, text_content) tuples with non-empty text."""
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
bills: list[tuple[str, str]] = []
|
||||
with csv_path.open(newline="", encoding="utf-8") as handle:
|
||||
reader = csv.DictReader(handle)
|
||||
for row in reader:
|
||||
text_content = (row.get("text_content") or "").strip()
|
||||
if not text_content:
|
||||
continue
|
||||
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
|
||||
version_code = row.get("version_code") or ""
|
||||
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
|
||||
bills.append((unique_id, text_content))
|
||||
if len(bills) >= count:
|
||||
break
|
||||
return bills
|
||||
|
||||
|
||||
def build_messages(bill_text: str) -> list[dict]:
|
||||
"""Return the system + user message pair for a bill."""
|
||||
return [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": USER_TEMPLATE.format(text_content=bill_text)},
|
||||
]
|
||||
|
||||
|
||||
def safe_filename(value: str) -> str:
|
||||
"""Make a string safe for use as a filename."""
|
||||
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
|
||||
|
||||
|
||||
def run_one_request(
|
||||
client: httpx.Client,
|
||||
*,
|
||||
bill_id: str,
|
||||
label: str,
|
||||
bill_text: str,
|
||||
model: str,
|
||||
output_path: Path,
|
||||
) -> tuple[bool, float, str | None]:
|
||||
"""Send one chat-completion request and persist prompt + response.
|
||||
|
||||
Returns (success, elapsed_seconds, response_id).
|
||||
"""
|
||||
messages = build_messages(bill_text)
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"seed": SEED,
|
||||
}
|
||||
start = time.monotonic()
|
||||
record: dict = {
|
||||
"bill_id": bill_id,
|
||||
"label": label,
|
||||
"model": model,
|
||||
"seed": SEED,
|
||||
"input_chars": len(bill_text),
|
||||
"messages": messages,
|
||||
}
|
||||
try:
|
||||
response = client.post(f"{OPENAI_API_BASE}/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
except httpx.HTTPStatusError as error:
|
||||
elapsed = time.monotonic() - start
|
||||
record["error"] = {
|
||||
"status_code": error.response.status_code,
|
||||
"body": error.response.text,
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
|
||||
logger.exception("HTTP error for %s/%s after %.2fs", label, bill_id, elapsed)
|
||||
return False, elapsed, None
|
||||
except Exception as error:
|
||||
elapsed = time.monotonic() - start
|
||||
record["error"] = {"message": str(error), "elapsed_seconds": elapsed}
|
||||
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
|
||||
logger.exception("Failed: %s/%s after %.2fs", label, bill_id, elapsed)
|
||||
return False, elapsed, None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
response_id = body.get("id")
|
||||
record["response_id"] = response_id
|
||||
record["elapsed_seconds"] = elapsed
|
||||
record["usage"] = body.get("usage")
|
||||
record["response"] = body
|
||||
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
|
||||
logger.info("Done: %s/%s id=%s in %.2fs", label, bill_id, response_id, elapsed)
|
||||
return True, elapsed, response_id
|
||||
|
||||
|
||||
def main(
|
||||
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
|
||||
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path(
|
||||
"output/openai_runs",
|
||||
),
|
||||
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
|
||||
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
|
||||
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16,
|
||||
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
|
||||
) -> None:
|
||||
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
|
||||
raise typer.BadParameter(message)
|
||||
if not csv_path.is_file():
|
||||
message = f"CSV not found: {csv_path}"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
compressed_dir = output_dir / "compressed"
|
||||
uncompressed_dir = output_dir / "uncompressed"
|
||||
compressed_dir.mkdir(parents=True, exist_ok=True)
|
||||
uncompressed_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Loading %d bills from %s", count, csv_path)
|
||||
bills = load_bills(csv_path, count)
|
||||
if len(bills) < count:
|
||||
logger.warning("Only %d bills available (requested %d)", len(bills), count)
|
||||
|
||||
tasks: list[tuple[str, str, str, Path]] = []
|
||||
for bill_id, text_content in bills:
|
||||
filename = f"{safe_filename(bill_id)}.json"
|
||||
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename))
|
||||
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename))
|
||||
|
||||
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
completed = 0
|
||||
failed = 0
|
||||
index: list[dict] = []
|
||||
wall_start = time.monotonic()
|
||||
with (
|
||||
httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client,
|
||||
ThreadPoolExecutor(
|
||||
max_workers=concurrency,
|
||||
) as executor,
|
||||
):
|
||||
future_to_task = {
|
||||
executor.submit(
|
||||
run_one_request,
|
||||
client,
|
||||
bill_id=bill_id,
|
||||
label=label,
|
||||
bill_text=bill_text,
|
||||
model=model,
|
||||
output_path=output_path,
|
||||
): (bill_id, label, output_path)
|
||||
for bill_id, label, bill_text, output_path in tasks
|
||||
}
|
||||
for future in as_completed(future_to_task):
|
||||
bill_id, label, output_path = future_to_task[future]
|
||||
success, elapsed, response_id = future.result()
|
||||
if success:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
index.append(
|
||||
{
|
||||
"bill_id": bill_id,
|
||||
"label": label,
|
||||
"response_id": response_id,
|
||||
"elapsed_seconds": elapsed,
|
||||
"success": success,
|
||||
"path": str(output_path),
|
||||
},
|
||||
)
|
||||
wall_elapsed = time.monotonic() - wall_start
|
||||
|
||||
summary = {
|
||||
"model": model,
|
||||
"count": len(bills),
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"wall_seconds": wall_elapsed,
|
||||
"concurrency": concurrency,
|
||||
"results": index,
|
||||
}
|
||||
summary_path = output_dir / "summary.json"
|
||||
summary_path.write_text(json.dumps(summary, indent=2))
|
||||
logger.info(
|
||||
"Done: completed=%d failed=%d wall=%.1fs summary=%s",
|
||||
completed,
|
||||
failed,
|
||||
wall_elapsed,
|
||||
summary_path,
|
||||
)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
typer.run(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user