Merge pull request 'adding calculate_legislator_scores.py summarize_bills.py and extract_bill_topics.py' (#6) from feature/making-jobs-dir into main
Reviewed-on: #6
This commit was merged in pull request #6.
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
|
||||
@@ -0,0 +1,574 @@
|
||||
"""Calculate legislator topic scores from bill topics and roll-call votes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import typer
|
||||
from sqlalchemy import (
|
||||
ColumnElement,
|
||||
Integer,
|
||||
Select,
|
||||
and_,
|
||||
case,
|
||||
cast,
|
||||
delete,
|
||||
extract,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
tuple_,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from pipelines.congress_vote_context import create_score_run, finalize_score_run
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
BillTopic,
|
||||
BillTopicPosition,
|
||||
LegislatorScore,
|
||||
SubjectType,
|
||||
Vote,
|
||||
VoteClassification,
|
||||
VoteEffect,
|
||||
VoteMeasureLink,
|
||||
VoteMeasureRole,
|
||||
VotePositionMeaning,
|
||||
VoteRelationship,
|
||||
VoteRecord,
|
||||
)
|
||||
from pipelines.pipelines.jobs.extract_bill_topics import normalize_topic_label
|
||||
from pipelines.web.scoring import (
|
||||
OPPOSE_POSITIONS,
|
||||
SUPPORT_POSITIONS,
|
||||
normalized_position_expression,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_BATCH_SIZE = 5_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScoreDiagnostics:
|
||||
"""Counts for the input stages required to calculate legislator scores."""
|
||||
|
||||
bill_topic_rows: int
|
||||
linked_vote_rows: int
|
||||
vote_record_rows: int
|
||||
topic_vote_links: int
|
||||
scorable_vote_records: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LegislatorScoreInput:
|
||||
"""One aggregated score ready to store in legislator_score."""
|
||||
|
||||
legislator_id: int
|
||||
year: int
|
||||
topic: str
|
||||
score: float
|
||||
supportive: int
|
||||
opposed: int
|
||||
|
||||
|
||||
def create_legislator_score_query(
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: Sequence[int] | None = None,
|
||||
topics: Sequence[str] | None = None,
|
||||
) -> Select:
|
||||
"""Build the aggregate score query from extracted bill topics and vote records."""
|
||||
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||
supportive_vote = _supportive_vote_expression(normalized_vote)
|
||||
opposed_vote = _opposed_vote_expression(normalized_vote)
|
||||
supportive_count = func.sum(supportive_vote)
|
||||
opposed_count = func.sum(opposed_vote)
|
||||
total_count = supportive_count + opposed_count
|
||||
vote_year = cast(extract("year", Vote.vote_date), Integer)
|
||||
score = (100.0 * supportive_count / func.nullif(total_count, 0)).label("score")
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
VoteRecord.legislator_id.label("legislator_id"),
|
||||
vote_year.label("year"),
|
||||
BillTopic.topic.label("topic"),
|
||||
score,
|
||||
supportive_count.label("supportive"),
|
||||
opposed_count.label("opposed"),
|
||||
total_count.label("total"),
|
||||
)
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(
|
||||
*_eligible_vote_filters(),
|
||||
_is_scorable_position(normalized_vote),
|
||||
)
|
||||
.group_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||
.having(total_count > 0)
|
||||
.order_by(VoteRecord.legislator_id, vote_year, BillTopic.topic)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Vote.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||
|
||||
normalized_topics = _normalize_topics(topics)
|
||||
if normalized_topics:
|
||||
stmt = stmt.where(BillTopic.topic.in_(normalized_topics))
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def collect_legislator_scores(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: Sequence[int] | None = None,
|
||||
topics: Sequence[str] | None = None,
|
||||
) -> list[LegislatorScoreInput]:
|
||||
"""Run the aggregate query and return score rows."""
|
||||
rows = session.execute(
|
||||
create_legislator_score_query(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
)
|
||||
return [
|
||||
LegislatorScoreInput(
|
||||
legislator_id=int(row.legislator_id),
|
||||
year=int(row.year),
|
||||
topic=str(row.topic),
|
||||
score=float(row.score),
|
||||
supportive=int(row.supportive),
|
||||
opposed=int(row.opposed),
|
||||
)
|
||||
for row in rows
|
||||
if row.score is not None
|
||||
]
|
||||
|
||||
|
||||
def collect_score_diagnostics(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: Sequence[int] | None = None,
|
||||
topics: Sequence[str] | None = None,
|
||||
) -> ScoreDiagnostics:
|
||||
"""Count score pipeline inputs for explaining empty score runs."""
|
||||
normalized_topics = _normalize_topics(topics)
|
||||
vote_filters = _vote_scope_filters(congress=congress, bill_ids=bill_ids)
|
||||
topic_filters = _topic_scope_filters(bill_ids=bill_ids, topics=normalized_topics)
|
||||
normalized_vote = normalized_position_expression(VoteRecord.position)
|
||||
eligible_vote_filters = _eligible_vote_filters()
|
||||
|
||||
bill_topic_rows = session.scalar(
|
||||
select(func.count(BillTopic.id)).where(*topic_filters)
|
||||
)
|
||||
linked_vote_rows = session.scalar(
|
||||
select(func.count(func.distinct(Vote.id)))
|
||||
.select_from(Vote)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.where(*vote_filters, *eligible_vote_filters)
|
||||
)
|
||||
vote_record_rows = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.where(*vote_filters, *eligible_vote_filters)
|
||||
)
|
||||
topic_vote_links = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(*vote_filters, *topic_filters, *eligible_vote_filters)
|
||||
)
|
||||
scorable_vote_records = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(
|
||||
*vote_filters,
|
||||
*topic_filters,
|
||||
*eligible_vote_filters,
|
||||
_is_scorable_position(normalized_vote),
|
||||
)
|
||||
)
|
||||
|
||||
return ScoreDiagnostics(
|
||||
bill_topic_rows=bill_topic_rows or 0,
|
||||
linked_vote_rows=linked_vote_rows or 0,
|
||||
vote_record_rows=vote_record_rows or 0,
|
||||
topic_vote_links=topic_vote_links or 0,
|
||||
scorable_vote_records=scorable_vote_records or 0,
|
||||
)
|
||||
|
||||
|
||||
def store_legislator_scores(
|
||||
session: Session,
|
||||
rows: Sequence[LegislatorScoreInput],
|
||||
*,
|
||||
score_run_id: int | None,
|
||||
replace_all: bool = False,
|
||||
) -> int:
|
||||
"""Replace matching score rows and insert the newly calculated scores."""
|
||||
if replace_all:
|
||||
session.execute(delete(LegislatorScore))
|
||||
elif rows:
|
||||
keys = [
|
||||
(row.legislator_id, row.year, row.topic)
|
||||
for row in rows
|
||||
]
|
||||
for key_batch in _batched(keys, DELETE_BATCH_SIZE):
|
||||
session.execute(
|
||||
delete(LegislatorScore).where(
|
||||
tuple_(
|
||||
LegislatorScore.legislator_id,
|
||||
LegislatorScore.year,
|
||||
LegislatorScore.topic,
|
||||
).in_(key_batch)
|
||||
)
|
||||
)
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
LegislatorScore(
|
||||
legislator_id=row.legislator_id,
|
||||
year=row.year,
|
||||
topic=row.topic,
|
||||
score=row.score,
|
||||
score_run_id=score_run_id,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
|
||||
def _supportive_vote_expression(
|
||||
normalized_vote: ColumnElement[str | None],
|
||||
) -> ColumnElement[int]:
|
||||
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||
return case(
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.FOR,
|
||||
supports_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||
opposes_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
|
||||
|
||||
def _opposed_vote_expression(
|
||||
normalized_vote: ColumnElement[str | None],
|
||||
) -> ColumnElement[int]:
|
||||
supports_text = _position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT)
|
||||
opposes_text = _position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT)
|
||||
return case(
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.FOR,
|
||||
opposes_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(
|
||||
and_(
|
||||
BillTopic.support_position == BillTopicPosition.AGAINST,
|
||||
supports_text,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
|
||||
|
||||
def _position_matches_effect(
|
||||
normalized_vote: ColumnElement[str | None],
|
||||
effect: VoteEffect,
|
||||
) -> ColumnElement[bool]:
|
||||
return or_(
|
||||
and_(
|
||||
normalized_vote.in_(sorted(SUPPORT_POSITIONS)),
|
||||
VotePositionMeaning.yea_effect == effect,
|
||||
),
|
||||
and_(
|
||||
normalized_vote.in_(sorted(OPPOSE_POSITIONS)),
|
||||
VotePositionMeaning.nay_effect == effect,
|
||||
),
|
||||
and_(
|
||||
normalized_vote == "present",
|
||||
VotePositionMeaning.present_effect == effect,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _is_scorable_position(normalized_vote: ColumnElement[str | None]) -> ColumnElement[bool]:
|
||||
return or_(
|
||||
_position_matches_effect(normalized_vote, VoteEffect.SUPPORTS_TEXT),
|
||||
_position_matches_effect(normalized_vote, VoteEffect.OPPOSES_TEXT),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_topics(topics: Sequence[str] | None) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for topic in topics or []:
|
||||
value = normalize_topic_label(topic)
|
||||
if value and value not in seen:
|
||||
normalized.append(value)
|
||||
seen.add(value)
|
||||
return normalized
|
||||
|
||||
|
||||
def _batched[T](items: Sequence[T], batch_size: int) -> list[Sequence[T]]:
|
||||
return [
|
||||
items[index : index + batch_size]
|
||||
for index in range(0, len(items), batch_size)
|
||||
]
|
||||
|
||||
|
||||
def _vote_scope_filters(
|
||||
*,
|
||||
congress: int | None,
|
||||
bill_ids: Sequence[int] | None,
|
||||
) -> list[ColumnElement[bool]]:
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
if congress is not None:
|
||||
filters.append(Vote.congress == congress)
|
||||
if bill_ids:
|
||||
filters.append(VoteMeasureLink.measure_id.in_(list(bill_ids)))
|
||||
return filters
|
||||
|
||||
|
||||
def _topic_scope_filters(
|
||||
*,
|
||||
bill_ids: Sequence[int] | None,
|
||||
topics: Sequence[str],
|
||||
) -> list[ColumnElement[bool]]:
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
if bill_ids:
|
||||
filters.append(BillTopic.bill_id.in_(list(bill_ids)))
|
||||
if topics:
|
||||
filters.append(BillTopic.topic.in_(list(topics)))
|
||||
return filters
|
||||
|
||||
|
||||
def _has_score_scope(
|
||||
*,
|
||||
congress: int | None,
|
||||
bill_ids: Sequence[int] | None,
|
||||
topics: Sequence[str] | None,
|
||||
) -> bool:
|
||||
return congress is not None or bool(bill_ids) or bool(topics)
|
||||
|
||||
|
||||
def _eligible_vote_filters() -> list[ColumnElement[bool]]:
|
||||
return [
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship == VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
]
|
||||
|
||||
|
||||
def main(
|
||||
congress: Annotated[
|
||||
int | None,
|
||||
typer.Option(help="Only score votes from one Congress."),
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only score votes linked to one internal bill.id. Repeatable.",
|
||||
),
|
||||
] = None,
|
||||
topics: Annotated[
|
||||
list[str] | None,
|
||||
typer.Option("--topic", help="Only score one normalized topic. Repeatable."),
|
||||
] = None,
|
||||
replace_all: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
help="Delete every existing legislator score before inserting. "
|
||||
"Unfiltered runs do this automatically."
|
||||
),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Calculate scores without writing to the database."),
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
diagnose: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Log input-stage counts even when rows are calculated."),
|
||||
] = False,
|
||||
) -> None:
|
||||
"""CLI entrypoint for calculating and storing legislator topic scores."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
rows = collect_legislator_scores(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
logger.info("Calculated %d legislator topic score rows", len(rows))
|
||||
if diagnose or not rows:
|
||||
diagnostics = collect_score_diagnostics(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
_log_diagnostics(diagnostics)
|
||||
|
||||
if dry_run:
|
||||
session.rollback()
|
||||
return
|
||||
|
||||
score_run = create_score_run(session)
|
||||
should_replace_all = replace_all or not _has_score_scope(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
topics=topics,
|
||||
)
|
||||
written = store_legislator_scores(
|
||||
session,
|
||||
rows,
|
||||
score_run_id=score_run.id,
|
||||
replace_all=should_replace_all,
|
||||
)
|
||||
included_vote_count = session.scalar(
|
||||
select(func.count(func.distinct(Vote.id)))
|
||||
.select_from(VoteRecord)
|
||||
.join(Vote, Vote.id == VoteRecord.vote_id)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(VotePositionMeaning, VotePositionMeaning.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.join(BillTopic, BillTopic.bill_id == VoteMeasureLink.measure_id)
|
||||
.where(
|
||||
*_vote_scope_filters(congress=congress, bill_ids=bill_ids),
|
||||
*_topic_scope_filters(bill_ids=bill_ids, topics=_normalize_topics(topics)),
|
||||
*_eligible_vote_filters(),
|
||||
_is_scorable_position(normalized_position_expression(VoteRecord.position)),
|
||||
)
|
||||
) or 0
|
||||
total_scoped_votes = session.scalar(
|
||||
select(func.count(func.distinct(Vote.id)))
|
||||
.select_from(Vote)
|
||||
.join(VoteClassification, VoteClassification.vote_id == Vote.id)
|
||||
.join(
|
||||
VoteMeasureLink,
|
||||
and_(
|
||||
VoteMeasureLink.vote_id == Vote.id,
|
||||
VoteMeasureLink.role == VoteMeasureRole.VOTED_ON,
|
||||
),
|
||||
)
|
||||
.where(*_vote_scope_filters(congress=congress, bill_ids=bill_ids))
|
||||
) or 0
|
||||
finalize_score_run(
|
||||
session,
|
||||
score_run=score_run,
|
||||
included_vote_count=included_vote_count,
|
||||
excluded_vote_count=max(total_scoped_votes - included_vote_count, 0),
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Stored %d legislator topic score rows", written)
|
||||
|
||||
|
||||
def _log_diagnostics(diagnostics: ScoreDiagnostics) -> None:
|
||||
logger.info(
|
||||
"Score input diagnostics: bill_topic_rows=%d linked_vote_rows=%d "
|
||||
"vote_record_rows=%d topic_vote_links=%d scorable_vote_records=%d",
|
||||
diagnostics.bill_topic_rows,
|
||||
diagnostics.linked_vote_rows,
|
||||
diagnostics.vote_record_rows,
|
||||
diagnostics.topic_vote_links,
|
||||
diagnostics.scorable_vote_records,
|
||||
)
|
||||
if diagnostics.bill_topic_rows == 0:
|
||||
logger.warning(
|
||||
"No extracted bill topics matched the score scope. Run "
|
||||
"pipelines.tools.extract_bill_topics after bill summarization."
|
||||
)
|
||||
elif diagnostics.linked_vote_rows == 0:
|
||||
logger.warning("No direct substantive text votes matched the score scope.")
|
||||
elif diagnostics.vote_record_rows == 0:
|
||||
logger.warning("No individual vote records matched the score scope.")
|
||||
elif diagnostics.topic_vote_links == 0:
|
||||
logger.warning(
|
||||
"Bill topics exist, but none are attached to bills that have eligible scored votes."
|
||||
)
|
||||
elif diagnostics.scorable_vote_records == 0:
|
||||
logger.warning(
|
||||
"Topic-vote links exist, but no joined vote records had Yea/Aye/Yes/Nay/No positions."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -0,0 +1,682 @@
|
||||
"""Extract bill topics from bill text using a configurable topic catalog."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from sqlalchemy import ColumnElement, Select, delete, exists, func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from pipelines.config import OpenAIConfig, get_config_dir, get_openai_config
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
Bill,
|
||||
BillText,
|
||||
BillTopic,
|
||||
BillTopicPosition,
|
||||
SubjectType,
|
||||
VoteClassification,
|
||||
VoteRelationship,
|
||||
VoteTextTarget,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||
REQUEST_TIMEOUT_SECONDS = 60
|
||||
DEFAULT_TOPICS_PATH = get_config_dir() / "congressional_issues_comprehensive.json"
|
||||
|
||||
|
||||
class TopicExtractionError(RuntimeError):
|
||||
"""Raised when a topic extraction request or response is invalid."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopicCatalog:
|
||||
"""Loaded topic catalog with categories for prompting and flat candidates."""
|
||||
|
||||
topics_by_category: dict[str, list[str]]
|
||||
candidate_topics: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopicExtractionDiagnostics:
|
||||
"""Counts for the bill summary inputs needed by topic extraction."""
|
||||
|
||||
bill_rows: int
|
||||
bill_text_rows: int
|
||||
summarized_bill_text_rows: int
|
||||
bills_with_summaries: int
|
||||
bill_topic_rows: int
|
||||
selected_bills: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtractedBillTopic:
|
||||
"""One extracted bill topic and yes-vote stance."""
|
||||
|
||||
topic: str
|
||||
support_position: BillTopicPosition
|
||||
confidence: float | None = None
|
||||
evidence: str | None = None
|
||||
|
||||
|
||||
def _select_bill_text_for_topic_extraction(bill: Bill) -> BillText | None:
|
||||
"""Pick one summarized bill_text row from the already-loaded relationship."""
|
||||
for bill_text in bill.bill_texts:
|
||||
if bill_text.summary and bill_text.summary.strip():
|
||||
return bill_text
|
||||
return None
|
||||
|
||||
|
||||
def normalize_topic_label(value: str) -> str:
|
||||
"""Normalize a topic label for storage, comparison, and de-duping."""
|
||||
normalized = value.strip().strip("\"'")
|
||||
normalized = normalized.strip().rstrip(".").strip()
|
||||
return re.sub(r"\s+", " ", normalized).lower()
|
||||
|
||||
|
||||
def load_topic_catalog(path: Path | None = None) -> TopicCatalog:
|
||||
"""Load, validate, normalize, and flatten the bill topic catalog."""
|
||||
topics_path = path or DEFAULT_TOPICS_PATH
|
||||
try:
|
||||
raw = json.loads(topics_path.read_text())
|
||||
except FileNotFoundError as exc:
|
||||
msg = f"Topic catalog not found: {topics_path}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = f"Topic catalog is not valid JSON: {topics_path}: {exc}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
msg = "Topic catalog root must be an object mapping category names to lists"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
topics_by_category: dict[str, list[str]] = {}
|
||||
candidate_topics: list[str] = []
|
||||
seen_topics: set[str] = set()
|
||||
|
||||
for category, topics in raw.items():
|
||||
if not isinstance(category, str) or not category.strip():
|
||||
msg = "Topic catalog category names must be non-empty strings"
|
||||
raise TopicExtractionError(msg)
|
||||
if not isinstance(topics, list):
|
||||
msg = f"Topic catalog category {category!r} must contain a list"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
normalized_topics: list[str] = []
|
||||
for topic in topics:
|
||||
if not isinstance(topic, str):
|
||||
msg = f"Topic catalog category {category!r} contains a non-string topic"
|
||||
raise TopicExtractionError(msg)
|
||||
normalized_topic = normalize_topic_label(topic)
|
||||
if not normalized_topic:
|
||||
msg = f"Topic catalog category {category!r} contains a blank topic"
|
||||
raise TopicExtractionError(msg)
|
||||
if normalized_topic in seen_topics:
|
||||
continue
|
||||
seen_topics.add(normalized_topic)
|
||||
normalized_topics.append(normalized_topic)
|
||||
candidate_topics.append(normalized_topic)
|
||||
|
||||
topics_by_category[category.strip()] = normalized_topics
|
||||
|
||||
return TopicCatalog(
|
||||
topics_by_category=topics_by_category,
|
||||
candidate_topics=candidate_topics,
|
||||
)
|
||||
|
||||
|
||||
def build_topic_extraction_messages(
|
||||
*,
|
||||
bill: Bill,
|
||||
bill_text: str,
|
||||
candidate_topics: Sequence[str],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build GPT messages for extracting a bill's scored topics."""
|
||||
normalized_candidates = [normalize_topic_label(topic) for topic in candidate_topics]
|
||||
candidate_list = "\n".join(f"- {topic}" for topic in normalized_candidates)
|
||||
metadata = "\n".join(
|
||||
(
|
||||
f"Congress: {bill.congress}",
|
||||
f"Bill: {bill.bill_type} {bill.number}",
|
||||
f"Title: {bill.title_short or bill.title or bill.official_title or ''}",
|
||||
f"Top subject term: {bill.subjects_top_term or ''}",
|
||||
)
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You extract policy topics from U.S. congressional bills.\n"
|
||||
'For each selected topic, decide whether a Yes/Yea vote on the bill is "for" or "against" that topic.\n'
|
||||
'Use "support_position": "for" when a Yes/Yea vote advances or supports the topic.\n'
|
||||
'Use "support_position": "against" when a Yes/Yea vote restricts, repeals, blocks, or opposes the topic.\n'
|
||||
"Select only topics from the provided candidate topic list.\n"
|
||||
"Omit topics that are not materially addressed by the bill.\n"
|
||||
"Return strict JSON only, with this shape:\n"
|
||||
'{"topics":[{"topic":"candidate topic","support_position":"for","confidence":0.0,"evidence":"short reason"}]}'
|
||||
)
|
||||
user_prompt = "\n\n".join(
|
||||
(
|
||||
"BILL METADATA:",
|
||||
metadata,
|
||||
"CANDIDATE TOPICS:",
|
||||
candidate_list,
|
||||
"BILL TEXT:",
|
||||
bill_text,
|
||||
)
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
def call_openai_topic_extraction(
|
||||
*,
|
||||
openai_config: OpenAIConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Call GPT and return the assistant message content."""
|
||||
|
||||
response = httpx.post(
|
||||
openai_config.openai_chat_completions_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {openai_config.api_key}",
|
||||
"OpenAI-Project": openai_config.openai_project_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "gpt-5.4-mini",
|
||||
"messages": messages,
|
||||
},
|
||||
timeout=openai_config.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return extract_message_content(response.json())
|
||||
|
||||
|
||||
def extract_message_content(data: dict[str, Any]) -> str:
|
||||
"""Extract message content from a chat-completions response body."""
|
||||
choices = data.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
msg = "Chat completion response did not contain choices"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
msg = "Chat completion choice must be an object"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
message = first.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
if isinstance(first.get("text"), str):
|
||||
return first["text"]
|
||||
|
||||
msg = "Chat completion response did not contain message content"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
|
||||
def parse_topic_extraction_response(response_text: str) -> list[ExtractedBillTopic]:
|
||||
"""Parse, normalize, validate, and de-dupe a topic extraction response."""
|
||||
payload = _load_json_response(response_text)
|
||||
topics = payload.get("topics")
|
||||
if not isinstance(topics, list):
|
||||
msg = "Topic extraction response must contain a topics list"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
deduped: dict[tuple[str, BillTopicPosition], ExtractedBillTopic] = {}
|
||||
for item in topics:
|
||||
if not isinstance(item, dict):
|
||||
msg = "Topic extraction response topics must be objects"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
raw_topic = _extract_topic_label(item)
|
||||
topic = normalize_topic_label(raw_topic)
|
||||
if not topic:
|
||||
msg = "Topic extraction response topic must not be blank"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
raw_position = item.get("support_position")
|
||||
try:
|
||||
support_position = BillTopicPosition(raw_position)
|
||||
except ValueError as exc:
|
||||
msg = f"Invalid support_position: {raw_position!r}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
confidence = _parse_confidence(item.get("confidence"))
|
||||
evidence = item.get("evidence")
|
||||
if evidence is not None and not isinstance(evidence, str):
|
||||
evidence = str(evidence)
|
||||
|
||||
extracted = ExtractedBillTopic(
|
||||
topic=topic,
|
||||
support_position=support_position,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
)
|
||||
key = (topic, support_position)
|
||||
existing = deduped.get(key)
|
||||
if existing is None or _confidence_rank(extracted) > _confidence_rank(existing):
|
||||
deduped[key] = extracted
|
||||
|
||||
return list(deduped.values())
|
||||
|
||||
|
||||
def extract_topics_for_bill_text(
|
||||
*,
|
||||
openai_config: OpenAIConfig,
|
||||
bill: Bill,
|
||||
text: str,
|
||||
candidate_topics: Sequence[str],
|
||||
) -> list[ExtractedBillTopic]:
|
||||
"""Extract accepted catalog topics for a bill text string."""
|
||||
normalized_candidates = {normalize_topic_label(topic) for topic in candidate_topics}
|
||||
messages = build_topic_extraction_messages(
|
||||
bill=bill,
|
||||
bill_text=text,
|
||||
candidate_topics=sorted(normalized_candidates),
|
||||
)
|
||||
response_text = call_openai_topic_extraction(
|
||||
openai_config=openai_config,
|
||||
messages=messages,
|
||||
)
|
||||
extracted_topics = parse_topic_extraction_response(response_text)
|
||||
return [topic for topic in extracted_topics if topic.topic in normalized_candidates]
|
||||
|
||||
|
||||
def store_bill_topic_result(
|
||||
*,
|
||||
session: Session,
|
||||
bill: Bill,
|
||||
topics: Sequence[ExtractedBillTopic],
|
||||
replace_existing: bool = True,
|
||||
) -> None:
|
||||
"""Store extracted topics for one bill."""
|
||||
if replace_existing:
|
||||
session.execute(delete(BillTopic).where(BillTopic.bill_id == bill.id))
|
||||
|
||||
for topic in topics:
|
||||
session.add(
|
||||
BillTopic(
|
||||
bill_id=bill.id,
|
||||
topic=normalize_topic_label(topic.topic),
|
||||
support_position=topic.support_position,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_select_bills_for_topic_extraction(
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> Select[tuple[Bill]]:
|
||||
"""Select bill rows that have summarized bill_text rows for topic extraction."""
|
||||
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||
summarized_text_filters: list[ColumnElement[bool]] = [
|
||||
BillText.bill_id == Bill.id,
|
||||
*has_summary,
|
||||
]
|
||||
if with_votes_only:
|
||||
summarized_text_filters.append(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
summarized_text_exists = exists(select(BillText.id).where(*summarized_text_filters))
|
||||
stmt = (
|
||||
select(Bill)
|
||||
.where(summarized_text_exists)
|
||||
.options(selectinload(Bill.bill_texts.and_(*summarized_text_filters[1:])))
|
||||
.order_by(Bill.id)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(Bill.id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
selected_text_exists = exists(
|
||||
select(BillText.id).where(
|
||||
BillText.bill_id == Bill.id,
|
||||
BillText.id.in_(bill_text_ids),
|
||||
*summarized_text_filters[1:],
|
||||
)
|
||||
)
|
||||
stmt = stmt.where(selected_text_exists)
|
||||
if not force:
|
||||
stmt = stmt.where(
|
||||
~exists(select(BillTopic.id).where(BillTopic.bill_id == Bill.id))
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def collect_topic_extraction_diagnostics(
|
||||
session: Session,
|
||||
*,
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> TopicExtractionDiagnostics:
|
||||
"""Count topic extraction inputs for explaining empty selections."""
|
||||
bill_filters = []
|
||||
bill_text_filters: list[ColumnElement[bool]] = []
|
||||
if congress is not None:
|
||||
bill_filters.append(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
bill_filters.append(Bill.id.in_(bill_ids))
|
||||
bill_text_filters.append(BillText.bill_id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
bill_text_filters.append(BillText.id.in_(bill_text_ids))
|
||||
if with_votes_only:
|
||||
bill_text_filters.append(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
has_summary = (BillText.summary.is_not(None), BillText.summary != "")
|
||||
summary_filters = [*bill_text_filters, *has_summary]
|
||||
|
||||
bills_with_summaries = session.scalar(
|
||||
select(func.count(func.distinct(Bill.id)))
|
||||
.select_from(Bill)
|
||||
.join(BillText, BillText.bill_id == Bill.id)
|
||||
.where(*bill_filters, *summary_filters)
|
||||
)
|
||||
selected_bills = session.scalar(
|
||||
select(func.count()).select_from(
|
||||
create_select_bills_for_topic_extraction(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
).subquery()
|
||||
)
|
||||
)
|
||||
|
||||
return TopicExtractionDiagnostics(
|
||||
bill_rows=session.scalar(select(func.count(Bill.id)).where(*bill_filters)) or 0,
|
||||
bill_text_rows=_count_bill_texts(
|
||||
session,
|
||||
bill_filters=bill_filters,
|
||||
bill_text_filters=bill_text_filters,
|
||||
),
|
||||
summarized_bill_text_rows=_count_bill_texts(
|
||||
session,
|
||||
bill_filters=bill_filters,
|
||||
bill_text_filters=summary_filters,
|
||||
),
|
||||
bills_with_summaries=bills_with_summaries or 0,
|
||||
bill_topic_rows=session.scalar(select(func.count(BillTopic.id))) or 0,
|
||||
selected_bills=selected_bills or 0,
|
||||
)
|
||||
|
||||
|
||||
def _load_json_response(response_text: str) -> dict[str, Any]:
|
||||
text = response_text.strip()
|
||||
fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
|
||||
if fenced:
|
||||
text = fenced.group(1).strip()
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
msg = f"Topic extraction response is not valid JSON: {exc}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
if not isinstance(payload, dict):
|
||||
msg = "Topic extraction response must be a JSON object"
|
||||
raise TopicExtractionError(msg)
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_confidence(raw: Any) -> float | None:
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError) as exc:
|
||||
msg = f"Invalid confidence: {raw!r}"
|
||||
raise TopicExtractionError(msg) from exc
|
||||
|
||||
|
||||
def _confidence_rank(topic: ExtractedBillTopic) -> tuple[int, float]:
|
||||
if topic.confidence is None:
|
||||
return (0, 0.0)
|
||||
return (1, topic.confidence)
|
||||
|
||||
|
||||
def _extract_topic_label(item: dict[str, Any]) -> str:
|
||||
raw_topic = item.get("topic")
|
||||
if isinstance(raw_topic, str):
|
||||
return raw_topic
|
||||
if isinstance(raw_topic, dict):
|
||||
for key in ("topic", "label", "name", "title"):
|
||||
value = raw_topic.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
msg = "Topic extraction response topic must be a string"
|
||||
raise TopicExtractionError(msg)
|
||||
|
||||
|
||||
def _count_bill_texts(
|
||||
session: Session,
|
||||
*,
|
||||
bill_filters: Sequence[ColumnElement[bool]],
|
||||
bill_text_filters: Sequence[ColumnElement[bool]],
|
||||
) -> int:
|
||||
stmt = select(func.count(BillText.id))
|
||||
if bill_filters:
|
||||
stmt = stmt.join(Bill, Bill.id == BillText.bill_id).where(*bill_filters)
|
||||
return session.scalar(stmt.where(*bill_text_filters)) or 0
|
||||
|
||||
|
||||
def main(
|
||||
topics_path: Annotated[
|
||||
Path, typer.Option(help="Path to congressional issue topic JSON.")
|
||||
] = DEFAULT_TOPICS_PATH,
|
||||
congress: Annotated[
|
||||
int | None, typer.Option(help="Only process one Congress.")
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||
),
|
||||
] = None,
|
||||
bill_text_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-text-id",
|
||||
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||
),
|
||||
] = None,
|
||||
with_votes_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--with-votes-only",
|
||||
help="Only process summarized bill_text rows linked to at least one vote.",
|
||||
),
|
||||
] = True,
|
||||
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Regenerate topics for bills that already have topics."),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Select bills and print diagnostics without calling OpenAI."),
|
||||
] = False,
|
||||
diagnose: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Log input-stage counts before processing."),
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
) -> None:
|
||||
"""CLI entrypoint for generating and storing bill topics."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
topic_catalog = load_topic_catalog(topics_path)
|
||||
logger.info(
|
||||
"Loaded %d candidate topics from %s",
|
||||
len(topic_catalog.candidate_topics),
|
||||
topics_path,
|
||||
)
|
||||
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
if diagnose or dry_run:
|
||||
diagnostics = collect_topic_extraction_diagnostics(
|
||||
session,
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
_log_topic_extraction_diagnostics(diagnostics)
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
openai_config = get_openai_config()
|
||||
|
||||
stmt = create_select_bills_for_topic_extraction(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
bills = session.scalars(stmt).all()
|
||||
logger.info("Selected %d bills for topic extraction", len(bills))
|
||||
|
||||
written = 0
|
||||
failed = 0
|
||||
for index, bill in enumerate(bills, 1):
|
||||
bill_text = _select_bill_text_for_topic_extraction(bill)
|
||||
if bill_text is None:
|
||||
logger.warning("Skipping bill id=%s: no usable summary", bill.id)
|
||||
continue
|
||||
summary = bill_text.summary.strip()
|
||||
|
||||
try:
|
||||
extracted_topics = extract_topics_for_bill_text(
|
||||
openai_config=openai_config,
|
||||
bill=bill,
|
||||
text=summary,
|
||||
candidate_topics=topic_catalog.candidate_topics,
|
||||
)
|
||||
except (httpx.HTTPError, TopicExtractionError):
|
||||
failed += 1
|
||||
logger.exception(
|
||||
"Skipping bill id=%s after topic extraction failure", bill.id
|
||||
)
|
||||
continue
|
||||
|
||||
store_bill_topic_result(
|
||||
session=session,
|
||||
bill=bill,
|
||||
topics=extracted_topics,
|
||||
replace_existing=True,
|
||||
)
|
||||
written += 1
|
||||
if index % 100 == 0:
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Stored %d topics for bill id=%s",
|
||||
len(extracted_topics),
|
||||
bill.id,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Done: stored topic results for %d bills; failed %d bills",
|
||||
written,
|
||||
failed,
|
||||
)
|
||||
|
||||
|
||||
def _log_topic_extraction_diagnostics(
|
||||
diagnostics: TopicExtractionDiagnostics,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Topic extraction diagnostics: bill_rows=%d bill_text_rows=%d "
|
||||
"summarized_bill_text_rows=%d bills_with_summaries=%d "
|
||||
"bill_topic_rows=%d selected_bills=%d",
|
||||
diagnostics.bill_rows,
|
||||
diagnostics.bill_text_rows,
|
||||
diagnostics.summarized_bill_text_rows,
|
||||
diagnostics.bills_with_summaries,
|
||||
diagnostics.bill_topic_rows,
|
||||
diagnostics.selected_bills,
|
||||
)
|
||||
if diagnostics.bill_rows == 0:
|
||||
logger.warning("No bills matched the topic extraction scope.")
|
||||
elif diagnostics.bill_text_rows == 0:
|
||||
logger.warning("No bill_text rows matched the topic extraction scope.")
|
||||
elif diagnostics.summarized_bill_text_rows == 0:
|
||||
logger.warning(
|
||||
"No summarized bill_text rows matched the topic extraction scope. "
|
||||
"Run pipelines.tools.summarize_bills first."
|
||||
)
|
||||
elif diagnostics.selected_bills == 0 and diagnostics.bill_topic_rows > 0:
|
||||
logger.warning(
|
||||
"No bills selected because matching bills already have topics. "
|
||||
"Use --force to regenerate them."
|
||||
)
|
||||
elif diagnostics.selected_bills == 0:
|
||||
logger.warning("No bills selected for topic extraction.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
@@ -0,0 +1,309 @@
|
||||
"""Summarize bill_text rows with GPT-5 and store results in the database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tomllib
|
||||
from os import getenv
|
||||
from typing import Annotated, Any
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from sqlalchemy import Select, exists, or_, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from tiktoken import get_encoding
|
||||
|
||||
|
||||
from pipelines.config import get_config_dir
|
||||
from pipelines.orm.common import get_postgres_engine
|
||||
from pipelines.orm.data_science_dev.congress import (
|
||||
Bill,
|
||||
BillText,
|
||||
SubjectType,
|
||||
VoteClassification,
|
||||
VoteRelationship,
|
||||
VoteTextTarget,
|
||||
)
|
||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions"
|
||||
OPENAI_PROJECT_ID = "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||
REQUEST_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
def load_summarization_prompts(
|
||||
section: str = "summarization",
|
||||
) -> dict[str, str]:
|
||||
summarization_prompts = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||
|
||||
return tomllib.loads(summarization_prompts.read_text())[section]
|
||||
|
||||
|
||||
class BillSummaryError(RuntimeError):
|
||||
"""Raised when a bill summary request or response is invalid."""
|
||||
|
||||
|
||||
def call_openai_summary(
|
||||
*,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Call GPT and return the assistant message content."""
|
||||
api_key = getenv("CLOSEDAI_TOKEN")
|
||||
if not api_key:
|
||||
msg = "CLOSEDAI_TOKEN is required"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
response = httpx.post(
|
||||
OPENAI_CHAT_COMPLETIONS_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"OpenAI-Project": OPENAI_PROJECT_ID,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
},
|
||||
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
logger.info(f"{response.text=}")
|
||||
response.raise_for_status()
|
||||
return extract_message_content(response.json())
|
||||
|
||||
|
||||
def build_bill_summary_messages(
|
||||
*,
|
||||
bill_text: BillText,
|
||||
summarization_prompts: dict[str, str],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build the GPT prompt messages plus compressed text and user prompt."""
|
||||
if not bill_text.text_content:
|
||||
msg = f"bill_text id={bill_text.id} has no text_content"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
compressed_text = compress_bill_text(bill_text.text_content)
|
||||
if not compressed_text:
|
||||
msg = f"bill_text id={bill_text.id} has no summarizable text_content"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
user_prompt = summarization_prompts["user_template"].format(
|
||||
text_content=compressed_text
|
||||
)
|
||||
|
||||
user_prompt_tokens = len(get_encoding("o200k_base").encode(user_prompt))
|
||||
logger.info(f"{user_prompt_tokens=}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": summarization_prompts["system_prompt"]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
return messages, user_prompt_tokens
|
||||
|
||||
|
||||
def summarize_bill_text(
|
||||
*,
|
||||
model: str,
|
||||
bill_text: BillText,
|
||||
summarization_prompts: dict[str, str],
|
||||
) -> str:
|
||||
"""Generate and return a summary for one bill_text row."""
|
||||
messages, user_prompt_tokens = build_bill_summary_messages(
|
||||
bill_text=bill_text,
|
||||
summarization_prompts=summarization_prompts,
|
||||
)
|
||||
# This may only be for gpt-5.4 mini I need to read the docs
|
||||
if user_prompt_tokens > 272000:
|
||||
msg = f"Compressed bill_text id={bill_text.id} is too long for summarization ({user_prompt_tokens} tokens)"
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
summary = call_openai_summary(
|
||||
model=model,
|
||||
messages=messages,
|
||||
).strip()
|
||||
if not summary:
|
||||
msg = f"Model returned an empty summary for bill_text id={bill_text.id}"
|
||||
raise BillSummaryError(msg)
|
||||
return summary
|
||||
|
||||
|
||||
def store_bill_summary_result(
|
||||
*,
|
||||
bill_text: BillText,
|
||||
summary: str,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""Store a generated summary and the prompt/model metadata that produced it."""
|
||||
bill_text.summary = summary
|
||||
bill_text.summarization_model = model
|
||||
bill_text.summarization_system_prompt_version = "v1.2"
|
||||
bill_text.summarization_user_prompt_version = "v1"
|
||||
|
||||
|
||||
def create_select_bill_texts_for_summarization(
|
||||
congress: int | None = None,
|
||||
bill_ids: list[int] | None = None,
|
||||
bill_text_ids: list[int] | None = None,
|
||||
with_votes_only: bool = False,
|
||||
force: bool = False,
|
||||
limit: int | None = None,
|
||||
) -> Select:
|
||||
"""Select bill_text rows that have source text and need summaries."""
|
||||
stmt = (
|
||||
select(BillText)
|
||||
.join(Bill, Bill.id == BillText.bill_id)
|
||||
.where(BillText.text_content.is_not(None), BillText.text_content != "")
|
||||
.options(selectinload(BillText.bill))
|
||||
.order_by(BillText.id)
|
||||
)
|
||||
if congress is not None:
|
||||
stmt = stmt.where(Bill.congress == congress)
|
||||
if bill_ids:
|
||||
stmt = stmt.where(BillText.bill_id.in_(bill_ids))
|
||||
if bill_text_ids:
|
||||
stmt = stmt.where(BillText.id.in_(bill_text_ids))
|
||||
if with_votes_only:
|
||||
stmt = stmt.where(
|
||||
exists(
|
||||
select(VoteTextTarget.vote_id)
|
||||
.join(
|
||||
VoteClassification,
|
||||
VoteClassification.vote_id == VoteTextTarget.vote_id,
|
||||
)
|
||||
.where(
|
||||
VoteTextTarget.voted_text_version_id == BillText.id,
|
||||
VoteClassification.subject_type == SubjectType.MEASURE,
|
||||
VoteClassification.vote_relationship
|
||||
== VoteRelationship.DIRECT_TEXT_VOTE,
|
||||
VoteClassification.is_direct_vote_on_legislative_text.is_(True),
|
||||
VoteClassification.is_substantive_policy_vote.is_(True),
|
||||
VoteClassification.is_special_rule.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
if not force:
|
||||
stmt = stmt.where(or_(BillText.summary.is_(None), BillText.summary == ""))
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def extract_message_content(data: dict[str, Any]) -> str:
|
||||
"""Extract message content from a chat-completions response body."""
|
||||
choices = data.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
msg = "Chat completion response did not contain choices"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
msg = "Chat completion choice must be an object"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
message = first.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
if isinstance(first.get("text"), str):
|
||||
return first["text"]
|
||||
|
||||
msg = "Chat completion response did not contain message content"
|
||||
raise BillSummaryError(msg)
|
||||
|
||||
|
||||
def main(
|
||||
model: Annotated[str, typer.Option(help="OpenAI model id.")] = "gpt-5.4-mini",
|
||||
congress: Annotated[
|
||||
int | None, typer.Option(help="Only process one Congress.")
|
||||
] = None,
|
||||
bill_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-id",
|
||||
help="Only process one internal bill.id. Repeat for multiple bills.",
|
||||
),
|
||||
] = None,
|
||||
bill_text_ids: Annotated[
|
||||
list[int] | None,
|
||||
typer.Option(
|
||||
"--bill-text-id",
|
||||
help="Only process one internal bill_text.id. Repeat for multiple rows.",
|
||||
),
|
||||
] = None,
|
||||
with_votes_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--with-votes-only",
|
||||
help="Only process bill_text rows linked to at least one vote.",
|
||||
),
|
||||
] = False,
|
||||
limit: Annotated[int | None, typer.Option(help="Maximum rows to process.")] = None,
|
||||
force: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Regenerate summaries for rows that already have a summary."),
|
||||
] = False,
|
||||
dry_run: Annotated[
|
||||
bool, typer.Option(help="Print summaries without writing them to the database.")
|
||||
] = False,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "INFO",
|
||||
) -> None:
|
||||
"""CLI entrypoint for generating and storing bill summaries."""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
if not getenv("CLOSEDAI_TOKEN"):
|
||||
message = "CLOSEDAI_TOKEN is required"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
summarization_prompts = load_summarization_prompts()
|
||||
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
|
||||
with Session(engine) as session:
|
||||
stmt = create_select_bill_texts_for_summarization(
|
||||
congress=congress,
|
||||
bill_ids=bill_ids,
|
||||
bill_text_ids=bill_text_ids,
|
||||
with_votes_only=with_votes_only,
|
||||
force=force,
|
||||
limit=limit,
|
||||
)
|
||||
bill_texts = session.scalars(stmt).all()
|
||||
logger.info("Selected %d bill_text rows for summarization", len(bill_texts))
|
||||
|
||||
written = 0
|
||||
for index, bill_text in enumerate(bill_texts, 1):
|
||||
summary = summarize_bill_text(
|
||||
model=model,
|
||||
bill_text=bill_text,
|
||||
summarization_prompts=summarization_prompts,
|
||||
)
|
||||
if summary is None:
|
||||
logger.warning("Skipping bill_text id=%s", bill_text.id)
|
||||
continue
|
||||
store_bill_summary_result(
|
||||
bill_text=bill_text,
|
||||
summary=summary,
|
||||
model=model,
|
||||
)
|
||||
if index % 100 == 0:
|
||||
session.commit()
|
||||
written += 1
|
||||
session.commit()
|
||||
logger.info("Stored summary for bill_text id=%s", bill_text.id)
|
||||
|
||||
logger.info("Done: stored %d summaries", written)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
"""Typer entry point."""
|
||||
typer.run(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user