Files
weave/pipelines/jobs/summarize_bills.py
T

324 lines
10 KiB
Python

"""Summarize bill_text rows with GPT-5 and store results in the database."""
from __future__ import annotations
import logging
import tomllib
from os import getenv
from typing import Annotated, Any
import httpx
import typer
from sqlalchemy import Select, exists, select
from sqlalchemy.orm import Session, selectinload
from tiktoken import get_encoding
from pipelines.config import get_config_dir
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
BillTextSummary,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
from pipelines.tools.bill_token_compression import compress_bill_text
logger = logging.getLogger(__name__)
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
REQUEST_TIMEOUT_SECONDS = 60
def load_summarization_prompts(
section: str = "summarization",
) -> dict[str, str]:
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
return tomllib.loads(summarization_prompts.read_text())[section]
class BillSummaryError(RuntimeError):
"""Raised when a bill summary request or response is invalid."""
def call_openai_summary(
*,
model: str,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
api_key = getenv("CLOSEDAI_TOKEN")
if not api_key:
msg = "CLOSEDAI_TOKEN is required"
raise BillSummaryError(msg)
response = httpx.post(
OPENAI_CHAT_COMPLETIONS_URL,
headers={
"Authorization": f"Bearer {api_key}",
"OpenAI-Project": OPENAI_PROJECT_ID,
"Content-Type": "application/json",
},
json={
"model": model,
"messages": messages,
},
timeout=REQUEST_TIMEOUT_SECONDS,
)
logger.info(f"{response.text=}")
response.raise_for_status()
return extract_message_content(response.json())
def build_bill_summary_messages(
*,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> list[dict[str, str]]:
"""Build the GPT prompt messages plus compressed text and user prompt."""
if not bill_text.text_content:
msg = f"bill_text id={bill_text.id} has no text_content"
raise BillSummaryError(msg)
compressed_text = compress_bill_text(bill_text.text_content)
if not compressed_text:
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
raise BillSummaryError(msg)
user_prompt = summarization_prompts["user_template"].format(
text_content=compressed_text
)
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
logger.info(f"{user_prompt_tokens=}")
messages = [
{"role": "system", "content": summarization_prompts["system_prompt"]},
{
"role": "user",
"content": user_prompt,
},
]
return messages, user_prompt_tokens
def summarize_bill_text(
*,
model: str,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> str | None:
"""Generate and return a summary for one bill_text row."""
messages, user_prompt_tokens = build_bill_summary_messages(
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
# This may only be for gpt-5.4 mini I need to read the docs
if user_prompt_tokens > 272000:
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
logger.warning(msg)
return None
summary = call_openai_summary(
model=model,
messages=messages,
).strip()
if not summary:
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
raise BillSummaryError(msg)
return summary
def store_bill_summary_result(
*,
session: Session,
bill_text: BillText,
summary: str,
model: str,
) -> BillTextSummary:
"""Store a generated summary and the prompt/model metadata that produced it."""
summary_row = BillTextSummary(
bill_text=bill_text,
summary=summary,
summarization_model=model,
summarization_system_prompt_version="v1.2",
summarization_user_prompt_version="v1",
)
session.add(summary_row)
return summary_row
def create_select_bill_texts_for_summarization(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select[tuple[BillText]]:
"""Select bill_text rows that have source text and need summaries."""
stmt = (
select(BillText)
.join(Bill, Bill.id == BillText.bill_id)
.where(BillText.text_content.is_not(None), BillText.text_content != "")
.options(selectinload(BillText.bill))
.order_by(BillText.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
stmt = stmt.where(BillText.id.in_(bill_text_ids))
if with_votes_only:
stmt = stmt.where(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
if not force:
stmt = stmt.where(
~exists(
select(BillTextSummary.id).where(
BillTextSummary.bill_text_id == BillText.id
)
)
)
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise BillSummaryError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise BillSummaryError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise BillSummaryError(msg)
def main(
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process bill_text rows linked to at least one vote.",
),
] = False,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate summaries for rows that already have a summary."),
] = False,
dry_run: Annotated[
bool, typer.Option(help="Print summaries without writing them to the database.")
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill summaries."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
if not getenv("CLOSEDAI_TOKEN"):
message = "CLOSEDAI_TOKEN is required"
raise typer.BadParameter(message)
summarization_prompts = load_summarization_prompts()
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
stmt = create_select_bill_texts_for_summarization(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bill_texts = session.scalars(stmt).all()
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
written = 0
for index, bill_text in enumerate(bill_texts, 1):
summary = summarize_bill_text(
model=model,
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
if summary is None:
logger.warning("Skipping bill_text id=%s", bill_text.id)
continue
store_bill_summary_result(
session=session,
bill_text=bill_text,
summary=summary,
model=model,
)
if index % 100 == 0:
session.commit()
written += 1
session.commit()
logger.info("Stored summary for bill_text id=%s", bill_text.id)
logger.info("Done: stored %d summaries", written)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()