"""Summarize bill_text rows with GPT-5 and store results in the database.""" from __future__ import annotations import logging import tomllib from os import getenv from typing import Annotated, Any import httpx import typer from sqlalchemy import Select, exists, select from sqlalchemy.orm import Session, selectinload from tiktoken import get_encoding from pipelines.config import get_config_dir from pipelines.orm.common import get_postgres_engine from pipelines.orm.data_science_dev.congress import ( Bill, BillText, BillTextSummary, SubjectType, VoteClassification, VoteRelationship, VoteTextTarget, ) from pipelines.tools.bill_token_compression import compress_bill_text logger = logging.getLogger(__name__) OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions" OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE" REQUEST_TIMEOUT_SECONDS = 60 def load_summarization_prompts( section: str = "summarization", ) -> dict[str, str]: summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml" return tomllib.loads(summarization_prompts.read_text())[section] class BillSummaryError(RuntimeError): """Raised when a bill summary request or response is invalid.""" def call_openai_summary( *, model: str, messages: list[dict[str, str]], ) -> str: """Call GPT and return the assistant message content.""" api_key = getenv("CLOSEDAI_TOKEN") if not api_key: msg = "CLOSEDAI_TOKEN is required" raise BillSummaryError(msg) response = httpx.post( OPENAI_CHAT_COMPLETIONS_URL, headers={ "Authorization": f"Bearer {api_key}", "OpenAI-Project": OPENAI_PROJECT_ID, "Content-Type": "application/json", }, json={ "model": model, "messages": messages, }, timeout=REQUEST_TIMEOUT_SECONDS, ) logger.info(f"{response.text=}") response.raise_for_status() return extract_message_content(response.json()) def build_bill_summary_messages( *, bill_text: BillText, summarization_prompts: dict[str, str], ) -> list[dict[str, str]]: """Build the GPT prompt messages plus compressed text and user prompt.""" if not bill_text.text_content: msg = f"bill_text id={bill_text.id} has no text_content" raise BillSummaryError(msg) compressed_text = compress_bill_text(bill_text.text_content) if not compressed_text: msg = f"bill_text id={bill_text.id} has no summarizable text_content" raise BillSummaryError(msg) user_prompt = summarization_prompts["user_template"].format( text_content=compressed_text ) user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt)) logger.info(f"{user_prompt_tokens=}") messages = [ {"role": "system", "content": summarization_prompts["system_prompt"]}, { "role": "user", "content": user_prompt, }, ] return messages, user_prompt_tokens def summarize_bill_text( *, model: str, bill_text: BillText, summarization_prompts: dict[str, str], ) -> str | None: """Generate and return a summary for one bill_text row.""" messages, user_prompt_tokens = build_bill_summary_messages( bill_text=bill_text, summarization_prompts=summarization_prompts, ) # This may only be for gpt-5.4 mini I need to read the docs if user_prompt_tokens > 272000: msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)" logger.warning(msg) return None summary = call_openai_summary( model=model, messages=messages, ).strip() if not summary: msg = f"Model returned an empty summary for bill_text id={bill_text.id}" raise BillSummaryError(msg) return summary def store_bill_summary_result( *, session: Session, bill_text: BillText, summary: str, model: str, ) -> BillTextSummary: """Store a generated summary and the prompt/model metadata that produced it.""" summary_row = BillTextSummary( bill_text=bill_text, summary=summary, summarization_model=model, summarization_system_prompt_version="v1.2", summarization_user_prompt_version="v1", ) session.add(summary_row) return summary_row def create_select_bill_texts_for_summarization( congress: int | None = None, bill_ids: list[int] | None = None, bill_text_ids: list[int] | None = None, with_votes_only: bool = False, force: bool = False, limit: int | None = None, ) -> Select[tuple[BillText]]: """Select bill_text rows that have source text and need summaries.""" stmt = ( select(BillText) .join(Bill, Bill.id == BillText.bill_id) .where(BillText.text_content.is_not(None), BillText.text_content != "") .options(selectinload(BillText.bill)) .order_by(BillText.id) ) if congress is not None: stmt = stmt.where(Bill.congress == congress) if bill_ids: stmt = stmt.where(BillText.bill_id.in_(bill_ids)) if bill_text_ids: stmt = stmt.where(BillText.id.in_(bill_text_ids)) if with_votes_only: stmt = stmt.where( exists( select(VoteTextTarget.vote_id) .join( VoteClassification, VoteClassification.vote_id == VoteTextTarget.vote_id, ) .where( VoteTextTarget.voted_text_version_id == BillText.id, VoteClassification.subject_type == SubjectType.MEASURE, VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE, VoteClassification.is_direct_vote_on_legislative_text.is_(True), VoteClassification.is_substantive_policy_vote.is_(True), VoteClassification.is_special_rule.is_(False), ) ) ) if not force: stmt = stmt.where( ~exists( select(BillTextSummary.id).where( BillTextSummary.bill_text_id == BillText.id ) ) ) if limit is not None: stmt = stmt.limit(limit) return stmt def extract_message_content(data: dict[str, Any]) -> str: """Extract message content from a chat-completions response body.""" choices = data.get("choices") if not isinstance(choices, list) or not choices: msg = "Chat completion response did not contain choices" raise BillSummaryError(msg) first = choices[0] if not isinstance(first, dict): msg = "Chat completion choice must be an object" raise BillSummaryError(msg) message = first.get("message") if isinstance(message, dict) and isinstance(message.get("content"), str): return message["content"] if isinstance(first.get("text"), str): return first["text"] msg = "Chat completion response did not contain message content" raise BillSummaryError(msg) def main( model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini", congress: Annotated[ int | None, typer.Option(help="Only process one Congress.") ] = None, bill_ids: Annotated[ list[int] | None, typer.Option( "--bill-id", help="Only process one internal bill.id. Repeat for multiple bills.", ), ] = None, bill_text_ids: Annotated[ list[int] | None, typer.Option( "--bill-text-id", help="Only process one internal bill_text.id. Repeat for multiple rows.", ), ] = None, with_votes_only: Annotated[ bool, typer.Option( "--with-votes-only", help="Only process bill_text rows linked to at least one vote.", ), ] = False, limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None, force: Annotated[ bool, typer.Option(help="Regenerate summaries for rows that already have a summary."), ] = False, dry_run: Annotated[ bool, typer.Option(help="Print summaries without writing them to the database.") ] = False, log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO", ) -> None: """CLI entrypoint for generating and storing bill summaries.""" logging.basicConfig( level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) if not getenv("CLOSEDAI_TOKEN"): message = "CLOSEDAI_TOKEN is required" raise typer.BadParameter(message) summarization_prompts = load_summarization_prompts() engine = get_postgres_engine(name="DATA_SCIENCE_DEV") with Session(engine) as session: stmt = create_select_bill_texts_for_summarization( congress=congress, bill_ids=bill_ids, bill_text_ids=bill_text_ids, with_votes_only=with_votes_only, force=force, limit=limit, ) bill_texts = session.scalars(stmt).all() logger.info("Selected %d bill_text rows for summarization", len(bill_texts)) written = 0 for index, bill_text in enumerate(bill_texts, 1): summary = summarize_bill_text( model=model, bill_text=bill_text, summarization_prompts=summarization_prompts, ) if summary is None: logger.warning("Skipping bill_text id=%s", bill_text.id) continue store_bill_summary_result( session=session, bill_text=bill_text, summary=summary, model=model, ) if index % 100 == 0: session.commit() written += 1 session.commit() logger.info("Stored summary for bill_text id=%s", bill_text.id) logger.info("Done: stored %d summaries", written) def cli() -> None: """Typer entry point.""" typer.run(main) if __name__ == "__main__": cli()