adding calculate_legislator_scores.py summarize_bills.py and extract_bill_topics.py

This commit is contained in:
2026-04-28 22:44:29 -04:00
committed by Richie
parent 2facb82bd4
commit e1beffef12
4 changed files with 1566 additions and 0 deletions
@@ -0,0 +1,574 @@
"""Calculate legislator topic scores from bill topics and roll-call votes."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Annotated, Sequence
import typer
from sqlalchemy import (
ColumnElement,
Integer,
Select,
and_,
case,
cast,
delete,
extract,
func,
or_,
select,
tuple_,
)
from sqlalchemy.orm import Session
from pipelines.congress_vote_context import create_score_run, finalize_score_run
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
BillTopic,
BillTopicPosition,
LegislatorScore,
SubjectType,
Vote,
VoteClassification,
VoteEffect,
VoteMeasureLink,
VoteMeasureRole,
VotePositionMeaning,
VoteRelationship,
VoteRecord,
)
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label
from pipelines.web.scoring import (
OPPOSE_POSITIONS,
SUPPORT_POSITIONS,
normalized_position_expression,
)
logger = logging.getLogger(__name__)
DELETE_BATCH_SIZE = 5_000
@dataclass(frozen=True)
class ScoreDiagnostics:
"""Counts for the input stages required to calculate legislator scores."""
bill_topic_rows: int
linked_vote_rows: int
vote_record_rows: int
topic_vote_links: int
scorable_vote_records: int
@dataclass(frozen=True)
class LegislatorScoreInput:
"""One aggregated score ready to store in legislator_score."""
legislator_id: int
year: int
topic: str
score: float
supportive: int
opposed: int
def create_legislator_score_query(
*,
congress: int | None = None,
bill_ids: Sequence[int] | None = None,
topics: Sequence[str] | None = None,
) -> Select:
"""Build the aggregate score query from extracted bill topics and vote records."""
normalized_vote = normalized_position_expression(VoteRecord.position)
supportive_vote = _supportive_vote_expression(normalized_vote)
opposed_vote = _opposed_vote_expression(normalized_vote)
supportive_count = func.sum(supportive_vote)
opposed_count = func.sum(opposed_vote)
total_count = supportive_count + opposed_count
vote_year = cast(extract("year", Vote.vote_date), Integer)
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
stmt = (
select(
VoteRecord.legislator_id.label("legislator_id"),
vote_year.label("year"),
BillTopic.topic.label("topic"),
score,
supportive_count.label("supportive"),
opposed_count.label("opposed"),
total_count.label("total"),
)
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(
*_eligible_vote_filters(),
_is_scorable_position(normalized_vote),
)
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
.having(total_count > 0)
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
)
if congress is not None:
stmt = stmt.where(Vote.congress == congress)
if bill_ids:
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
normalized_topics = _normalize_topics(topics)
if normalized_topics:
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
return stmt
def collect_legislator_scores(
session: Session,
*,
congress: int | None = None,
bill_ids: Sequence[int] | None = None,
topics: Sequence[str] | None = None,
) -> list[LegislatorScoreInput]:
"""Run the aggregate query and return score rows."""
rows = session.execute(
create_legislator_score_query(
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
)
return [
LegislatorScoreInput(
legislator_id=int(row.legislator_id),
year=int(row.year),
topic=str(row.topic),
score=float(row.score),
supportive=int(row.supportive),
opposed=int(row.opposed),
)
for row in rows
if row.score is not None
]
def collect_score_diagnostics(
session: Session,
*,
congress: int | None = None,
bill_ids: Sequence[int] | None = None,
topics: Sequence[str] | None = None,
) -> ScoreDiagnostics:
"""Count score pipeline inputs for explaining empty score runs."""
normalized_topics = _normalize_topics(topics)
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
normalized_vote = normalized_position_expression(VoteRecord.position)
eligible_vote_filters = _eligible_vote_filters()
bill_topic_rows = session.scalar(
select(func.count(BillTopic.id)).where(*topic_filters)
)
linked_vote_rows = session.scalar(
select(func.count(func.distinct(Vote.id)))
.select_from(Vote)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.where(*vote_filters, *eligible_vote_filters)
)
vote_record_rows = session.scalar(
select(func.count())
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.where(*vote_filters, *eligible_vote_filters)
)
topic_vote_links = session.scalar(
select(func.count())
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
)
scorable_vote_records = session.scalar(
select(func.count())
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(
*vote_filters,
*topic_filters,
*eligible_vote_filters,
_is_scorable_position(normalized_vote),
)
)
return ScoreDiagnostics(
bill_topic_rows=bill_topic_rows or 0,
linked_vote_rows=linked_vote_rows or 0,
vote_record_rows=vote_record_rows or 0,
topic_vote_links=topic_vote_links or 0,
scorable_vote_records=scorable_vote_records or 0,
)
def store_legislator_scores(
session: Session,
rows: Sequence[LegislatorScoreInput],
*,
score_run_id: int | None,
replace_all: bool = False,
) -> int:
"""Replace matching score rows and insert the newly calculated scores."""
if replace_all:
session.execute(delete(LegislatorScore))
elif rows:
keys = [
(row.legislator_id, row.year, row.topic)
for row in rows
]
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
session.execute(
delete(LegislatorScore).where(
tuple_(
LegislatorScore.legislator_id,
LegislatorScore.year,
LegislatorScore.topic,
).in_(key_batch)
)
)
session.add_all(
[
LegislatorScore(
legislator_id=row.legislator_id,
year=row.year,
topic=row.topic,
score=row.score,
score_run_id=score_run_id,
)
for row in rows
]
)
return len(rows)
def _supportive_vote_expression(
normalized_vote: ColumnElement[str | None],
) -> ColumnElement[int]:
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
return case(
(
and_(
BillTopic.support_position == BillTopicPosition.FOR,
supports_text,
),
1,
),
(
and_(
BillTopic.support_position == BillTopicPosition.AGAINST,
opposes_text,
),
1,
),
else_=0,
)
def _opposed_vote_expression(
normalized_vote: ColumnElement[str | None],
) -> ColumnElement[int]:
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
return case(
(
and_(
BillTopic.support_position == BillTopicPosition.FOR,
opposes_text,
),
1,
),
(
and_(
BillTopic.support_position == BillTopicPosition.AGAINST,
supports_text,
),
1,
),
else_=0,
)
def _position_matches_effect(
normalized_vote: ColumnElement[str | None],
effect: VoteEffect,
) -> ColumnElement[bool]:
return or_(
and_(
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
VotePositionMeaning.yea_effect == effect,
),
and_(
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
VotePositionMeaning.nay_effect == effect,
),
and_(
normalized_vote == "present",
VotePositionMeaning.present_effect == effect,
),
)
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
return or_(
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
)
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
normalized: list[str] = []
seen: set[str] = set()
for topic in topics or []:
value = normalize_topic_label(topic)
if value and value not in seen:
normalized.append(value)
seen.add(value)
return normalized
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
return [
items[index : index + batch_size]
for index in range(0, len(items), batch_size)
]
def _vote_scope_filters(
*,
congress: int | None,
bill_ids: Sequence[int] | None,
) -> list[ColumnElement[bool]]:
filters: list[ColumnElement[bool]] = []
if congress is not None:
filters.append(Vote.congress == congress)
if bill_ids:
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
return filters
def _topic_scope_filters(
*,
bill_ids: Sequence[int] | None,
topics: Sequence[str],
) -> list[ColumnElement[bool]]:
filters: list[ColumnElement[bool]] = []
if bill_ids:
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
if topics:
filters.append(BillTopic.topic.in_(list(topics)))
return filters
def _has_score_scope(
*,
congress: int | None,
bill_ids: Sequence[int] | None,
topics: Sequence[str] | None,
) -> bool:
return congress is not None or bool(bill_ids) or bool(topics)
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
return [
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
]
def main(
congress: Annotated[
int | None,
typer.Option(help="Only score votes from one Congress."),
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only score votes linked to one internal bill.id. Repeatable.",
),
] = None,
topics: Annotated[
list[str] | None,
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
] = None,
replace_all: Annotated[
bool,
typer.Option(
help="Delete every existing legislator score before inserting. "
"Unfiltered runs do this automatically."
),
] = False,
dry_run: Annotated[
bool,
typer.Option(help="Calculate scores without writing to the database."),
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
diagnose: Annotated[
bool,
typer.Option(help="Log input-stage counts even when rows are calculated."),
] = False,
) -> None:
"""CLI entrypoint for calculating and storing legislator topic scores."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
rows = collect_legislator_scores(
session,
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
logger.info("Calculated %d legislator topic score rows", len(rows))
if diagnose or not rows:
diagnostics = collect_score_diagnostics(
session,
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
_log_diagnostics(diagnostics)
if dry_run:
session.rollback()
return
score_run = create_score_run(session)
should_replace_all = replace_all or not _has_score_scope(
congress=congress,
bill_ids=bill_ids,
topics=topics,
)
written = store_legislator_scores(
session,
rows,
score_run_id=score_run.id,
replace_all=should_replace_all,
)
included_vote_count = session.scalar(
select(func.count(func.distinct(Vote.id)))
.select_from(VoteRecord)
.join(Vote, Vote.id == VoteRecord.vote_id)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
.where(
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
*_eligible_vote_filters(),
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
)
) or 0
total_scoped_votes = session.scalar(
select(func.count(func.distinct(Vote.id)))
.select_from(Vote)
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
.join(
VoteMeasureLink,
and_(
VoteMeasureLink.vote_id == Vote.id,
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
),
)
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
) or 0
finalize_score_run(
session,
score_run=score_run,
included_vote_count=included_vote_count,
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
)
session.commit()
logger.info("Stored %d legislator topic score rows", written)
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
logger.info(
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
diagnostics.bill_topic_rows,
diagnostics.linked_vote_rows,
diagnostics.vote_record_rows,
diagnostics.topic_vote_links,
diagnostics.scorable_vote_records,
)
if diagnostics.bill_topic_rows == 0:
logger.warning(
"No extracted bill topics matched the score scope. Run "
"pipelines.tools.extract_bill_topics after bill summarization."
)
elif diagnostics.linked_vote_rows == 0:
logger.warning("No direct substantive text votes matched the score scope.")
elif diagnostics.vote_record_rows == 0:
logger.warning("No individual vote records matched the score scope.")
elif diagnostics.topic_vote_links == 0:
logger.warning(
"Bill topics exist, but none are attached to bills that have eligible scored votes."
)
elif diagnostics.scorable_vote_records == 0:
logger.warning(
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
)
if __name__ == "__main__":
typer.run(main)
+682
View File
@@ -0,0 +1,682 @@
"""Extract bill topics from bill text using a configurable topic catalog."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Any, Sequence
import httpx
import typer
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
from sqlalchemy.orm import Session, selectinload
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
BillTopic,
BillTopicPosition,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
logger = logging.getLogger(__name__)
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
REQUEST_TIMEOUT_SECONDS = 60
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
class TopicExtractionError(RuntimeError):
"""Raised when a topic extraction request or response is invalid."""
@dataclass(frozen=True)
class TopicCatalog:
"""Loaded topic catalog with categories for prompting and flat candidates."""
topics_by_category: dict[str, list[str]]
candidate_topics: list[str]
@dataclass(frozen=True)
class TopicExtractionDiagnostics:
"""Counts for the bill summary inputs needed by topic extraction."""
bill_rows: int
bill_text_rows: int
summarized_bill_text_rows: int
bills_with_summaries: int
bill_topic_rows: int
selected_bills: int
@dataclass(frozen=True)
class ExtractedBillTopic:
"""One extracted bill topic and yes-vote stance."""
topic: str
support_position: BillTopicPosition
confidence: float | None = None
evidence: str | None = None
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
"""Pick one summarized bill_text row from the already-loaded relationship."""
for bill_text in bill.bill_texts:
if bill_text.summary and bill_text.summary.strip():
return bill_text
return None
def normalize_topic_label(value: str) -> str:
"""Normalize a topic label for storage, comparison, and de-duping."""
normalized = value.strip().strip("\"'")
normalized = normalized.strip().rstrip(".").strip()
return re.sub(r"\s+", " ", normalized).lower()
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
"""Load, validate, normalize, and flatten the bill topic catalog."""
topics_path = path or DEFAULT_TOPICS_PATH
try:
raw = json.loads(topics_path.read_text())
except FileNotFoundError as exc:
msg = f"Topic catalog not found: {topics_path}"
raise TopicExtractionError(msg) from exc
except json.JSONDecodeError as exc:
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
raise TopicExtractionError(msg) from exc
if not isinstance(raw, dict):
msg = "Topic catalog root must be an object mapping category names to lists"
raise TopicExtractionError(msg)
topics_by_category: dict[str, list[str]] = {}
candidate_topics: list[str] = []
seen_topics: set[str] = set()
for category, topics in raw.items():
if not isinstance(category, str) or not category.strip():
msg = "Topic catalog category names must be non-empty strings"
raise TopicExtractionError(msg)
if not isinstance(topics, list):
msg = f"Topic catalog category {category!r} must contain a list"
raise TopicExtractionError(msg)
normalized_topics: list[str] = []
for topic in topics:
if not isinstance(topic, str):
msg = f"Topic catalog category {category!r} contains a non-string topic"
raise TopicExtractionError(msg)
normalized_topic = normalize_topic_label(topic)
if not normalized_topic:
msg = f"Topic catalog category {category!r} contains a blank topic"
raise TopicExtractionError(msg)
if normalized_topic in seen_topics:
continue
seen_topics.add(normalized_topic)
normalized_topics.append(normalized_topic)
candidate_topics.append(normalized_topic)
topics_by_category[category.strip()] = normalized_topics
return TopicCatalog(
topics_by_category=topics_by_category,
candidate_topics=candidate_topics,
)
def build_topic_extraction_messages(
*,
bill: Bill,
bill_text: str,
candidate_topics: Sequence[str],
) -> list[dict[str, str]]:
"""Build GPT messages for extracting a bill's scored topics."""
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
metadata = "\n".join(
(
f"Congress: {bill.congress}",
f"Bill: {bill.bill_type} {bill.number}",
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
f"Top subject term: {bill.subjects_top_term or ''}",
)
)
system_prompt = (
"You extract policy topics from U.S. congressional bills.\n"
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
"Select only topics from the provided candidate topic list.\n"
"Omit topics that are not materially addressed by the bill.\n"
"Return strict JSON only, with this shape:\n"
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
)
user_prompt = "\n\n".join(
(
"BILL METADATA:",
metadata,
"CANDIDATE TOPICS:",
candidate_list,
"BILL TEXT:",
bill_text,
)
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def call_openai_topic_extraction(
*,
openai_config: OpenAIConfig,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
response = httpx.post(
openai_config.openai_chat_completions_url,
headers={
"Authorization": f"Bearer {openai_config.api_key}",
"OpenAI-Project": openai_config.openai_project_id,
"Content-Type": "application/json",
},
json={
"model": "gpt-5.4-mini",
"messages": messages,
},
timeout=openai_config.timeout_seconds,
)
response.raise_for_status()
return extract_message_content(response.json())
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise TopicExtractionError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise TopicExtractionError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise TopicExtractionError(msg)
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
payload = _load_json_response(response_text)
topics = payload.get("topics")
if not isinstance(topics, list):
msg = "Topic extraction response must contain a topics list"
raise TopicExtractionError(msg)
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
for item in topics:
if not isinstance(item, dict):
msg = "Topic extraction response topics must be objects"
raise TopicExtractionError(msg)
raw_topic = _extract_topic_label(item)
topic = normalize_topic_label(raw_topic)
if not topic:
msg = "Topic extraction response topic must not be blank"
raise TopicExtractionError(msg)
raw_position = item.get("support_position")
try:
support_position = BillTopicPosition(raw_position)
except ValueError as exc:
msg = f"Invalid support_position: {raw_position!r}"
raise TopicExtractionError(msg) from exc
confidence = _parse_confidence(item.get("confidence"))
evidence = item.get("evidence")
if evidence is not None and not isinstance(evidence, str):
evidence = str(evidence)
extracted = ExtractedBillTopic(
topic=topic,
support_position=support_position,
confidence=confidence,
evidence=evidence,
)
key = (topic, support_position)
existing = deduped.get(key)
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
deduped[key] = extracted
return list(deduped.values())
def extract_topics_for_bill_text(
*,
openai_config: OpenAIConfig,
bill: Bill,
text: str,
candidate_topics: Sequence[str],
) -> list[ExtractedBillTopic]:
"""Extract accepted catalog topics for a bill text string."""
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
messages = build_topic_extraction_messages(
bill=bill,
bill_text=text,
candidate_topics=sorted(normalized_candidates),
)
response_text = call_openai_topic_extraction(
openai_config=openai_config,
messages=messages,
)
extracted_topics = parse_topic_extraction_response(response_text)
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
def store_bill_topic_result(
*,
session: Session,
bill: Bill,
topics: Sequence[ExtractedBillTopic],
replace_existing: bool = True,
) -> None:
"""Store extracted topics for one bill."""
if replace_existing:
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
for topic in topics:
session.add(
BillTopic(
bill_id=bill.id,
topic=normalize_topic_label(topic.topic),
support_position=topic.support_position,
)
)
def create_select_bills_for_topic_extraction(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select[tuple[Bill]]:
"""Select bill rows that have summarized bill_text rows for topic extraction."""
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
summarized_text_filters: list[ColumnElement[bool]] = [
BillText.bill_id == Bill.id,
*has_summary,
]
if with_votes_only:
summarized_text_filters.append(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
stmt = (
select(Bill)
.where(summarized_text_exists)
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
.order_by(Bill.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(Bill.id.in_(bill_ids))
if bill_text_ids:
selected_text_exists = exists(
select(BillText.id).where(
BillText.bill_id == Bill.id,
BillText.id.in_(bill_text_ids),
*summarized_text_filters[1:],
)
)
stmt = stmt.where(selected_text_exists)
if not force:
stmt = stmt.where(
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
)
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def collect_topic_extraction_diagnostics(
session: Session,
*,
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> TopicExtractionDiagnostics:
"""Count topic extraction inputs for explaining empty selections."""
bill_filters = []
bill_text_filters: list[ColumnElement[bool]] = []
if congress is not None:
bill_filters.append(Bill.congress == congress)
if bill_ids:
bill_filters.append(Bill.id.in_(bill_ids))
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
bill_text_filters.append(BillText.id.in_(bill_text_ids))
if with_votes_only:
bill_text_filters.append(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
summary_filters = [*bill_text_filters, *has_summary]
bills_with_summaries = session.scalar(
select(func.count(func.distinct(Bill.id)))
.select_from(Bill)
.join(BillText, BillText.bill_id == Bill.id)
.where(*bill_filters, *summary_filters)
)
selected_bills = session.scalar(
select(func.count()).select_from(
create_select_bills_for_topic_extraction(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
).subquery()
)
)
return TopicExtractionDiagnostics(
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
bill_text_rows=_count_bill_texts(
session,
bill_filters=bill_filters,
bill_text_filters=bill_text_filters,
),
summarized_bill_text_rows=_count_bill_texts(
session,
bill_filters=bill_filters,
bill_text_filters=summary_filters,
),
bills_with_summaries=bills_with_summaries or 0,
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
selected_bills=selected_bills or 0,
)
def _load_json_response(response_text: str) -> dict[str, Any]:
text = response_text.strip()
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
if fenced:
text = fenced.group(1).strip()
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
msg = f"Topic extraction response is not valid JSON: {exc}"
raise TopicExtractionError(msg) from exc
if not isinstance(payload, dict):
msg = "Topic extraction response must be a JSON object"
raise TopicExtractionError(msg)
return payload
def _parse_confidence(raw: Any) -> float | None:
if raw is None:
return None
try:
return float(raw)
except (TypeError, ValueError) as exc:
msg = f"Invalid confidence: {raw!r}"
raise TopicExtractionError(msg) from exc
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
if topic.confidence is None:
return (0, 0.0)
return (1, topic.confidence)
def _extract_topic_label(item: dict[str, Any]) -> str:
raw_topic = item.get("topic")
if isinstance(raw_topic, str):
return raw_topic
if isinstance(raw_topic, dict):
for key in ("topic", "label", "name", "title"):
value = raw_topic.get(key)
if isinstance(value, str):
return value
msg = "Topic extraction response topic must be a string"
raise TopicExtractionError(msg)
def _count_bill_texts(
session: Session,
*,
bill_filters: Sequence[ColumnElement[bool]],
bill_text_filters: Sequence[ColumnElement[bool]],
) -> int:
stmt = select(func.count(BillText.id))
if bill_filters:
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
return session.scalar(stmt.where(*bill_text_filters)) or 0
def main(
topics_path: Annotated[
Path, typer.Option(help="Path to congressional issue topic JSON.")
] = DEFAULT_TOPICS_PATH,
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process summarized bill_text rows linked to at least one vote.",
),
] = True,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate topics for bills that already have topics."),
] = False,
dry_run: Annotated[
bool,
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
] = False,
diagnose: Annotated[
bool,
typer.Option(help="Log input-stage counts before processing."),
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill topics."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
topic_catalog = load_topic_catalog(topics_path)
logger.info(
"Loaded %d candidate topics from %s",
len(topic_catalog.candidate_topics),
topics_path,
)
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
if diagnose or dry_run:
diagnostics = collect_topic_extraction_diagnostics(
session,
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
_log_topic_extraction_diagnostics(diagnostics)
if dry_run:
return
openai_config = get_openai_config()
stmt = create_select_bills_for_topic_extraction(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bills = session.scalars(stmt).all()
logger.info("Selected %d bills for topic extraction", len(bills))
written = 0
failed = 0
for index, bill in enumerate(bills, 1):
bill_text = _select_bill_text_for_topic_extraction(bill)
if bill_text is None:
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
continue
summary = bill_text.summary.strip()
try:
extracted_topics = extract_topics_for_bill_text(
openai_config=openai_config,
bill=bill,
text=summary,
candidate_topics=topic_catalog.candidate_topics,
)
except (httpx.HTTPError, TopicExtractionError):
failed += 1
logger.exception(
"Skipping bill id=%s after topic extraction failure", bill.id
)
continue
store_bill_topic_result(
session=session,
bill=bill,
topics=extracted_topics,
replace_existing=True,
)
written += 1
if index % 100 == 0:
session.commit()
logger.info(
"Stored %d topics for bill id=%s",
len(extracted_topics),
bill.id,
)
session.commit()
logger.info(
"Done: stored topic results for %d bills; failed %d bills",
written,
failed,
)
def _log_topic_extraction_diagnostics(
diagnostics: TopicExtractionDiagnostics,
) -> None:
logger.info(
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
"summarized_bill_text_rows=%d bills_with_summaries=%d "
"bill_topic_rows=%d selected_bills=%d",
diagnostics.bill_rows,
diagnostics.bill_text_rows,
diagnostics.summarized_bill_text_rows,
diagnostics.bills_with_summaries,
diagnostics.bill_topic_rows,
diagnostics.selected_bills,
)
if diagnostics.bill_rows == 0:
logger.warning("No bills matched the topic extraction scope.")
elif diagnostics.bill_text_rows == 0:
logger.warning("No bill_text rows matched the topic extraction scope.")
elif diagnostics.summarized_bill_text_rows == 0:
logger.warning(
"No summarized bill_text rows matched the topic extraction scope. "
"Run pipelines.tools.summarize_bills first."
)
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
logger.warning(
"No bills selected because matching bills already have topics. "
"Use --force to regenerate them."
)
elif diagnostics.selected_bills == 0:
logger.warning("No bills selected for topic extraction.")
if __name__ == "__main__":
typer.run(main)
+309
View File
@@ -0,0 +1,309 @@
"""Summarize bill_text rows with GPT-5 and store results in the database."""
from __future__ import annotations
import logging
import tomllib
from os import getenv
from typing import Annotated, Any
import httpx
import typer
from sqlalchemy import Select, exists, or_, select
from sqlalchemy.orm import Session, selectinload
from tiktoken import get_encoding
from pipelines.config import get_config_dir
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.congress import (
Bill,
BillText,
SubjectType,
VoteClassification,
VoteRelationship,
VoteTextTarget,
)
from pipelines.tools.bill_token_compression import compress_bill_text
logger = logging.getLogger(__name__)
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
REQUEST_TIMEOUT_SECONDS = 60
def load_summarization_prompts(
section: str = "summarization",
) -> dict[str, str]:
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
return tomllib.loads(summarization_prompts.read_text())[section]
class BillSummaryError(RuntimeError):
"""Raised when a bill summary request or response is invalid."""
def call_openai_summary(
*,
model: str,
messages: list[dict[str, str]],
) -> str:
"""Call GPT and return the assistant message content."""
api_key = getenv("CLOSEDAI_TOKEN")
if not api_key:
msg = "CLOSEDAI_TOKEN is required"
raise BillSummaryError(msg)
response = httpx.post(
OPENAI_CHAT_COMPLETIONS_URL,
headers={
"Authorization": f"Bearer {api_key}",
"OpenAI-Project": OPENAI_PROJECT_ID,
"Content-Type": "application/json",
},
json={
"model": model,
"messages": messages,
},
timeout=REQUEST_TIMEOUT_SECONDS,
)
logger.info(f"{response.text=}")
response.raise_for_status()
return extract_message_content(response.json())
def build_bill_summary_messages(
*,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> list[dict[str, str]]:
"""Build the GPT prompt messages plus compressed text and user prompt."""
if not bill_text.text_content:
msg = f"bill_text id={bill_text.id} has no text_content"
raise BillSummaryError(msg)
compressed_text = compress_bill_text(bill_text.text_content)
if not compressed_text:
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
raise BillSummaryError(msg)
user_prompt = summarization_prompts["user_template"].format(
text_content=compressed_text
)
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
logger.info(f"{user_prompt_tokens=}")
messages = [
{"role": "system", "content": summarization_prompts["system_prompt"]},
{
"role": "user",
"content": user_prompt,
},
]
return messages, user_prompt_tokens
def summarize_bill_text(
*,
model: str,
bill_text: BillText,
summarization_prompts: dict[str, str],
) -> str:
"""Generate and return a summary for one bill_text row."""
messages, user_prompt_tokens = build_bill_summary_messages(
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
# This may only be for gpt-5.4 mini I need to read the docs
if user_prompt_tokens > 272000:
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
logger.warning(msg)
return None
summary = call_openai_summary(
model=model,
messages=messages,
).strip()
if not summary:
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
raise BillSummaryError(msg)
return summary
def store_bill_summary_result(
*,
bill_text: BillText,
summary: str,
model: str,
) -> None:
"""Store a generated summary and the prompt/model metadata that produced it."""
bill_text.summary = summary
bill_text.summarization_model = model
bill_text.summarization_system_prompt_version = "v1.2"
bill_text.summarization_user_prompt_version = "v1"
def create_select_bill_texts_for_summarization(
congress: int | None = None,
bill_ids: list[int] | None = None,
bill_text_ids: list[int] | None = None,
with_votes_only: bool = False,
force: bool = False,
limit: int | None = None,
) -> Select:
"""Select bill_text rows that have source text and need summaries."""
stmt = (
select(BillText)
.join(Bill, Bill.id == BillText.bill_id)
.where(BillText.text_content.is_not(None), BillText.text_content != "")
.options(selectinload(BillText.bill))
.order_by(BillText.id)
)
if congress is not None:
stmt = stmt.where(Bill.congress == congress)
if bill_ids:
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
if bill_text_ids:
stmt = stmt.where(BillText.id.in_(bill_text_ids))
if with_votes_only:
stmt = stmt.where(
exists(
select(VoteTextTarget.vote_id)
.join(
VoteClassification,
VoteClassification.vote_id == VoteTextTarget.vote_id,
)
.where(
VoteTextTarget.voted_text_version_id == BillText.id,
VoteClassification.subject_type == SubjectType.MEASURE,
VoteClassification.vote_relationship
== VoteRelationship.DIRECT_TEXT_VOTE,
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
VoteClassification.is_substantive_policy_vote.is_(True),
VoteClassification.is_special_rule.is_(False),
)
)
)
if not force:
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def extract_message_content(data: dict[str, Any]) -> str:
"""Extract message content from a chat-completions response body."""
choices = data.get("choices")
if not isinstance(choices, list) or not choices:
msg = "Chat completion response did not contain choices"
raise BillSummaryError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat completion choice must be an object"
raise BillSummaryError(msg)
message = first.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
return message["content"]
if isinstance(first.get("text"), str):
return first["text"]
msg = "Chat completion response did not contain message content"
raise BillSummaryError(msg)
def main(
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
congress: Annotated[
int | None, typer.Option(help="Only process one Congress.")
] = None,
bill_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-id",
help="Only process one internal bill.id. Repeat for multiple bills.",
),
] = None,
bill_text_ids: Annotated[
list[int] | None,
typer.Option(
"--bill-text-id",
help="Only process one internal bill_text.id. Repeat for multiple rows.",
),
] = None,
with_votes_only: Annotated[
bool,
typer.Option(
"--with-votes-only",
help="Only process bill_text rows linked to at least one vote.",
),
] = False,
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
force: Annotated[
bool,
typer.Option(help="Regenerate summaries for rows that already have a summary."),
] = False,
dry_run: Annotated[
bool, typer.Option(help="Print summaries without writing them to the database.")
] = False,
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
) -> None:
"""CLI entrypoint for generating and storing bill summaries."""
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
if not getenv("CLOSEDAI_TOKEN"):
message = "CLOSEDAI_TOKEN is required"
raise typer.BadParameter(message)
summarization_prompts = load_summarization_prompts()
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
with Session(engine) as session:
stmt = create_select_bill_texts_for_summarization(
congress=congress,
bill_ids=bill_ids,
bill_text_ids=bill_text_ids,
with_votes_only=with_votes_only,
force=force,
limit=limit,
)
bill_texts = session.scalars(stmt).all()
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
written = 0
for index, bill_text in enumerate(bill_texts, 1):
summary = summarize_bill_text(
model=model,
bill_text=bill_text,
summarization_prompts=summarization_prompts,
)
if summary is None:
logger.warning("Skipping bill_text id=%s", bill_text.id)
continue
store_bill_summary_result(
bill_text=bill_text,
summary=summary,
model=model,
)
if index % 100 == 0:
session.commit()
written += 1
session.commit()
logger.info("Stored summary for bill_text id=%s", bill_text.id)
logger.info("Done: stored %d summaries", written)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()