added summarize_bills.py
This commit is contained in:
268
pipelines/tools/summarize_bills.py
Normal file
268
pipelines/tools/summarize_bills.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tomllib
|
||||||
|
from os import getenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import typer
|
||||||
|
from sqlalchemy import Select, or_, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
|
from pipelines.config import get_config_dir
|
||||||
|
from pipelines.orm.common import get_postgres_engine
|
||||||
|
from pipelines.orm.data_science_dev.congress import Bill, BillText
|
||||||
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _find_prompts_path() -> Path:
|
||||||
|
return get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||||
|
|
||||||
|
|
||||||
|
def load_summarization_prompts(
|
||||||
|
section: str = "summarization",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return tomllib.loads(_find_prompts_path().read_text())[section]
|
||||||
|
|
||||||
|
|
||||||
|
class BillSummaryError(RuntimeError):
|
||||||
|
"""Raised when a bill summary request or response is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai_summary(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Call GPT and return the assistant message content."""
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
msg = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return extract_message_content(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def build_bill_summary_messages(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Build the GPT prompt messages for one bill text row."""
|
||||||
|
if not bill_text.text_content:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
bill = bill_text.bill
|
||||||
|
compressed_text = compress_bill_text(bill_text.text_content)
|
||||||
|
if not compressed_text:
|
||||||
|
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
metadata = "\n".join(
|
||||||
|
(
|
||||||
|
f"Congress: {bill.congress}",
|
||||||
|
f"Bill: {bill.bill_type} {bill.number}",
|
||||||
|
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||||
|
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||||
|
f"Text version: {bill_text.version_code}"
|
||||||
|
+ (f" ({bill_text.version_name})" if bill_text.version_name else ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
summarization_prompts = load_summarization_prompts()
|
||||||
|
user_prompt = "\n\n".join(
|
||||||
|
(
|
||||||
|
"BILL METADATA:",
|
||||||
|
metadata,
|
||||||
|
summarization_prompts["user_template"].format(text_content=compressed_text),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bill_text(
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
bill_text: BillText,
|
||||||
|
) -> str:
|
||||||
|
"""Generate and return a summary for one bill_text row."""
|
||||||
|
messages = build_bill_summary_messages(bill_text=bill_text)
|
||||||
|
summary = call_openai_summary(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
).strip()
|
||||||
|
if not summary:
|
||||||
|
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def store_bill_summary_result(
|
||||||
|
*,
|
||||||
|
bill_text: BillText,
|
||||||
|
summary: str,
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||||
|
bill_text.summary = summary
|
||||||
|
bill_text.summarization_model = model
|
||||||
|
bill_text.summarization_system_prompt_version = "v1"
|
||||||
|
bill_text.summarization_user_prompt_version = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def create_select_bill_texts_for_summarization(
|
||||||
|
congress: int | None = None,
|
||||||
|
bill_ids: list[int] | None = None,
|
||||||
|
bill_text_ids: list[int] | None = None,
|
||||||
|
force: bool = False,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Select:
|
||||||
|
"""Select bill_text rows that have source text and need summaries."""
|
||||||
|
stmt = (
|
||||||
|
select(BillText)
|
||||||
|
.join(Bill, Bill.id == BillText.bill_id)
|
||||||
|
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||||
|
.options(selectinload(BillText.bill))
|
||||||
|
.order_by(BillText.id)
|
||||||
|
)
|
||||||
|
if congress is not None:
|
||||||
|
stmt = stmt.where(Bill.congress == congress)
|
||||||
|
if bill_ids:
|
||||||
|
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||||
|
if bill_text_ids:
|
||||||
|
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||||
|
if not force:
|
||||||
|
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_content(data: dict[str, Any]) -> str:
|
||||||
|
"""Extract message content from a chat-completions response body."""
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
msg = "Chat completion response did not contain choices"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
first = choices[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
msg = "Chat completion choice must be an object"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
message = first.get("message")
|
||||||
|
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||||
|
return message["content"]
|
||||||
|
if isinstance(first.get("text"), str):
|
||||||
|
return first["text"]
|
||||||
|
|
||||||
|
msg = "Chat completion response did not contain message content"
|
||||||
|
raise BillSummaryError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||||
|
congress: Annotated[
|
||||||
|
int | None, typer.Option(help="Only process one Congress.")
|
||||||
|
] = None,
|
||||||
|
bill_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-id",
|
||||||
|
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
bill_text_ids: Annotated[
|
||||||
|
list[int] | None,
|
||||||
|
typer.Option(
|
||||||
|
"--bill-text-id",
|
||||||
|
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||||
|
] = False,
|
||||||
|
dry_run: Annotated[
|
||||||
|
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||||
|
] = False,
|
||||||
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||||
|
) -> None:
|
||||||
|
"""CLI entrypoint for generating and storing bill summaries."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
if not getenv("CLOSEDAI_TOKEN"):
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
|
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = create_select_bill_texts_for_summarization(
|
||||||
|
congress=congress,
|
||||||
|
bill_ids=bill_ids,
|
||||||
|
bill_text_ids=bill_text_ids,
|
||||||
|
force=force,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
bill_texts = session.scalars(stmt).all()
|
||||||
|
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
for index, bill_text in enumerate(bill_texts, 1):
|
||||||
|
summary = summarize_bill_text(
|
||||||
|
model=model,
|
||||||
|
bill_text=bill_text,
|
||||||
|
)
|
||||||
|
store_bill_summary_result(
|
||||||
|
bill_text=bill_text,
|
||||||
|
summary=summary,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if index % 100 == 0:
|
||||||
|
session.commit()
|
||||||
|
written += 1
|
||||||
|
session.commit()
|
||||||
|
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||||
|
|
||||||
|
logger.info("Done: stored %d summaries", written)
|
||||||
|
|
||||||
|
|
||||||
|
def cli() -> None:
|
||||||
|
"""Typer entry point."""
|
||||||
|
typer.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
Reference in New Issue
Block a user