Files
weave/pipelines/jobs/extract_bill_topics.py
T

697 lines
24 KiB
Python

"""Extract bill topics from bill text using a configurable topic catalog."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Any, Sequence
import httpx
import typer
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
from sqlalchemy.orm import Session, selectinload
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
BillTextSummary,
BillTopic,
BillTopicPosition,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
logger = logging.getLogger(__name__)
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
REQUEST_TIMEOUT_SECONDS = 60
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
class TopicExtractionError(RuntimeError):
"""Raised when a topic extraction request or response is invalid."""
@dataclass(frozen=True)
class TopicCatalog:
"""Loaded topic catalog with categories for prompting and flat candidates."""
topics_by_category: dict[str, list[str]]
candidate_topics: list[str]
@dataclass(frozen=True)
class TopicExtractionDiagnostics:
"""Counts for the bill summary inputs needed by topic extraction."""
bill_rows: int
bill_text_rows: int
summarized_bill_text_rows: int
bills_with_summaries: int
bill_topic_rows: int
selected_bills: int
@dataclass(frozen=True)
class ExtractedBillTopic:
"""One extracted bill topic and yes-vote stance."""
topic: str
support_position: BillTopicPosition
confidence: float | None = None
evidence: str | None = None
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
"""Pick one summarized bill_text row from the already-loaded relationship."""
for bill_text in bill.bill_texts:
summary_row = bill_text.default_summary()
if summary_row and summary_row.summary.strip():
return bill_text
return None
def _bill_text_has_summary_clause() -> ColumnElement[bool]:
"""Return a correlated EXISTS clause for bill texts with at least one summary."""
return exists(
select(BillTextSummary.id).where(BillTextSummary.bill_text_id == BillText.id)
)
def normalize_topic_label(value: str) -> str:
"""Normalize a topic label for storage, comparison, and de-duping."""
normalized = value.strip().strip("\"'")
normalized = normalized.strip().rstrip(".").strip()
return re.sub(r"\s+", " ", normalized).lower()
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
"""Load, validate, normalize, and flatten the bill topic catalog."""
topics_path = path or DEFAULT_TOPICS_PATH
try:
raw = json.loads(topics_path.read_text())
except FileNotFoundError as exc:
msg = f"Topic catalog not found: {topics_path}"
raise TopicExtractionError(msg) from exc
except json.JSONDecodeError as exc:
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
raise TopicExtractionError(msg) from exc
if not isinstance(raw, dict):
msg = "Topic catalog root must be an object mapping category names to lists"
raise TopicExtractionError(msg)
topics_by_category: dict[str, list[str]] = {}
candidate_topics: list[str] = []
seen_topics: set[str] = set()
for category, topics in raw.items():
if not isinstance(category, str) or not category.strip():
msg = "Topic catalog category names must be non-empty strings"
raise TopicExtractionError(msg)
if not isinstance(topics, list):
msg = f"Topic catalog category {category!r} must contain a list"
raise TopicExtractionError(msg)
normalized_topics: list[str] = []
for topic in topics:
if not isinstance(topic, str):
msg = f"Topic catalog category {category!r} contains a non-string topic"
raise TopicExtractionError(msg)
normalized_topic = normalize_topic_label(topic)
if not normalized_topic:
msg = f"Topic catalog category {category!r} contains a blank topic"
raise TopicExtractionError(msg)
if normalized_topic in seen_topics:
continue
seen_topics.add(normalized_topic)
normalized_topics.append(normalized_topic)
candidate_topics.append(normalized_topic)
topics_by_category[category.strip()] = normalized_topics
return TopicCatalog(
topics_by_category=topics_by_category,
candidate_topics=candidate_topics,
)
def build_topic_extraction_messages(
*,
bill: Bill,
bill_text: str,
candidate_topics: Sequence[str],
) -> list[dict[str, str]]:
"""Build GPT messages for extracting a bill's scored topics."""
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
metadata = "\n".join(
(
f"Congress: {bill.congress}",
f"Bill: {bill.bill_type} {bill.number}",
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
f"Top subject term: {bill.subjects_top_term or ''}",
)
)
system_prompt = (
"You extract policy topics from U.S. congressional bills.\n"
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
"Select only topics from the provided candidate topic list.\n"
"Omit topics that are not materially addressed by the bill.\n"
"Return strict JSON only, with this shape:\n"
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
)
user_prompt = "\n\n".join(
(
"BILL METADATA:",
metadata,
"CANDIDATE TOPICS:",
candidate_list,
"BILL TEXT:",
bill_text,
)
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def call_openai_topic_extraction(
*,
openai_config: OpenAIConfig,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
response = httpx.post(
openai_config.openai_chat_completions_url,
headers={
"Authorization": f"Bearer {openai_config.api_key}",
"OpenAI-Project": openai_config.openai_project_id,
"Content-Type": "application/json",
},
json={
"model": "gpt-5.4-mini",
"messages": messages,
},
timeout=openai_config.timeout_seconds,
)
response.raise_for_status()
return extract_message_content(response.json())
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise TopicExtractionError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise TopicExtractionError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise TopicExtractionError(msg)
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
payload = _load_json_response(response_text)
topics = payload.get("topics")
if not isinstance(topics, list):
msg = "Topic extraction response must contain a topics list"
raise TopicExtractionError(msg)
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
for item in topics:
if not isinstance(item, dict):
msg = "Topic extraction response topics must be objects"
raise TopicExtractionError(msg)
raw_topic = _extract_topic_label(item)
topic = normalize_topic_label(raw_topic)
if not topic:
msg = "Topic extraction response topic must not be blank"
raise TopicExtractionError(msg)
raw_position = item.get("support_position")
try:
support_position = BillTopicPosition(raw_position)
except ValueError as exc:
msg = f"Invalid support_position: {raw_position!r}"
raise TopicExtractionError(msg) from exc
confidence = _parse_confidence(item.get("confidence"))
evidence = item.get("evidence")
if evidence is not None and not isinstance(evidence, str):
evidence = str(evidence)
extracted = ExtractedBillTopic(
topic=topic,
support_position=support_position,
confidence=confidence,
evidence=evidence,
)
key = (topic, support_position)
existing = deduped.get(key)
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
deduped[key] = extracted
return list(deduped.values())
def extract_topics_for_bill_text(
*,
openai_config: OpenAIConfig,
bill: Bill,
text: str,
candidate_topics: Sequence[str],
) -> list[ExtractedBillTopic]:
"""Extract accepted catalog topics for a bill text string."""
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
messages = build_topic_extraction_messages(
bill=bill,
bill_text=text,
candidate_topics=sorted(normalized_candidates),
)
response_text = call_openai_topic_extraction(
openai_config=openai_config,
messages=messages,
)
extracted_topics = parse_topic_extraction_response(response_text)
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
def store_bill_topic_result(
*,
session: Session,
bill: Bill,
topics: Sequence[ExtractedBillTopic],
replace_existing: bool = True,
) -> None:
"""Store extracted topics for one bill."""
if replace_existing:
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
for topic in topics:
session.add(
BillTopic(
bill_id=bill.id,
topic=normalize_topic_label(topic.topic),
support_position=topic.support_position,
)
)
def create_select_bills_for_topic_extraction(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select[tuple[Bill]]:
"""Select bill rows that have summarized bill_text rows for topic extraction."""
summarized_text_filters: list[ColumnElement[bool]] = [_bill_text_has_summary_clause()]
if with_votes_only:
summarized_text_filters.append(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
summarized_text_exists = exists(
select(BillText.id).where(BillText.bill_id == Bill.id, *summarized_text_filters)
)
bill_text_loader = selectinload(Bill.bill_texts.and_(*summarized_text_filters))
stmt = (
select(Bill)
.where(summarized_text_exists)
.options(
bill_text_loader.selectinload(BillText.summaries),
bill_text_loader.selectinload(BillText.primary_summary),
)
.order_by(Bill.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(Bill.id.in_(bill_ids))
if bill_text_ids:
selected_text_exists = exists(
select(BillText.id).where(
BillText.bill_id == Bill.id,
BillText.id.in_(bill_text_ids),
*summarized_text_filters,
)
)
stmt = stmt.where(selected_text_exists)
if not force:
stmt = stmt.where(
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
)
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def collect_topic_extraction_diagnostics(
session: Session,
*,
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> TopicExtractionDiagnostics:
"""Count topic extraction inputs for explaining empty selections."""
bill_filters = []
bill_text_filters: list[ColumnElement[bool]] = []
if congress is not None:
bill_filters.append(Bill.congress == congress)
if bill_ids:
bill_filters.append(Bill.id.in_(bill_ids))
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
bill_text_filters.append(BillText.id.in_(bill_text_ids))
if with_votes_only:
bill_text_filters.append(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
summary_filters = [*bill_text_filters, _bill_text_has_summary_clause()]
bills_with_summaries = session.scalar(
select(func.count(func.distinct(Bill.id)))
.select_from(Bill)
.join(BillText, BillText.bill_id == Bill.id)
.where(*bill_filters, *summary_filters)
)
selected_bills = session.scalar(
select(func.count()).select_from(
create_select_bills_for_topic_extraction(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
).subquery()
)
)
return TopicExtractionDiagnostics(
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
bill_text_rows=_count_bill_texts(
session,
bill_filters=bill_filters,
bill_text_filters=bill_text_filters,
),
summarized_bill_text_rows=_count_bill_texts(
session,
bill_filters=bill_filters,
bill_text_filters=summary_filters,
),
bills_with_summaries=bills_with_summaries or 0,
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
selected_bills=selected_bills or 0,
)
def _load_json_response(response_text: str) -> dict[str, Any]:
text = response_text.strip()
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
if fenced:
text = fenced.group(1).strip()
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
msg = f"Topic extraction response is not valid JSON: {exc}"
raise TopicExtractionError(msg) from exc
if not isinstance(payload, dict):
msg = "Topic extraction response must be a JSON object"
raise TopicExtractionError(msg)
return payload
def _parse_confidence(raw: Any) -> float | None:
if raw is None:
return None
try:
return float(raw)
except (TypeError, ValueError) as exc:
msg = f"Invalid confidence: {raw!r}"
raise TopicExtractionError(msg) from exc
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
if topic.confidence is None:
return (0, 0.0)
return (1, topic.confidence)
def _extract_topic_label(item: dict[str, Any]) -> str:
raw_topic = item.get("topic")
if isinstance(raw_topic, str):
return raw_topic
if isinstance(raw_topic, dict):
for key in ("topic", "label", "name", "title"):
value = raw_topic.get(key)
if isinstance(value, str):
return value
msg = "Topic extraction response topic must be a string"
raise TopicExtractionError(msg)
def _count_bill_texts(
session: Session,
*,
bill_filters: Sequence[ColumnElement[bool]],
bill_text_filters: Sequence[ColumnElement[bool]],
) -> int:
stmt = select(func.count(BillText.id))
if bill_filters:
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
return session.scalar(stmt.where(*bill_text_filters)) or 0
def main(
topics_path: Annotated[
Path, typer.Option(help="Path to congressional issue topic JSON.")
] = DEFAULT_TOPICS_PATH,
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process summarized bill_text rows linked to at least one vote.",
),
] = True,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate topics for bills that already have topics."),
] = False,
dry_run: Annotated[
bool,
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
] = False,
diagnose: Annotated[
bool,
typer.Option(help="Log input-stage counts before processing."),
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill topics."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
topic_catalog = load_topic_catalog(topics_path)
logger.info(
"Loaded %d candidate topics from %s",
len(topic_catalog.candidate_topics),
topics_path,
)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
if diagnose or dry_run:
diagnostics = collect_topic_extraction_diagnostics(
session,
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
_log_topic_extraction_diagnostics(diagnostics)
if dry_run:
return
openai_config = get_openai_config()
stmt = create_select_bills_for_topic_extraction(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bills = session.scalars(stmt).all()
logger.info("Selected %d bills for topic extraction", len(bills))
written = 0
failed = 0
for index, bill in enumerate(bills, 1):
bill_text = _select_bill_text_for_topic_extraction(bill)
if bill_text is None:
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
continue
summary_row = bill_text.default_summary()
if summary_row is None:
logger.warning("Skipping bill id=%s: no default summary", bill.id)
continue
summary = summary_row.summary.strip()
try:
extracted_topics = extract_topics_for_bill_text(
openai_config=openai_config,
bill=bill,
text=summary,
candidate_topics=topic_catalog.candidate_topics,
)
except (httpx.HTTPError, TopicExtractionError):
failed += 1
logger.exception(
"Skipping bill id=%s after topic extraction failure", bill.id
)
continue
store_bill_topic_result(
session=session,
bill=bill,
topics=extracted_topics,
replace_existing=True,
)
written += 1
if index % 100 == 0:
session.commit()
logger.info(
"Stored %d topics for bill id=%s",
len(extracted_topics),
bill.id,
)
session.commit()
logger.info(
"Done: stored topic results for %d bills; failed %d bills",
written,
failed,
)
def _log_topic_extraction_diagnostics(
diagnostics: TopicExtractionDiagnostics,
) -> None:
logger.info(
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
"summarized_bill_text_rows=%d bills_with_summaries=%d "
"bill_topic_rows=%d selected_bills=%d",
diagnostics.bill_rows,
diagnostics.bill_text_rows,
diagnostics.summarized_bill_text_rows,
diagnostics.bills_with_summaries,
diagnostics.bill_topic_rows,
diagnostics.selected_bills,
)
if diagnostics.bill_rows == 0:
logger.warning("No bills matched the topic extraction scope.")
elif diagnostics.bill_text_rows == 0:
logger.warning("No bill_text rows matched the topic extraction scope.")
elif diagnostics.summarized_bill_text_rows == 0:
logger.warning(
"No summarized bill_text rows matched the topic extraction scope. "
"Run pipelines.tools.summarize_bills first."
)
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
logger.warning(
"No bills selected because matching bills already have topics. "
"Use --force to regenerate them."
)
elif diagnostics.selected_bills == 0:
logger.warning("No bills selected for topic extraction.")
if __name__ == "__main__":
typer.run(main)