Files
weave/pipelines/jobs/summarize_bills.py
T

310 lines
9.7 KiB
Python

"""Summarize bill_text rows with GPT-5 and store results in the database."""
from __future__ import annotations
import logging
import tomllib
from os import getenv
from typing import Annotated, Any
import httpx
import typer
from sqlalchemy import Select, exists, or_, select
from sqlalchemy.orm import Session, selectinload
from tiktoken import get_encoding
from pipelines.config import get_config_dir
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
from pipelines.tools.bill_token_compression import compress_bill_text
logger = logging.getLogger(__name__)
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
REQUEST_TIMEOUT_SECONDS = 60
def load_summarization_prompts(
section: str = "summarization",
) -> dict[str, str]:
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
return tomllib.loads(summarization_prompts.read_text())[section]
class BillSummaryError(RuntimeError):
"""Raised when a bill summary request or response is invalid."""
def call_openai_summary(
*,
model: str,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
api_key = getenv("CLOSEDAI_TOKEN")
if not api_key:
msg = "CLOSEDAI_TOKEN is required"
raise BillSummaryError(msg)
response = httpx.post(
OPENAI_CHAT_COMPLETIONS_URL,
headers={
"Authorization": f"Bearer {api_key}",
"OpenAI-Project": OPENAI_PROJECT_ID,
"Content-Type": "application/json",
},
json={
"model": model,
"messages": messages,
},
timeout=REQUEST_TIMEOUT_SECONDS,
)
logger.info(f"{response.text=}")
response.raise_for_status()
return extract_message_content(response.json())
def build_bill_summary_messages(
*,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> list[dict[str, str]]:
"""Build the GPT prompt messages plus compressed text and user prompt."""
if not bill_text.text_content:
msg = f"bill_text id={bill_text.id} has no text_content"
raise BillSummaryError(msg)
compressed_text = compress_bill_text(bill_text.text_content)
if not compressed_text:
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
raise BillSummaryError(msg)
user_prompt = summarization_prompts["user_template"].format(
text_content=compressed_text
)
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
logger.info(f"{user_prompt_tokens=}")
messages = [
{"role": "system", "content": summarization_prompts["system_prompt"]},
{
"role": "user",
"content": user_prompt,
},
]
return messages, user_prompt_tokens
def summarize_bill_text(
*,
model: str,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> str:
"""Generate and return a summary for one bill_text row."""
messages, user_prompt_tokens = build_bill_summary_messages(
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
# This may only be for gpt-5.4 mini I need to read the docs
if user_prompt_tokens > 272000:
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
logger.warning(msg)
return None
summary = call_openai_summary(
model=model,
messages=messages,
).strip()
if not summary:
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
raise BillSummaryError(msg)
return summary
def store_bill_summary_result(
*,
bill_text: BillText,
summary: str,
model: str,
) -> None:
"""Store a generated summary and the prompt/model metadata that produced it."""
bill_text.summary = summary
bill_text.summarization_model = model
bill_text.summarization_system_prompt_version = "v1.2"
bill_text.summarization_user_prompt_version = "v1"
def create_select_bill_texts_for_summarization(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select:
"""Select bill_text rows that have source text and need summaries."""
stmt = (
select(BillText)
.join(Bill, Bill.id == BillText.bill_id)
.where(BillText.text_content.is_not(None), BillText.text_content != "")
.options(selectinload(BillText.bill))
.order_by(BillText.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
stmt = stmt.where(BillText.id.in_(bill_text_ids))
if with_votes_only:
stmt = stmt.where(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
if not force:
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise BillSummaryError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise BillSummaryError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise BillSummaryError(msg)
def main(
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process bill_text rows linked to at least one vote.",
),
] = False,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate summaries for rows that already have a summary."),
] = False,
dry_run: Annotated[
bool, typer.Option(help="Print summaries without writing them to the database.")
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill summaries."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
if not getenv("CLOSEDAI_TOKEN"):
message = "CLOSEDAI_TOKEN is required"
raise typer.BadParameter(message)
summarization_prompts = load_summarization_prompts()
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
stmt = create_select_bill_texts_for_summarization(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bill_texts = session.scalars(stmt).all()
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
written = 0
for index, bill_text in enumerate(bill_texts, 1):
summary = summarize_bill_text(
model=model,
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
if summary is None:
logger.warning("Skipping bill_text id=%s", bill_text.id)
continue
store_bill_summary_result(
bill_text=bill_text,
summary=summary,
model=model,
)
if index % 100 == 0:
session.commit()
written += 1
session.commit()
logger.info("Stored summary for bill_text id=%s", bill_text.id)
logger.info("Done: stored %d summaries", written)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()