"""Summarize bill_text rows with GPT-5 and store results in the database.""" from __future__ import annotations import logging import tomllib from os import getenv from typing import Annotated, Any import httpx import typer from sqlalchemy import Select, exists, or_, select from sqlalchemy.orm import Session, selectinload from tiktoken import get_encoding from pipelines.config import get_config_dir from pipelines.orm.common import get_postgres_engine from pipelines.orm.data_science_dev.congress import ( Bill, BillText, SubjectType, VoteClassification, VoteRelationship, VoteTextTarget, ) from pipelines.tools.bill_token_compression import compress_bill_text logger = logging.getLogger(__name__) OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions" OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE" REQUEST_TIMEOUT_SECONDS = 60 def load_summarization_prompts( section: str = "summarization", ) -> dict[str, str]: summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml" return tomllib.loads(summarization_prompts.read_text())[section] class BillSummaryError(RuntimeError): """Raised when a bill summary request or response is invalid.""" def call_openai_summary( *, model: str, messages: list[dict[str, str]], ) -> str: """Call GPT and return the assistant message content.""" api_key = getenv("CLOSEDAI_TOKEN") if not api_key: msg = "CLOSEDAI_TOKEN is required" raise BillSummaryError(msg) response = httpx.post( OPENAI_CHAT_COMPLETIONS_URL, headers={ "Authorization": f"Bearer {api_key}", "OpenAI-Project": OPENAI_PROJECT_ID, "Content-Type": "application/json", }, json={ "model": model, "messages": messages, }, timeout=REQUEST_TIMEOUT_SECONDS, ) logger.info(f"{response.text=}") response.raise_for_status() return extract_message_content(response.json()) def build_bill_summary_messages( *, bill_text: BillText, summarization_prompts: dict[str, str], ) -> list[dict[str, str]]: """Build the GPT prompt messages plus compressed text and user prompt.""" if not bill_text.text_content: msg = f"bill_text id={bill_text.id} has no text_content" raise BillSummaryError(msg) compressed_text = compress_bill_text(bill_text.text_content) if not compressed_text: msg = f"bill_text id={bill_text.id} has no summarizable text_content" raise BillSummaryError(msg) user_prompt = summarization_prompts["user_template"].format( text_content=compressed_text ) user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt)) logger.info(f"{user_prompt_tokens=}") messages = [ {"role": "system", "content": summarization_prompts["system_prompt"]}, { "role": "user", "content": user_prompt, }, ] return messages, user_prompt_tokens def summarize_bill_text( *, model: str, bill_text: BillText, summarization_prompts: dict[str, str], ) -> str: """Generate and return a summary for one bill_text row.""" messages, user_prompt_tokens = build_bill_summary_messages( bill_text=bill_text, summarization_prompts=summarization_prompts, ) # This may only be for gpt-5.4 mini I need to read the docs if user_prompt_tokens > 272000: msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)" logger.warning(msg) return None summary = call_openai_summary( model=model, messages=messages, ).strip() if not summary: msg = f"Model returned an empty summary for bill_text id={bill_text.id}" raise BillSummaryError(msg) return summary def store_bill_summary_result( *, bill_text: BillText, summary: str, model: str, ) -> None: """Store a generated summary and the prompt/model metadata that produced it.""" bill_text.summary = summary bill_text.summarization_model = model bill_text.summarization_system_prompt_version = "v1.2" bill_text.summarization_user_prompt_version = "v1" def create_select_bill_texts_for_summarization( congress: int | None = None, bill_ids: list[int] | None = None, bill_text_ids: list[int] | None = None, with_votes_only: bool = False, force: bool = False, limit: int | None = None, ) -> Select: """Select bill_text rows that have source text and need summaries.""" stmt = ( select(BillText) .join(Bill, Bill.id == BillText.bill_id) .where(BillText.text_content.is_not(None), BillText.text_content != "") .options(selectinload(BillText.bill)) .order_by(BillText.id) ) if congress is not None: stmt = stmt.where(Bill.congress == congress) if bill_ids: stmt = stmt.where(BillText.bill_id.in_(bill_ids)) if bill_text_ids: stmt = stmt.where(BillText.id.in_(bill_text_ids)) if with_votes_only: stmt = stmt.where( exists( select(VoteTextTarget.vote_id) .join( VoteClassification, VoteClassification.vote_id == VoteTextTarget.vote_id, ) .where( VoteTextTarget.voted_text_version_id == BillText.id, VoteClassification.subject_type == SubjectType.MEASURE, VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE, VoteClassification.is_direct_vote_on_legislative_text.is_(True), VoteClassification.is_substantive_policy_vote.is_(True), VoteClassification.is_special_rule.is_(False), ) ) ) if not force: stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == "")) if limit is not None: stmt = stmt.limit(limit) return stmt def extract_message_content(data: dict[str, Any]) -> str: """Extract message content from a chat-completions response body.""" choices = data.get("choices") if not isinstance(choices, list) or not choices: msg = "Chat completion response did not contain choices" raise BillSummaryError(msg) first = choices[0] if not isinstance(first, dict): msg = "Chat completion choice must be an object" raise BillSummaryError(msg) message = first.get("message") if isinstance(message, dict) and isinstance(message.get("content"), str): return message["content"] if isinstance(first.get("text"), str): return first["text"] msg = "Chat completion response did not contain message content" raise BillSummaryError(msg) def main( model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini", congress: Annotated[ int | None, typer.Option(help="Only process one Congress.") ] = None, bill_ids: Annotated[ list[int] | None, typer.Option( "--bill-id", help="Only process one internal bill.id. Repeat for multiple bills.", ), ] = None, bill_text_ids: Annotated[ list[int] | None, typer.Option( "--bill-text-id", help="Only process one internal bill_text.id. Repeat for multiple rows.", ), ] = None, with_votes_only: Annotated[ bool, typer.Option( "--with-votes-only", help="Only process bill_text rows linked to at least one vote.", ), ] = False, limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None, force: Annotated[ bool, typer.Option(help="Regenerate summaries for rows that already have a summary."), ] = False, dry_run: Annotated[ bool, typer.Option(help="Print summaries without writing them to the database.") ] = False, log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO", ) -> None: """CLI entrypoint for generating and storing bill summaries.""" logging.basicConfig( level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) if not getenv("CLOSEDAI_TOKEN"): message = "CLOSEDAI_TOKEN is required" raise typer.BadParameter(message) summarization_prompts = load_summarization_prompts() engine = get_postgres_engine(name="DATA_SCIENCE_DEV") with Session(engine) as session: stmt = create_select_bill_texts_for_summarization( congress=congress, bill_ids=bill_ids, bill_text_ids=bill_text_ids, with_votes_only=with_votes_only, force=force, limit=limit, ) bill_texts = session.scalars(stmt).all() logger.info("Selected %d bill_text rows for summarization", len(bill_texts)) written = 0 for index, bill_text in enumerate(bill_texts, 1): summary = summarize_bill_text( model=model, bill_text=bill_text, summarization_prompts=summarization_prompts, ) if summary is None: logger.warning("Skipping bill_text id=%s", bill_text.id) continue store_bill_summary_result( bill_text=bill_text, summary=summary, model=model, ) if index % 100 == 0: session.commit() written += 1 session.commit() logger.info("Stored summary for bill_text id=%s", bill_text.id) logger.info("Done: stored %d summaries", written) def cli() -> None: """Typer entry point.""" typer.run(main) if __name__ == "__main__": cli()