diff --git a/pipelines/tools/summarize_bills.py b/pipelines/tools/summarize_bills.py new file mode 100644 index 0000000..3931789 --- /dev/null +++ b/pipelines/tools/summarize_bills.py @@ -0,0 +1,268 @@ +"""Summarize bill_text rows with GPT-5 and store results in the database.""" + +from __future__ import annotations + +import logging +import tomllib +from os import getenv +from pathlib import Path +from typing import Annotated, Any + +import httpx +import typer +from sqlalchemy import Select, or_, select +from sqlalchemy.orm import Session, selectinload + +from pipelines.config import get_config_dir +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.congress import Bill, BillText +from pipelines.tools.bill_token_compression import compress_bill_text + +logger = logging.getLogger(__name__) + +OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE" + + + +def _find_prompts_path() -> Path: + return get_config_dir() / "prompts" / "summarization_prompts.toml" + + +def load_summarization_prompts( + section: str = "summarization", +) -> dict[str, str]: + return tomllib.loads(_find_prompts_path().read_text())[section] + + +class BillSummaryError(RuntimeError): + """Raised when a bill summary request or response is invalid.""" + + +def call_openai_summary( + *, + model: str, + messages: list[dict[str, str]], +) -> str: + """Call GPT and return the assistant message content.""" + api_key = getenv("CLOSEDAI_TOKEN") + if not api_key: + msg = "CLOSEDAI_TOKEN is required" + raise BillSummaryError(msg) + + response = httpx.post( + "https://api.openai.com/v1/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "OpenAI-Project": OPENAI_PROJECT_ID, + "Content-Type": "application/json", + }, + json={ + "model": model, + "messages": messages, + }, + timeout=60, + ) + response.raise_for_status() + return extract_message_content(response.json()) + + +def build_bill_summary_messages( + *, + bill_text: BillText, +) -> list[dict[str, str]]: + """Build the GPT prompt messages for one bill text row.""" + if not bill_text.text_content: + msg = f"bill_text id={bill_text.id} has no text_content" + raise BillSummaryError(msg) + + bill = bill_text.bill + compressed_text = compress_bill_text(bill_text.text_content) + if not compressed_text: + msg = f"bill_text id={bill_text.id} has no summarizable text_content" + raise BillSummaryError(msg) + + metadata = "\n".join( + ( + f"Congress: {bill.congress}", + f"Bill: {bill.bill_type} {bill.number}", + f"Title: {bill.title_short or bill.title or bill.official_title or ''}", + f"Top subject term: {bill.subjects_top_term or ''}", + f"Text version: {bill_text.version_code}" + + (f" ({bill_text.version_name})" if bill_text.version_name else ""), + ) + ) + summarization_prompts = load_summarization_prompts() + user_prompt = "\n\n".join( + ( + "BILL METADATA:", + metadata, + summarization_prompts["user_template"].format(text_content=compressed_text), + ) + ) + + return [ + {"role": "system", "content": summarization_prompts["system_prompt"]}, + { + "role": "user", + "content": user_prompt, + }, + ] + + +def summarize_bill_text( + *, + model: str, + bill_text: BillText, +) -> str: + """Generate and return a summary for one bill_text row.""" + messages = build_bill_summary_messages(bill_text=bill_text) + summary = call_openai_summary( + model=model, + messages=messages, + ).strip() + if not summary: + msg = f"Model returned an empty summary for bill_text id={bill_text.id}" + raise BillSummaryError(msg) + return summary + + +def store_bill_summary_result( + *, + bill_text: BillText, + summary: str, + model: str, +) -> None: + """Store a generated summary and the prompt/model metadata that produced it.""" + bill_text.summary = summary + bill_text.summarization_model = model + bill_text.summarization_system_prompt_version = "v1" + bill_text.summarization_user_prompt_version = "v1" + + +def create_select_bill_texts_for_summarization( + congress: int | None = None, + bill_ids: list[int] | None = None, + bill_text_ids: list[int] | None = None, + force: bool = False, + limit: int | None = None, +) -> Select: + """Select bill_text rows that have source text and need summaries.""" + stmt = ( + select(BillText) + .join(Bill, Bill.id == BillText.bill_id) + .where(BillText.text_content.is_not(None), BillText.text_content != "") + .options(selectinload(BillText.bill)) + .order_by(BillText.id) + ) + if congress is not None: + stmt = stmt.where(Bill.congress == congress) + if bill_ids: + stmt = stmt.where(BillText.bill_id.in_(bill_ids)) + if bill_text_ids: + stmt = stmt.where(BillText.id.in_(bill_text_ids)) + if not force: + stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == "")) + if limit is not None: + stmt = stmt.limit(limit) + return stmt + + +def extract_message_content(data: dict[str, Any]) -> str: + """Extract message content from a chat-completions response body.""" + choices = data.get("choices") + if not isinstance(choices, list) or not choices: + msg = "Chat completion response did not contain choices" + raise BillSummaryError(msg) + + first = choices[0] + if not isinstance(first, dict): + msg = "Chat completion choice must be an object" + raise BillSummaryError(msg) + + message = first.get("message") + if isinstance(message, dict) and isinstance(message.get("content"), str): + return message["content"] + if isinstance(first.get("text"), str): + return first["text"] + + msg = "Chat completion response did not contain message content" + raise BillSummaryError(msg) + + +def main( + model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini", + congress: Annotated[ + int | None, typer.Option(help="Only process one Congress.") + ] = None, + bill_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-id", + help="Only process one internal bill.id. Repeat for multiple bills.", + ), + ] = None, + bill_text_ids: Annotated[ + list[int] | None, + typer.Option( + "--bill-text-id", + help="Only process one internal bill_text.id. Repeat for multiple rows.", + ), + ] = None, + limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None, + force: Annotated[ + bool, + typer.Option(help="Regenerate summaries for rows that already have a summary."), + ] = False, + dry_run: Annotated[ + bool, typer.Option(help="Print summaries without writing them to the database.") + ] = False, + log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO", +) -> None: + """CLI entrypoint for generating and storing bill summaries.""" + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + if not getenv("CLOSEDAI_TOKEN"): + message = "CLOSEDAI_TOKEN is required" + raise typer.BadParameter(message) + + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + with Session(engine) as session: + stmt = create_select_bill_texts_for_summarization( + congress=congress, + bill_ids=bill_ids, + bill_text_ids=bill_text_ids, + force=force, + limit=limit, + ) + bill_texts = session.scalars(stmt).all() + logger.info("Selected %d bill_text rows for summarization", len(bill_texts)) + + written = 0 + for index, bill_text in enumerate(bill_texts, 1): + summary = summarize_bill_text( + model=model, + bill_text=bill_text, + ) + store_bill_summary_result( + bill_text=bill_text, + summary=summary, + model=model, + ) + if index % 100 == 0: + session.commit() + written += 1 + session.commit() + logger.info("Stored summary for bill_text id=%s", bill_text.id) + + logger.info("Done: stored %d summaries", written) + + +def cli() -> None: + """Typer entry point.""" + typer.run(main) + + +if __name__ == "__main__": + cli()